diff --git a/docs/nvbench_compare.md b/docs/nvbench_compare.md new file mode 100644 index 00000000..49ca96fd --- /dev/null +++ b/docs/nvbench_compare.md @@ -0,0 +1,511 @@ +# NVBench Compare + +`nvbench-compare` compares two NVBench JSON outputs and classifies matching +benchmark states as `SAME`, `FAST`, `SLOW`, `UNDECIDED`, or `????`. + +NVBench treats benchmark performance data as describing a timing interval over +which measured timings varied. The interval is not intended as a precise +statistical confidence interval; it is an intuitive representation of the +observed timing range used to decide whether two benchmark results are clearly +separated, clearly compatible, or ambiguous. + +The comparison is intentionally conservative. It reports `FAST` or `SLOW` only +when the timing intervals have a clear gap and the gap is confirmed in cycle +space when clock information is available. Ambiguous cases stay `UNDECIDED` +instead of forcing a pass or regression. + +## Common Invocations + +Compare two JSON files: + +```bash +nvbench-compare reference.json compare.json +``` + +Limit the comparison to one benchmark: + +```bash +nvbench-compare --benchmark copy_type_sweep reference.json compare.json +``` + +Limit the comparison to one benchmark and one axis value: + +```bash +nvbench-compare \ + --benchmark copy_type_sweep \ + --axis T=F32 \ + reference.json compare.json +``` + +Show interval details and decision reasons in the table: + +```bash +nvbench-compare --display explain reference.json compare.json +``` + +Generate Python code with bulk sample/frequency filenames for every displayed +row: + +```bash +nvbench-compare --bulk-debug-python /path/to/output.py reference.json compare.json +``` + +Compare selected devices. Device filters are paired by position, so this +compares reference device `0` against compare device `1`: + +```bash +nvbench-compare \ + --reference-devices 0 \ + --compare-devices 1 \ + reference.json compare.json +``` + +Use emoji status markers instead of ANSI colors, which is useful when pasting +output into GitHub issues or pull requests: + +```bash +nvbench-compare --no-color reference.json compare.json +``` + +Use a built-in comparison preset: + +```bash +nvbench-compare --preset permissive reference.json compare.json +``` + +Use custom settings from TOML: + +```bash +nvbench-compare --config compare.toml reference.json compare.json +``` + +Use a CLI preset as the base while preserving explicit TOML overrides: + +```bash +nvbench-compare --config compare.toml --preset strict reference.json compare.json +``` + +Print the effective default configuration: + +```bash +nvbench-compare --dump-config +``` + +Print the effective configuration for a built-in preset: + +```bash +nvbench-compare --preset permissive --dump-config +``` + +## Matching Inputs + +`nvbench-compare` matches benchmark states by benchmark name, device pairing, +axis filters, and state occurrence order within each device section. + +Device sections must match unless `--ignore-devices` is specified or explicit +device filters are used: + +```bash +nvbench-compare \ + --ignore-devices \ + reference.json compare.json +``` + +```bash +nvbench-compare \ + --reference-devices 0 \ + --compare-devices 1 \ + reference.json compare.json +``` + +The device filter value may be `all`, one non-negative integer device id, or a +comma-separated list of non-negative integer ids. Filtered reference and compare +device lists must have the same length; devices are paired by position. + +Benchmark and axis filters follow NVBench CLI scoping: + +```bash +nvbench-compare -b copy_type_sweep -a T=F32 reference.json compare.json +``` + +`-a` / `--axis` applies to the most recent `-b` / `--benchmark`, or to all +benchmarks if it appears before any benchmark filter. + +## Timing Data Used + +For each matched state, `nvbench-compare` extracts GPU timing summaries emitted +by NVBench cold measurements: + +- `min` +- `max` +- `mean` +- `stdev/absolute` +- `stdev/relative` +- `q1` +- `median` +- `q3` +- `iqr/absolute` +- `iqr/relative` +- `sm_clock_rate/mean` + +When JSON output is generated with the NVBench `--jsonbin` option, +sample-time and sample-frequency binary data are loaded lazily and used for +bulk-data confirmation. + +Bulk data read failures are treated as unavailable data and reported as +warnings. + +## Bulk Debug Python Output + +`--bulk-debug-python /path/to/output.py` writes a Python script to the specified +file. The generated script contains a `bulk_rows` list. Each entry corresponds +to one row that `nvbench-compare` prints in its display tables after all +benchmark, axis, device, and threshold filters are applied. + +Use `stdout` instead of a file path to print the generated Python code: + +```bash +nvbench-compare --bulk-debug-python stdout reference.json compare.json +``` + +Each `bulk_rows` entry includes: + +- `row_index`: zero-based index among displayed comparison rows +- `table_row_index`: zero-based index within the displayed table for a device + section +- `benchmark` +- `reference_json` and `compare_json` +- `reference_device_id` and `compare_device_id` +- `state_key` +- `occurrence` and `occurrence_count`, which disambiguate duplicate states +- `axis_values` +- `status`, `reason`, and `reason_message` +- sample and frequency filenames and counts for reference and compare data + +The generated script also defines `load_bulk_data(row)`, which reads the +float32 sample and frequency files for a selected row. + +Select the first displayed row: + +```python +row = bulk_rows[0] +arrays = load_bulk_data(row) +``` + +Select the second undecided row: + +```python +undecided = [row for row in bulk_rows if row["status"] == "UNDECIDED"] +row = undecided[1] +arrays = load_bulk_data(row) +``` + +If `-b` and `-a` narrow the report to one comparison of interest, the desired +entry is usually available positionally as `bulk_rows[0]`. If duplicate states +remain after filtering, use `occurrence` to distinguish them. + +## Time Estimates And Intervals + +`nvbench-compare` prefers robust timing summaries when both sides provide them: + +- center: `median` +- relative dispersion: `iqr/relative`, or `iqr/absolute` / `median` +- interval: `[min, q3]` + +If robust summaries are not available on both sides, it falls back to classical +summaries: + +- center: `mean` +- relative dispersion: `stdev/relative`, or `stdev/absolute` / `mean` +- interval: `[max(min, mean - stdev), min(max, mean + stdev)]` + +Centers and interval endpoints must be positive and finite. States with unusable +centers are not compared. + +## Decision Tree + +The comparison logic starts from `UNDECIDED` and upgrades only when enough +evidence is available. + +### 1. Check For A Clear Gap + +The reference and compare intervals are checked for non-overlap. + +`FAST` is possible when the compare interval is entirely below the reference +interval: + +```text +cmp.upper < ref.lower +(ref.lower - cmp.upper) / cmp.upper >= clear_gap.relative +``` + +`SLOW` is possible when the compare interval is entirely above the reference +interval: + +```text +cmp.lower > ref.upper +(cmp.lower - ref.upper) / ref.upper >= clear_gap.relative +``` + +These ratios are algebraically equivalent to checking a log-scale relative gap, +but avoid evaluating logarithms for every row. + +### 2. Confirm Clear Gap In Cycle Space + +If sample times and frequencies are available, `nvbench-compare` computes: + +```text +cycles = sample_time * sample_frequency +``` + +It then builds cycle intervals from the bulk cycle samples and requires the +cycle interval comparison to agree with the timing interval comparison. A timing +gap that is not confirmed by bulk cycle intervals is `UNDECIDED`. + +If bulk data are unavailable, `nvbench-compare` falls back to summary clock-rate +confirmation using `sm_clock_rate/mean`. If that clock-rate summary is missing +or invalid, the clear-gap decision remains `UNDECIDED`. + +### 3. Check Bulk-Data Compatibility For SAME + +When there is no clear gap and bulk sample/frequency data are available, +`nvbench-compare` compares both time samples and cycle samples using symmetric +nearest-neighbor coverage in log space. + +For each unique value in one run, the nearest unique value in the other run is +found. A value is covered when the nearest-neighbor log distance is within: + +```text +log(1 + same.center_relative) +``` + +Both directions must pass: + +- sample-weight coverage must be at least `bulk.sample_coverage` +- unique-support coverage must be at least `bulk.support_coverage` + +Sample-weight coverage uses occurrence counts. Unique-support coverage treats +each retained unique value equally. + +### 4. Fall Back To Summary SAME + +If bulk data are unavailable, summary data can still support `SAME` when all of +the following are true: + +- both relative dispersion values are finite +- `max(ref_noise, cmp_noise) <= same.relative_dispersion_ceiling` +- centers are close: + +```text +abs(ref.center - cmp.center) / min(ref.center, cmp.center) + <= same.center_relative +``` + +- intervals overlap strongly enough: + +```text +overlap_fraction >= same.overlap_fraction +``` + +If `sm_clock_rate/mean` is available on both sides, the same check must also be +confirmed in summary cycle space. If clock-rate summaries are unavailable, the +summary timing decision can still report `SAME`. + +### 5. Otherwise Report UNDECIDED + +If none of the clear-gap or same-result paths has enough evidence, +`nvbench-compare` reports `UNDECIDED` and records a reason in the summary. + +## What To Do With UNDECIDED Results + +`UNDECIDED` does not mean a benchmark improved or regressed. It means +`nvbench-compare` did not find enough evidence to classify the result as +`SAME`, `FAST`, or `SLOW`. + +Useful next steps are: + +- Re-run both measurements and collect JSON with bulk sample data: + +```bash +./benchmark --jsonbin reference.json +./benchmark --jsonbin compare.json +nvbench-compare reference.json compare.json +``` + +Here `./benchmark` is the NVBench-instrumented executable or Python script that +uses `cuda.bench`. + +- Use `--display explain` to inspect the interval, noise, and decision reason + for each compared state. +- Use `--bulk-debug-python /path/to/output.py` to generate Python code that + identifies sample and frequency files for every displayed row. +- If cold-start effects are expected, adjust cold warmup controls such as + `--cold-warmup-runs` and `--cold-max-warmup-walltime`. +- Try a different stopping criterion when the default does not collect useful + evidence. For example, use `--stopping-criterion entropy`, or use + `--stopping-criterion sample-count` with an explicit `--target-samples` + value. +- After collecting stable data, use `--dump-config` as a starting point for a + TOML config if the default comparison thresholds are not appropriate for the + benchmark or machine. + +## Configuration + +Configuration files are TOML. The current format version is `1`. + +```toml +version = 1 + +[preset] +name = "default" + +[clear_gap] +relative = 0.005 + +[same] +center_relative = 0.005 +overlap_fraction = 0.5 +relative_dispersion_ceiling = 0.02 + +[bulk] +sample_coverage = 0.97 +support_coverage = 0.8 + +[bulk.rare_support] +sample_fraction = 0.001 +max_removed_sample_fraction = 0.01 +``` + +The parser is strict. Unknown top-level tables, unknown keys, wrong nesting, +unsupported versions, invalid types, non-finite values, and out-of-range values +are rejected. + +TOML parsing is lazy. Python 3.11 and newer use the standard-library +`tomllib`; Python 3.10 requires the optional `tomli` package only when +`--config` is used. + +## Preset And Config Precedence + +Preset resolution is: + +1. Use `default` when neither TOML nor CLI selects a preset. +2. Use `[preset] name = "..."` from TOML as the base preset when present. +3. Use `--preset ...` as the base preset when present, overriding the TOML + preset selection. +4. Apply explicit TOML threshold values over whichever base preset was selected. + +For example, with this config: + +```toml +version = 1 + +[preset] +name = "permissive" + +[bulk] +sample_coverage = 0.99 +``` + +This command uses the `permissive` preset as the base and overrides only +`bulk.sample_coverage`: + +```bash +nvbench-compare --config compare.toml reference.json compare.json +``` + +This command uses the `strict` preset as the base, but still overrides +`bulk.sample_coverage` from TOML: + +```bash +nvbench-compare --config compare.toml --preset strict reference.json compare.json +``` + +## Built-In Presets + +Built-in presets are available through `--preset`. To inspect the exact values +for the default configuration, run: + +```bash +nvbench-compare --dump-config +``` + +To inspect a named preset, combine `--preset` with `--dump-config`: + +```bash +nvbench-compare --preset strict --dump-config +nvbench-compare --preset permissive --dump-config +``` + +This avoids duplicating preset values in documentation and keeps the displayed +configuration tied to the installed `nvbench-compare` version. + +## Configuration Keys + +### `clear_gap.relative` + +Valid range: `>= 0` + +Minimum relative gap required before a non-overlapping timing interval can be +classified as `FAST` or `SLOW`. + +### `same.center_relative` + +Valid range: `>= 0` + +Maximum relative center difference for summary `SAME` decisions. This value is +also used as the log-space tolerance for bulk nearest-neighbor coverage: + +```text +log(1 + same.center_relative) +``` + +### `same.overlap_fraction` + +Valid range: `0 <= value <= 1` + +Minimum interval overlap fraction required for summary `SAME` decisions. The +overlap is measured relative to the narrower interval. + +### `same.relative_dispersion_ceiling` + +Valid range: `>= 0` + +Maximum allowed relative dispersion for summary `SAME` decisions. + +### `bulk.sample_coverage` + +Valid range: `0 <= value <= 1` + +Minimum sample-weight coverage for bulk `SAME` decisions. This check uses +counts of repeated sample values, so common values carry more weight. + +### `bulk.support_coverage` + +Valid range: `0 <= value <= 1` + +Minimum unique-support coverage for bulk `SAME` decisions. This check treats +each retained unique value equally. + +### `bulk.rare_support.sample_fraction` + +Valid range: `0 <= value <= 1` + +Unique values with count below: + +```text +max(2, ceil(sample_fraction * total_sample_count)) +``` + +are considered rare for support coverage. + +This filter only affects unique-support coverage. Sample-weight coverage always +uses all samples. + +### `bulk.rare_support.max_removed_sample_fraction` + +Valid range: `0 <= value <= 1` + +Maximum sample mass that may be removed from unique-support coverage by the rare +value filter. If filtering would remove more sample mass than this, remove every +unique value, or operate on an all-unique dataset, support coverage falls back +to the full unique support. diff --git a/nvbench/detail/measure_cold.cu b/nvbench/detail/measure_cold.cu index aaf49d63..ed228548 100644 --- a/nvbench/detail/measure_cold.cu +++ b/nvbench/detail/measure_cold.cu @@ -295,12 +295,12 @@ void measure_cold_base::generate_summaries() summ.set_string("hide", "Hidden by default."); } { - auto &summ = m_state.add_summary("nv/cold/time/cpu/ir/absolute"); + auto &summ = m_state.add_summary("nv/cold/time/cpu/iqr/absolute"); summ.set_string("name", "IQR"); summ.set_string("hint", "duration"); summ.set_string("description", "Interquartile range of isolated kernel execution CPU times"); - const auto cpu_time_ir = cpu_time_third_quartile - cpu_time_first_quartile; - summ.set_float64("value", cpu_time_ir); + const auto cpu_time_iqr = cpu_time_third_quartile - cpu_time_first_quartile; + summ.set_float64("value", cpu_time_iqr); summ.set_string("hide", "Hidden by default."); } const auto cpu_robust_noise = statistics::compute_robust_noise(m_total_samples, @@ -309,7 +309,7 @@ void measure_cold_base::generate_summaries() cpu_time_third_quartile); if (cpu_robust_noise) { - auto &summ = m_state.add_summary("nv/cold/time/cpu/ir/relative"); + auto &summ = m_state.add_summary("nv/cold/time/cpu/iqr/relative"); summ.set_string("name", "Rel IQR"); summ.set_string("hint", "percentage"); summ.set_string("description", @@ -401,12 +401,12 @@ void measure_cold_base::generate_summaries() summ.set_string("hide", "Hidden by default."); } { - auto &summ = m_state.add_summary("nv/cold/time/gpu/ir/absolute"); + auto &summ = m_state.add_summary("nv/cold/time/gpu/iqr/absolute"); summ.set_string("name", "IQR"); summ.set_string("hint", "duration"); summ.set_string("description", "Interquartile range of isolated kernel execution GPU times"); - const auto cuda_time_ir = cuda_time_third_quartile - cuda_time_first_quartile; - summ.set_float64("value", cuda_time_ir); + const auto cuda_time_iqr = cuda_time_third_quartile - cuda_time_first_quartile; + summ.set_float64("value", cuda_time_iqr); summ.set_string("hide", "Hidden by default."); } const auto cuda_robust_noise = statistics::compute_robust_noise(m_total_samples, @@ -415,7 +415,7 @@ void measure_cold_base::generate_summaries() cuda_time_third_quartile); if (cuda_robust_noise) { - auto &summ = m_state.add_summary("nv/cold/time/gpu/ir/relative"); + auto &summ = m_state.add_summary("nv/cold/time/gpu/iqr/relative"); summ.set_string("name", "Rel IQR"); summ.set_string("hint", "percentage"); summ.set_string("description", diff --git a/nvbench/detail/measure_cpu_only.cxx b/nvbench/detail/measure_cpu_only.cxx index d4a71af1..c291ba67 100644 --- a/nvbench/detail/measure_cpu_only.cxx +++ b/nvbench/detail/measure_cpu_only.cxx @@ -205,13 +205,13 @@ void measure_cpu_only_base::generate_summaries() summ.set_string("hide", "Hidden by default."); } { - auto &summ = m_state.add_summary("nv/cpu_only/time/cpu/ir/absolute"); + auto &summ = m_state.add_summary("nv/cpu_only/time/cpu/iqr/absolute"); summ.set_string("name", "IQR"); summ.set_string("hint", "duration"); summ.set_string("description", "Interquartile range of CPU times of isolated kernel executions"); - const auto cpu_ir = cpu_third_quartile - cpu_first_quartile; - summ.set_float64("value", cpu_ir); + const auto cpu_iqr = cpu_third_quartile - cpu_first_quartile; + summ.set_float64("value", cpu_iqr); summ.set_string("hide", "Hidden by default."); } const auto cpu_robust_noise = statistics::compute_robust_noise(m_total_samples, @@ -220,7 +220,7 @@ void measure_cpu_only_base::generate_summaries() cpu_third_quartile); if (cpu_robust_noise) { - auto &summ = m_state.add_summary("nv/cpu_only/time/cpu/ir/relative"); + auto &summ = m_state.add_summary("nv/cpu_only/time/cpu/iqr/relative"); summ.set_string("name", "Rel IQR"); summ.set_string("hint", "percentage"); summ.set_string("description", diff --git a/python/scripts/nvbench_compare.py b/python/scripts/nvbench_compare.py index c6370332..230e48d2 100644 --- a/python/scripts/nvbench_compare.py +++ b/python/scripts/nvbench_compare.py @@ -1,12 +1,23 @@ #!/usr/bin/env python +# +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import argparse import math import os +import pprint import sys -from enum import StrEnum +import warnings +from collections import Counter +from collections.abc import Mapping +from dataclasses import dataclass, field, replace +from enum import Enum +from functools import cached_property +from typing import Any, BinaryIO, Callable, Protocol import jsondiff +import numpy as np import tabulate from colorama import Fore @@ -23,15 +34,670 @@ def version_tuple(v): tabulate_version = version_tuple(tabulate.__version__) -all_ref_devices = [] -all_cmp_devices = [] -config_count = 0 -unknown_count = 0 -failure_count = 0 -pass_count = 0 +GPU_TIME_MIN_TAG = "nv/cold/time/gpu/min" +GPU_TIME_MAX_TAG = "nv/cold/time/gpu/max" +GPU_TIME_MEAN_TAG = "nv/cold/time/gpu/mean" +GPU_TIME_STDEV_TAG = "nv/cold/time/gpu/stdev/absolute" +GPU_TIME_STDEV_RELATIVE_TAG = "nv/cold/time/gpu/stdev/relative" +GPU_TIME_Q1_TAG = "nv/cold/time/gpu/q1" +GPU_TIME_MEDIAN_TAG = "nv/cold/time/gpu/median" +GPU_TIME_Q3_TAG = "nv/cold/time/gpu/q3" +GPU_TIME_IQR_TAG = "nv/cold/time/gpu/iqr/absolute" +GPU_TIME_IQR_RELATIVE_TAG = "nv/cold/time/gpu/iqr/relative" +LEGACY_GPU_TIME_IR_TAG = "nv/cold/time/gpu/ir/absolute" +LEGACY_GPU_TIME_IR_RELATIVE_TAG = "nv/cold/time/gpu/ir/relative" +GPU_SM_CLOCK_RATE_MEAN_TAG = "nv/cold/sm_clock_rate/mean" +SAMPLE_TIMES_TAG = "nv/json/bin:nv/cold/sample_times" +SAMPLE_FREQUENCIES_TAG = "nv/json/freqs-bin:nv/cold/sample_freqs" + +# The reader returns an object supporting the buffer protocol. Python 3.10 does +# not provide a standard Buffer type annotation. +Float32Reader = Callable[[str], object] + + +class TomlModule(Protocol): + # TOML support is imported lazily. This protocol documents the narrow + # tomllib/tomli module surface used by this script. + @property + def TOMLDecodeError(self) -> type[BaseException]: ... + + def load(self, fp: BinaryIO, /) -> dict[str, Any]: ... + + +def read_float32_file(filename: str) -> object: + return np.fromfile(filename, dtype=" ComparisonThresholds: + try: + return COMPARISON_THRESHOLD_PRESETS[preset_name] + except KeyError as exc: + raise ValueError(f"unknown comparison preset {preset_name!r}") from exc + + +def load_toml_module() -> TomlModule: + try: + import tomllib + + return tomllib + except ModuleNotFoundError: + try: + import tomli + + return tomli + except ModuleNotFoundError as exc: + raise ValueError( + "TOML config support requires Python 3.11+ or the tomli package" + ) from exc + + +def validate_config_table(value: object, table_name: str) -> None: + if not isinstance(value, Mapping): + raise ValueError(f"config table [{table_name}] must be a TOML table") + + +def validate_config_float(value: object, key: str, field_name: str) -> float: + if isinstance(value, bool) or not isinstance(value, int | float): + raise ValueError(f"config value {key!r} must be a finite number") + + value = float(value) + if not math.isfinite(value): + raise ValueError(f"config value {key!r} must be finite") + + minimum, maximum = COMPARISON_THRESHOLD_RANGES[field_name] + if value < minimum: + raise ValueError(f"config value {key!r} must be >= {minimum:g}") + if maximum is not None and value > maximum: + raise ValueError(f"config value {key!r} must be <= {maximum:g}") + return value + + +def parse_config_section( + table: Mapping[str, Any], section_name: str +) -> dict[str, float]: + validate_config_table(table, section_name) + known_keys = COMPARISON_CONFIG_KEYS[section_name] + unknown_keys = set(table) - set(known_keys) + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise ValueError(f"unknown config key(s) in [{section_name}]: {unknown}") + + overrides = {} + for key, field_name in known_keys.items(): + if key not in table: + continue + full_key = f"{section_name}.{key}" + overrides[field_name] = validate_config_float(table[key], full_key, field_name) + return overrides + + +def parse_comparison_config_data( + config_data: Mapping[str, Any], +) -> tuple[str | None, dict[str, float]]: + if not isinstance(config_data, Mapping): + raise ValueError("comparison config must be a TOML table") + + unknown_top_level = set(config_data) - ({"version"} | COMPARISON_CONFIG_TABLES) + if unknown_top_level: + unknown = ", ".join(sorted(unknown_top_level)) + raise ValueError(f"unknown top-level config key(s): {unknown}") + + version = config_data.get("version") + if isinstance(version, bool) or not isinstance(version, int): + raise ValueError( + f"comparison config must specify integer version = {COMPARISON_CONFIG_VERSION}" + ) + if version != COMPARISON_CONFIG_VERSION: + raise ValueError( + f"unsupported comparison config version {version!r}; " + f"expected {COMPARISON_CONFIG_VERSION}" + ) + + preset_name = None + if "preset" in config_data: + preset_table = config_data["preset"] + validate_config_table(preset_table, "preset") + unknown_keys = set(preset_table) - {"name"} + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise ValueError(f"unknown config key(s) in [preset]: {unknown}") + if "name" in preset_table: + preset_name = preset_table["name"] + if not isinstance(preset_name, str): + raise ValueError("config value 'preset.name' must be a string") + get_comparison_thresholds(preset_name) + + overrides = {} + for section_name in ("clear_gap", "same"): + if section_name in config_data: + overrides.update( + parse_config_section(config_data[section_name], section_name) + ) + + if "bulk" in config_data: + bulk_table = config_data["bulk"] + validate_config_table(bulk_table, "bulk") + known_bulk_keys = set(COMPARISON_CONFIG_KEYS["bulk"]) | {"rare_support"} + unknown_keys = set(bulk_table) - known_bulk_keys + if unknown_keys: + unknown = ", ".join(sorted(unknown_keys)) + raise ValueError(f"unknown config key(s) in [bulk]: {unknown}") + + bulk_values = { + key: value for key, value in bulk_table.items() if key != "rare_support" + } + overrides.update(parse_config_section(bulk_values, "bulk")) + if "rare_support" in bulk_table: + overrides.update( + parse_config_section(bulk_table["rare_support"], "bulk.rare_support") + ) + + return preset_name, overrides + + +def read_comparison_config_file( + config_path: str | os.PathLike[str], +) -> tuple[str | None, dict[str, float]]: + toml_module = load_toml_module() + try: + with open(config_path, "rb") as config_file: + config_data = toml_module.load(config_file) + except toml_module.TOMLDecodeError as exc: + raise ValueError( + f"failed to parse comparison config {config_path!r}: {exc}" + ) from exc + except OSError as exc: + raise ValueError( + f"failed to read comparison config {config_path!r}: {exc}" + ) from exc + + return parse_comparison_config_data(config_data) + + +def resolve_comparison_thresholds( + cli_preset_name: str | None = None, + config_path: str | os.PathLike[str] | None = None, +) -> tuple[str, ComparisonThresholds]: + config_preset_name = None + config_overrides: dict[str, float] = {} + if config_path is not None: + config_preset_name, config_overrides = read_comparison_config_file(config_path) + + preset_name = cli_preset_name or config_preset_name or COMPARISON_DEFAULT_PRESET + thresholds = replace(get_comparison_thresholds(preset_name), **config_overrides) + return preset_name, thresholds + + +def format_toml_float(value: float) -> str: + return repr(float(value)) + + +def dump_comparison_config(preset_name: str, thresholds: ComparisonThresholds) -> str: + lines = [ + f"version = {COMPARISON_CONFIG_VERSION}", + "", + "[preset]", + f'name = "{preset_name}"', + "", + "[clear_gap]", + f"relative = {format_toml_float(thresholds.clear_gap_relative)}", + "", + "[same]", + f"center_relative = {format_toml_float(thresholds.same_center_relative)}", + f"overlap_fraction = {format_toml_float(thresholds.same_overlap_fraction)}", + "relative_dispersion_ceiling = " + f"{format_toml_float(thresholds.same_relative_dispersion_ceiling)}", + "", + "[bulk]", + f"sample_coverage = {format_toml_float(thresholds.bulk_same_sample_coverage)}", + f"support_coverage = {format_toml_float(thresholds.bulk_same_support_coverage)}", + "", + "[bulk.rare_support]", + "sample_fraction = " + f"{format_toml_float(thresholds.bulk_support_rare_sample_fraction)}", + "max_removed_sample_fraction = " + f"{format_toml_float(thresholds.bulk_support_max_removed_sample_fraction)}", + ] + return "\n".join(lines) + "\n" + + +@dataclass(frozen=True) +class SupportFilterInfo: + activated: bool + reason: str + removed_sample_fraction: float + + +@dataclass(frozen=True) +class Float32BinarySource: + count: int + filename: str + json_dir: str + description: str + reader: Float32Reader = read_float32_file + + @cached_property + def values(self) -> np.ndarray | None: + return read_float32_binary( + self.count, self.filename, self.json_dir, self.description, self.reader + ) + + +@dataclass(frozen=True) +class GpuTimingData: + minimum: float | None + maximum: float | None + mean: float | None + stdev: float | None + stdev_relative: float | None + first_quartile: float | None + median: float | None + third_quartile: float | None + interquartile_range: float | None + interquartile_range_relative: float | None + sm_clock_rate_mean: float | None = None + sample_source: Float32BinarySource | None = None + frequency_source: Float32BinarySource | None = None + + @cached_property + def samples(self) -> np.ndarray | None: + if self.sample_source is None: + return None + return self.sample_source.values + + @cached_property + def frequencies(self) -> np.ndarray | None: + if self.frequency_source is None: + return None + return self.frequency_source.values + + +@dataclass(frozen=True) +class BulkDebugOutput: + destination: str + + @property + def is_stdout(self) -> bool: + return self.destination.lower() == "stdout" + + +@dataclass(frozen=True) +class TimeEstimate: + center: float | None + relative_dispersion: float | None + + +@dataclass(frozen=True) +class TimingInterval: + lower: float + upper: float + center: float + + +class ComparisonStatus(str, Enum): + UNKNOWN = "????" + UNDECIDED = "UNDECIDED" + SAME = "SAME" + FAST = "FAST" + SLOW = "SLOW" + + +@dataclass(frozen=True) +class DecisionReason: + code: str + message: str + severity: float = 0.0 + + +REASON_DISPLAY_CODES = { + "bulk_cycle_data_unusable": "bc-bad", + "bulk_cycle_gap_not_confirmed": "bc-gap-miss", + "bulk_cycle_same": "bc-same", + "bulk_cycle_support_mismatch": "bc-sup-miss", + "bulk_data_unavailable": "bulk-miss", + "bulk_same": "bulk-same", + "bulk_time_data_unusable": "bt-bad", + "bulk_time_same": "bt-same", + "bulk_time_support_mismatch": "bt-sup-miss", + "centers_not_close": "centers-far", + "clear_gap_confirmed_by_bulk_cycles": "bc-gap", + "clear_gap_confirmed_by_summary_cycles": "sc-gap", + "cycle_same_not_confirmed": "sc-same-miss", + "invalid_clock_rate": "clk-bad", + "missing_clock_rate": "clk-miss", + "missing_interval": "int-miss", + "no_clear_gap": "no-gap", + "noise_too_high": "noise-high", + "noise_unavailable": "noise-miss", + "same_confirmed_by_cycles": "sc-same", + "same_summary": "sum-same", + "same_without_clock_rate": "same-no-clk", + "summary_cycle_gap_not_confirmed": "sc-gap-miss", + "weak_interval_overlap": "weak-overlap", +} + + +def format_reason_display_code(code): + return REASON_DISPLAY_CODES.get(code, code) + + +def format_reason_legend_entries(reason_legend): + entries = [] + for code, reason_summary in sorted(reason_legend.items()): + if code == reason_summary.canonical_code.replace("_", "-"): + continue + entries.append(f"{code} = {reason_summary.canonical_code}") + return entries + + +@dataclass(frozen=True) +class TimingDecision: + status: ComparisonStatus + reason: DecisionReason + + +@dataclass(frozen=True) +class SummaryComparison: + ref_interval: TimingInterval | None + cmp_interval: TimingInterval | None + ref_estimate: TimeEstimate + cmp_estimate: TimeEstimate + ref_time: float + cmp_time: float + ref_noise: float | None + cmp_noise: float | None + diff: float + frac_diff: float + diff_interval: tuple[float, float] | None + frac_diff_interval: tuple[float, float] | None + max_noise: float | None + status: ComparisonStatus + reason: DecisionReason + + +@dataclass +class DecisionReasonSummary: + count: int = 0 + canonical_code: str = "" + message: str = "" + severity: float = 0.0 + + +@dataclass +class ComparisonStats: + config_count: int = 0 + pass_count: int = 0 + improvement_count: int = 0 + regression_count: int = 0 + undecided_count: int = 0 + unknown_count: int = 0 + undecided_reasons: dict[str, DecisionReasonSummary] = field(default_factory=dict) + reason_legend: dict[str, DecisionReasonSummary] = field(default_factory=dict) + + @staticmethod + def record_reason_summary( + summaries: dict[str, DecisionReasonSummary], + reason: DecisionReason, + *, + use_display_code, + ) -> None: + display_code = ( + format_reason_display_code(reason.code) if use_display_code else reason.code + ) + summary = summaries.setdefault( + display_code, DecisionReasonSummary(canonical_code=reason.code) + ) + if summary.count == 0 or reason.severity > summary.severity: + summary.canonical_code = reason.code + summary.message = reason.message + summary.severity = reason.severity + summary.count += 1 + + def record( + self, status: ComparisonStatus, reason: DecisionReason | None = None + ) -> None: + self.config_count += 1 + if reason is not None: + self.record_reason_summary( + self.reason_legend, reason, use_display_code=True + ) + if status == ComparisonStatus.UNKNOWN: + self.unknown_count += 1 + elif status == ComparisonStatus.UNDECIDED: + self.undecided_count += 1 + if reason is not None: + self.record_reason_summary( + self.undecided_reasons, reason, use_display_code=False + ) + elif status == ComparisonStatus.SAME: + self.pass_count += 1 + elif status == ComparisonStatus.FAST: + self.improvement_count += 1 + else: + self.regression_count += 1 + + +DeviceInfo = Mapping[str, Any] + + +@dataclass(frozen=True) +class ComparisonRunData: + # Device metadata fields are treated as read-only; stats is intentionally + # mutable and accumulates counts across one comparison run. + stats: ComparisonStats + ref_devices: tuple[DeviceInfo, ...] + cmp_devices: tuple[DeviceInfo, ...] + +@dataclass(frozen=True) +class BenchmarkFilterScope: + benchmark_name: str + axis_filters: list[dict] -class Emoji(StrEnum): + +@dataclass(frozen=True) +class BenchmarkFilterPlan: + global_axis_filters: list[dict] + benchmark_scopes: list[BenchmarkFilterScope] + + +class OrderedBenchmarkFilterAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + actions = getattr(namespace, self.dest, None) + actions = [] if actions is None else list(actions) + action_kind = "axis" if option_string in {"-a", "--axis"} else "benchmark" + actions.append((action_kind, values)) + setattr(namespace, self.dest, actions) + + +def state_match_key(state): + device_prefix = f"Device={state['device']}" + state_name = state["name"] + if state_name == device_prefix: + return "" + if state_name.startswith(f"{device_prefix} "): + return state_name[len(device_prefix) + 1 :] + return state_name + + +def group_states_by_match_key(states): + grouped = {} + for state in states: + grouped.setdefault(state_match_key(state), []).append(state) + return grouped + + +def state_group_counts(grouped_states): + return Counter( + {state_name: len(states) for state_name, states in grouped_states.items()} + ) + + +def format_device_ids(device_ids): + return ", ".join(str(device_id) for device_id in device_ids) + + +def parse_device_filter(device_arg, option_name): + device_arg = device_arg.strip() + if device_arg.lower() == "all": + return None + + values = [value.strip() for value in device_arg.split(",")] + if not all(values): + raise ValueError( + f"{option_name} must be 'all', a non-negative integer, " + "or comma-separated non-negative integers" + ) + + try: + device_ids = [int(value) for value in values] + except ValueError as exc: + raise ValueError( + f"{option_name} must be 'all', a non-negative integer, " + "or comma-separated non-negative integers" + ) from exc + if any(device_id < 0 for device_id in device_ids): + raise ValueError( + f"{option_name} must be 'all', a non-negative integer, " + "or comma-separated non-negative integers" + ) + return device_ids + + +def select_devices(all_devices, device_filter, option_name): + if device_filter is None: + return list(all_devices) + + devices_by_id = {device["id"]: device for device in all_devices} + missing_ids = [ + device_id for device_id in device_filter if device_id not in devices_by_id + ] + if missing_ids: + raise ValueError( + f"{option_name} requested device id(s) not present in input: " + f"{format_device_ids(missing_ids)}" + ) + + return [devices_by_id[device_id] for device_id in device_filter] + + +def resolve_benchmark_device_ids(bench, device_filter, option_name): + if device_filter is None: + return list(bench["devices"]) + + benchmark_device_ids = set(bench["devices"]) + missing_ids = [ + device_id + for device_id in device_filter + if device_id not in benchmark_device_ids + ] + if missing_ids: + raise ValueError( + f"benchmark {bench['name']!r} does not contain {option_name} " + f"device id(s): {format_device_ids(missing_ids)}" + ) + + return device_filter + + +def require_matching_device_sections(reference_device_filter, compare_device_filter): + return reference_device_filter is None and compare_device_filter is None + + +# TODO(opavlyk): replace with Emoji(StrEnum) after EOL of Python 3.10 +class Emoji(str, Enum): YELLOW = "\U0001f7e1" BLUE = "\U0001f535" GREEN = "\U0001f7e2" @@ -42,13 +708,1064 @@ class Emoji(StrEnum): def colorize(msg: str, fore: Fore, emoji: Emoji, no_color: bool) -> str: if no_color: prefix = "" - if emoji_s := str(emoji): + if emoji_s := emoji.value: prefix = f"{emoji_s} " return f"{prefix}{msg}" else: return f"{fore}{msg}{Fore.RESET}" +def lookup_summary(summaries, tag): + return next((summary for summary in summaries if summary["tag"] == tag), None) + + +def extract_summary_data_value(summary, name, expected_type): + summary_tag = summary.get("tag", "") + for value_data in summary.get("data", []): + if value_data.get("name") != name: + continue + + value_type = value_data.get("type") + if value_type != expected_type: + raise ValueError( + f"summary {summary_tag!r} field {name!r} has type " + f"{value_type!r}; expected {expected_type!r}" + ) + if "value" not in value_data: + raise ValueError(f"summary {summary_tag!r} field {name!r} is missing value") + return value_data["value"] + + raise ValueError(f"summary {summary_tag!r} is missing field {name!r}") + + +def extract_summary_value(summary): + return extract_summary_data_value(summary, "value", "float64") + + +def normalize_float_value(value, *, null_value=None): + if value is None: + return null_value + return float(value) + + +def extract_summary_float(summaries, tag, *, null_value=None): + summary = lookup_summary(summaries, tag) + if summary is None: + return None + return normalize_float_value(extract_summary_value(summary), null_value=null_value) + + +def extract_summary_float_with_fallback( + summaries: list[dict[str, Any]], + primary_tag: str, + fallback_tag: str, + *, + null_value: float | None = None, +) -> float | None: + value = extract_summary_float(summaries, primary_tag, null_value=null_value) + if value is not None: + return value + return extract_summary_float(summaries, fallback_tag, null_value=null_value) + + +def extract_binary_filename(summary): + value = extract_summary_data_value(summary, "filename", "string") + if not isinstance(value, str): + raise ValueError( + f"summary {summary.get('tag', '')!r} field 'filename' " + "value must be a string" + ) + return value + + +def extract_binary_size(summary): + value = extract_summary_data_value(summary, "size", "int64") + try: + return int(value) + except (TypeError, ValueError) as exc: + raise ValueError( + f"summary {summary.get('tag', '')!r} field 'size' " + f"value {value!r} is not an int64" + ) from exc + + +def extract_binary_meta(summaries, tag): + summary = lookup_summary(summaries, tag) + if summary is None: + return None, None + return extract_binary_size(summary), extract_binary_filename(summary) + + +def resolve_binary_filename(json_dir, binary_filename): + if os.path.isabs(binary_filename): + return binary_filename + + json_relative_filename = os.path.join(json_dir, binary_filename) + if os.path.exists(json_relative_filename): + return json_relative_filename + + parent_relative_filename = os.path.join(os.path.dirname(json_dir), binary_filename) + if os.path.exists(parent_relative_filename): + return parent_relative_filename + + if os.path.exists(binary_filename): + return binary_filename + + return json_relative_filename + + +def warn_unavailable_bulk_data(description, message): + warnings.warn( + f"Could not use NVBench {description} data: {message}; treating it as unavailable", + RuntimeWarning, + stacklevel=3, + ) + + +def read_float32_binary(count, filename, json_dir, description, reader): + filename = resolve_binary_filename(json_dir, filename) + try: + values = np.frombuffer(reader(filename), dtype=" str | None: + if source is None: + return None + return resolve_binary_filename(source.json_dir, source.filename) + + +def get_bulk_source_count(source: Float32BinarySource | None) -> int | None: + if source is None: + return None + return source.count + + +def make_axis_debug_values(axis_values, axes) -> list[dict[str, Any]]: + return [ + { + "name": axis_value.get("name"), + "type": axis_value.get("type"), + "value": axis_value.get("value"), + "display": format_axis_value(axis_value["name"], axis_value, axes), + } + for axis_value in axis_values + ] + + +def make_bulk_debug_row( + *, + row_index: int, + table_row_index: int, + benchmark_name: str, + ref_json_path: str | None, + cmp_json_path: str | None, + ref_device_id: int, + cmp_device_id: int, + cmp_state_name: str, + occurrence: int, + occurrence_count: int, + axis_values, + axes, + ref_timing: GpuTimingData, + cmp_timing: GpuTimingData, + comparison: SummaryComparison, +) -> dict[str, Any]: + return { + "row_index": row_index, + "table_row_index": table_row_index, + "benchmark": benchmark_name, + "reference_json": ref_json_path, + "compare_json": cmp_json_path, + "reference_device_id": ref_device_id, + "compare_device_id": cmp_device_id, + "state_key": cmp_state_name, + "occurrence": occurrence, + "occurrence_count": occurrence_count, + "axis_values": make_axis_debug_values(axis_values, axes), + "status": comparison.status.value, + "reason": comparison.reason.code, + "reason_message": comparison.reason.message, + "reference_time": comparison.ref_time, + "compare_time": comparison.cmp_time, + "fractional_difference": comparison.frac_diff, + "reference_sample_filename": resolve_bulk_source_filename( + ref_timing.sample_source + ), + "reference_sample_count": get_bulk_source_count(ref_timing.sample_source), + "reference_frequency_filename": resolve_bulk_source_filename( + ref_timing.frequency_source + ), + "reference_frequency_count": get_bulk_source_count(ref_timing.frequency_source), + "compare_sample_filename": resolve_bulk_source_filename( + cmp_timing.sample_source + ), + "compare_sample_count": get_bulk_source_count(cmp_timing.sample_source), + "compare_frequency_filename": resolve_bulk_source_filename( + cmp_timing.frequency_source + ), + "compare_frequency_count": get_bulk_source_count(cmp_timing.frequency_source), + } + + +def format_bulk_debug_python(bulk_rows: list[dict[str, Any]]) -> str: + return ( + "# Generated by nvbench-compare --bulk-debug-python.\n" + "import numpy as np\n\n" + f"bulk_rows = {pprint.pformat(bulk_rows, sort_dicts=False)}\n\n" + "def read_float32(filename, expected_count=None):\n" + " if filename is None:\n" + " return None\n" + " values = np.fromfile(filename, dtype=' None: + if output is None: + return + + script = format_bulk_debug_python(bulk_rows) + if output.is_stdout: + print(script, end="") + return + + with open(output.destination, "w", encoding="utf-8") as output_file: + output_file.write(script) + + +def compute_relative_dispersion(dispersion, center): + if ( + dispersion is None + or center is None + or center <= 0 + or not math.isfinite(center) + or dispersion < 0 + or math.isnan(dispersion) + ): + return None + return dispersion / center + + +def is_positive_finite(value): + return value is not None and value > 0.0 and math.isfinite(value) + + +def make_timing_interval(lower, upper, center): + if ( + not is_positive_finite(lower) + or not is_positive_finite(upper) + or not is_positive_finite(center) + or lower > center + or center > upper + ): + return None + return TimingInterval(lower=lower, upper=upper, center=center) + + +def compute_timing_interval(timing): + if ( + is_positive_finite(timing.minimum) + and is_positive_finite(timing.first_quartile) + and is_positive_finite(timing.median) + and is_positive_finite(timing.third_quartile) + and timing.minimum <= timing.first_quartile + and timing.first_quartile <= timing.median + and timing.median <= timing.third_quartile + ): + return make_timing_interval( + lower=timing.minimum, + upper=timing.third_quartile, + center=timing.median, + ) + + if ( + is_positive_finite(timing.minimum) + and is_positive_finite(timing.maximum) + and is_positive_finite(timing.mean) + and is_positive_finite(timing.stdev) + and timing.minimum <= timing.mean + and timing.mean <= timing.maximum + ): + return make_timing_interval( + lower=max(timing.minimum, timing.mean - timing.stdev), + upper=min(timing.maximum, timing.mean + timing.stdev), + center=timing.mean, + ) + + return None + + +def compute_timing_interval_from_samples(samples): + values = positive_finite_array(samples) + if values is None: + return None + + first_quartile, median, third_quartile = np.quantile(values, [0.25, 0.5, 0.75]) + return make_timing_interval( + lower=np.min(values), + upper=third_quartile, + center=median, + ) + + +def make_decision(status, code, message, *, severity=0.0): + return TimingDecision( + status=status, + reason=DecisionReason(code=code, message=message, severity=severity), + ) + + +def compare_intervals_for_clear_gap(ref_interval, cmp_interval, thresholds): + # These ratios are equivalent to log(ref/cmp) >= log(1 + delta), but avoid + # evaluating logarithms on every comparison. + if cmp_interval.upper < ref_interval.lower: + gap = ref_interval.lower - cmp_interval.upper + if gap / cmp_interval.upper >= thresholds.clear_gap_relative: + return ComparisonStatus.FAST + if cmp_interval.lower > ref_interval.upper: + gap = cmp_interval.lower - ref_interval.upper + if gap / ref_interval.upper >= thresholds.clear_gap_relative: + return ComparisonStatus.SLOW + return None + + +def compute_diff_interval(ref_interval, cmp_interval): + return ( + cmp_interval.lower - ref_interval.upper, + cmp_interval.upper - ref_interval.lower, + ) + + +def compute_frac_diff_interval(ref_interval, cmp_interval): + return ( + cmp_interval.lower / ref_interval.upper - 1.0, + cmp_interval.upper / ref_interval.lower - 1.0, + ) + + +def centers_are_close(ref_center, cmp_center, thresholds): + if not is_positive_finite(ref_center) or not is_positive_finite(cmp_center): + return False + return ( + abs(ref_center - cmp_center) / min(ref_center, cmp_center) + <= thresholds.same_center_relative + ) + + +def interval_overlap_fraction(ref_interval, cmp_interval): + intersection_lower = max(ref_interval.lower, cmp_interval.lower) + intersection_upper = min(ref_interval.upper, cmp_interval.upper) + if intersection_upper < intersection_lower: + return 0.0 + + ref_width = ref_interval.upper - ref_interval.lower + cmp_width = cmp_interval.upper - cmp_interval.lower + min_width = min(ref_width, cmp_width) + if min_width > 0.0: + return (intersection_upper - intersection_lower) / min_width + + if ref_width == 0.0 and cmp_width == 0.0: + return 1.0 if ref_interval.lower == cmp_interval.lower else 0.0 + + if ref_width == 0.0: + return ( + 1.0 + if cmp_interval.lower <= ref_interval.lower <= cmp_interval.upper + else 0.0 + ) + + return ( + 1.0 if ref_interval.lower <= cmp_interval.lower <= ref_interval.upper else 0.0 + ) + + +def intervals_overlap_strongly(ref_interval, cmp_interval, thresholds): + return ( + interval_overlap_fraction(ref_interval, cmp_interval) + >= thresholds.same_overlap_fraction + ) + + +def nearest_distances_to_sorted(target, source): + pos = np.searchsorted(source, target, side="left") + left = np.clip(pos - 1, 0, len(source) - 1) + right = np.clip(pos, 0, len(source) - 1) + return np.minimum( + np.abs(target - source[left]), + np.abs(target - source[right]), + ) + + +def symmetric_nearest_distances(x, y): + # This is O(N log M + M log N), but runs in NumPy C code and operates on + # unique supports. If this becomes a bottleneck for very large supports, + # add an optional O(N + M) two-pass merge helper to cuda.bench and fall back + # to this implementation when cuda.bench is unavailable. + return nearest_distances_to_sorted(x, y), nearest_distances_to_sorted(y, x) + + +def symmetric_nearest_log_distances(x, y): + return symmetric_nearest_distances(np.log(x), np.log(y)) + + +def compute_effective_support_mask(counts, thresholds): + """Return the unique-value mask used for support coverage. + + Sample-weight coverage always uses all values. Support coverage may ignore + low-count values only when their total sample mass is small; otherwise it + falls back to full support, preserving all-unique datasets. + """ + counts = np.asarray(counts) + total_count = np.sum(counts) + if ( + len(counts) == 0 + or total_count <= 0 + or thresholds.bulk_support_rare_sample_fraction <= 0.0 + or thresholds.bulk_support_max_removed_sample_fraction <= 0.0 + ): + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="disabled", + removed_sample_fraction=0.0, + ) + + if np.all(counts == 1): + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="all_values_unique", + removed_sample_fraction=0.0, + ) + + min_count = max( + 2, + math.ceil(thresholds.bulk_support_rare_sample_fraction * total_count), + ) + support_mask = counts >= min_count + if np.all(support_mask): + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="no_rare_values", + removed_sample_fraction=0.0, + ) + if not np.any(support_mask): + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="would_remove_all_support", + removed_sample_fraction=0.0, + ) + + removed_sample_fraction = np.sum(counts[~support_mask]) / total_count + if removed_sample_fraction > thresholds.bulk_support_max_removed_sample_fraction: + return np.ones(len(counts), dtype=bool), SupportFilterInfo( + activated=False, + reason="would_remove_too_much_mass", + removed_sample_fraction=0.0, + ) + + return support_mask, SupportFilterInfo( + activated=True, + reason="filtered", + removed_sample_fraction=removed_sample_fraction, + ) + + +def format_support_filter_info(filter_info): + if filter_info.activated: + return f"on({format_coverage(filter_info.removed_sample_fraction)})" + + if filter_info.reason == "no_rare_values": + return "off(no rare values)" + if filter_info.reason == "all_values_unique": + return "off(all values unique)" + if filter_info.reason == "would_remove_too_much_mass": + return "off(would remove too much mass)" + if filter_info.reason == "would_remove_all_support": + return "off(would remove all support)" + return "off(disabled)" + + +def compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds): + ref_unique, ref_counts = np.unique_counts(ref_values) + cmp_unique, cmp_counts = np.unique_counts(cmp_values) + if len(ref_unique) == 0 or len(cmp_unique) == 0: + return None + + ref_distances, cmp_distances = symmetric_nearest_log_distances( + ref_unique, cmp_unique + ) + tolerance = math.log1p(thresholds.same_center_relative) + ref_covered = ref_distances <= tolerance + cmp_covered = cmp_distances <= tolerance + ref_support_mask, ref_filter_info = compute_effective_support_mask( + ref_counts, thresholds + ) + cmp_support_mask, cmp_filter_info = compute_effective_support_mask( + cmp_counts, thresholds + ) + + return { + "ref_sample": np.sum(ref_counts[ref_covered]) / np.sum(ref_counts), + "cmp_sample": np.sum(cmp_counts[cmp_covered]) / np.sum(cmp_counts), + "ref_support": np.mean(ref_covered[ref_support_mask]), + "cmp_support": np.mean(cmp_covered[cmp_support_mask]), + "ref_support_filter": ref_filter_info, + "cmp_support_filter": cmp_filter_info, + } + + +def coverages_support_same(coverages, thresholds): + return ( + coverages["ref_sample"] >= thresholds.bulk_same_sample_coverage + and coverages["cmp_sample"] >= thresholds.bulk_same_sample_coverage + and coverages["ref_support"] >= thresholds.bulk_same_support_coverage + and coverages["cmp_support"] >= thresholds.bulk_same_support_coverage + ) + + +def format_coverage_threshold(threshold): + return f"{threshold * 100.0:.1f}%" + + +def format_coverage(value): + return f"{value * 100.0:.1f}%" + + +def make_bulk_coverage_mismatch_decision(label, coverages, thresholds): + sample_threshold = format_coverage_threshold(thresholds.bulk_same_sample_coverage) + support_threshold = format_coverage_threshold(thresholds.bulk_same_support_coverage) + sample_deficit = max( + thresholds.bulk_same_sample_coverage - coverages["ref_sample"], + thresholds.bulk_same_sample_coverage - coverages["cmp_sample"], + 0.0, + ) + support_deficit = max( + thresholds.bulk_same_support_coverage - coverages["ref_support"], + thresholds.bulk_same_support_coverage - coverages["cmp_support"], + 0.0, + ) + severity = max(sample_deficit, support_deficit) + return make_decision( + ComparisonStatus.UNDECIDED, + f"bulk_{label}_support_mismatch", + f"sample: min(ref={format_coverage(coverages['ref_sample'])}, " + f"cmp={format_coverage(coverages['cmp_sample'])}) >= {sample_threshold}; " + f"support: min(ref={format_coverage(coverages['ref_support'])}, " + f"cmp={format_coverage(coverages['cmp_support'])}) >= {support_threshold}", + severity=severity, + ) + + +def positive_finite_array(values): + if values is None or len(values) == 0: + return None + + array = np.asarray(values, dtype=np.float64) + if np.all(np.isfinite(array) & (array > 0.0)): + return array + return None + + +def get_bulk_time_and_cycles(timing): + samples = positive_finite_array(timing.samples) + frequencies = positive_finite_array(timing.frequencies) + if samples is None or frequencies is None: + return None + if len(samples) != len(frequencies): + return None + return samples, samples * frequencies + + +def scale_interval(interval, scale): + if not is_positive_finite(scale): + return None + return make_timing_interval( + lower=interval.lower * scale, + upper=interval.upper * scale, + center=interval.center * scale, + ) + + +def confirm_clear_gap_with_clock_rate( + status, ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds +): + if ref_timing.sm_clock_rate_mean is None or cmp_timing.sm_clock_rate_mean is None: + return make_decision( + ComparisonStatus.UNDECIDED, + "missing_clock_rate", + "clear timing gap was not confirmed because SM clock summaries are unavailable", + ) + + ref_cycles = scale_interval(ref_interval, ref_timing.sm_clock_rate_mean) + cmp_cycles = scale_interval(cmp_interval, cmp_timing.sm_clock_rate_mean) + if ref_cycles is None or cmp_cycles is None: + return make_decision( + ComparisonStatus.UNDECIDED, + "invalid_clock_rate", + "clear timing gap was not confirmed because SM clock summaries are invalid", + ) + + cycle_status = compare_intervals_for_clear_gap(ref_cycles, cmp_cycles, thresholds) + if cycle_status == status: + return make_decision( + status, + "clear_gap_confirmed_by_summary_cycles", + "clear timing gap was confirmed by SM-clock-adjusted cycle intervals", + ) + return make_decision( + ComparisonStatus.UNDECIDED, + "summary_cycle_gap_not_confirmed", + "clear timing gap was not confirmed by SM-clock-adjusted cycle intervals", + ) + + +def confirm_clear_gap_with_bulk_cycles(status, ref_timing, cmp_timing, thresholds): + ref_bulk = get_bulk_time_and_cycles(ref_timing) + cmp_bulk = get_bulk_time_and_cycles(cmp_timing) + if ref_bulk is None or cmp_bulk is None: + return None + + _, ref_cycles = ref_bulk + _, cmp_cycles = cmp_bulk + ref_cycle_interval = compute_timing_interval_from_samples(ref_cycles) + cmp_cycle_interval = compute_timing_interval_from_samples(cmp_cycles) + if ref_cycle_interval is None or cmp_cycle_interval is None: + return None + + cycle_status = compare_intervals_for_clear_gap( + ref_cycle_interval, cmp_cycle_interval, thresholds + ) + if cycle_status == status: + return make_decision( + status, + "clear_gap_confirmed_by_bulk_cycles", + "clear timing gap was confirmed by bulk cycle intervals", + ) + return make_decision( + ComparisonStatus.UNDECIDED, + "bulk_cycle_gap_not_confirmed", + "clear timing gap was not confirmed by bulk cycle intervals", + ) + + +def compare_timings_for_clear_gap(ref_timing, cmp_timing, thresholds): + ref_interval = compute_timing_interval(ref_timing) + cmp_interval = compute_timing_interval(cmp_timing) + if ref_interval is None or cmp_interval is None: + return make_decision( + ComparisonStatus.UNDECIDED, + "missing_interval", + "could not construct comparable timing intervals", + ) + + status = compare_intervals_for_clear_gap(ref_interval, cmp_interval, thresholds) + if status is None: + return make_decision( + ComparisonStatus.UNDECIDED, + "no_clear_gap", + "timing intervals do not have a sufficient clear gap", + ) + + bulk_decision = confirm_clear_gap_with_bulk_cycles( + status, ref_timing, cmp_timing, thresholds + ) + if bulk_decision is not None: + return bulk_decision + + return confirm_clear_gap_with_clock_rate( + status, ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds + ) + + +def compare_intervals_for_same(ref_interval, cmp_interval, thresholds): + if not centers_are_close(ref_interval.center, cmp_interval.center, thresholds): + return make_decision( + ComparisonStatus.UNDECIDED, + "centers_not_close", + "timing centers are not close enough to declare same", + ) + if not intervals_overlap_strongly(ref_interval, cmp_interval, thresholds): + return make_decision( + ComparisonStatus.UNDECIDED, + "weak_interval_overlap", + "timing intervals do not overlap strongly enough to declare same", + ) + return make_decision( + ComparisonStatus.SAME, + "same_summary", + "timing centers are close and intervals overlap strongly", + ) + + +def confirm_same_with_clock_rate( + ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds +): + if ref_timing.sm_clock_rate_mean is None or cmp_timing.sm_clock_rate_mean is None: + return make_decision( + ComparisonStatus.SAME, + "same_without_clock_rate", + "timing centers are close and intervals overlap strongly; SM clock summaries are unavailable", + ) + + ref_cycles = scale_interval(ref_interval, ref_timing.sm_clock_rate_mean) + cmp_cycles = scale_interval(cmp_interval, cmp_timing.sm_clock_rate_mean) + if ref_cycles is None or cmp_cycles is None: + return make_decision( + ComparisonStatus.UNDECIDED, + "invalid_clock_rate", + "same decision was not confirmed because SM clock summaries are invalid", + ) + + decision = compare_intervals_for_same(ref_cycles, cmp_cycles, thresholds) + if decision.status == ComparisonStatus.SAME: + return make_decision( + ComparisonStatus.SAME, + "same_confirmed_by_cycles", + "timing and SM-clock-adjusted cycle intervals both support same", + ) + return make_decision( + ComparisonStatus.UNDECIDED, + "cycle_same_not_confirmed", + "same decision was not confirmed by SM-clock-adjusted cycle intervals", + ) + + +def compare_values_for_bulk_same(ref_values, cmp_values, *, label, thresholds): + coverages = compute_nearest_neighbor_coverages(ref_values, cmp_values, thresholds) + if coverages is None: + return make_decision( + ComparisonStatus.UNDECIDED, + f"bulk_{label}_data_unusable", + f"bulk {label} data is empty or unusable", + ) + if coverages_support_same(coverages, thresholds): + return make_decision( + ComparisonStatus.SAME, + f"bulk_{label}_same", + f"bulk {label} nearest-neighbor coverage supports same", + ) + return make_bulk_coverage_mismatch_decision(label, coverages, thresholds) + + +def compare_timings_for_bulk_same(ref_timing, cmp_timing, thresholds): + ref_bulk = get_bulk_time_and_cycles(ref_timing) + cmp_bulk = get_bulk_time_and_cycles(cmp_timing) + if ref_bulk is None or cmp_bulk is None: + return make_decision( + ComparisonStatus.UNDECIDED, + "bulk_data_unavailable", + "bulk sample time and frequency data are unavailable", + ) + + ref_times, ref_cycles = ref_bulk + cmp_times, cmp_cycles = cmp_bulk + + time_decision = compare_values_for_bulk_same( + ref_times, cmp_times, label="time", thresholds=thresholds + ) + if time_decision.status != ComparisonStatus.SAME: + return time_decision + + cycle_decision = compare_values_for_bulk_same( + ref_cycles, cmp_cycles, label="cycle", thresholds=thresholds + ) + if cycle_decision.status != ComparisonStatus.SAME: + return cycle_decision + + return make_decision( + ComparisonStatus.SAME, + "bulk_same", + "bulk time and cycle nearest-neighbor coverage both support same", + ) + + +def compare_timings_for_same(ref_timing, cmp_timing, ref_noise, cmp_noise, thresholds): + if not has_finite_noise(ref_noise) or not has_finite_noise(cmp_noise): + return make_decision( + ComparisonStatus.UNDECIDED, + "noise_unavailable", + "relative dispersion is unavailable or non-finite", + ) + if max(ref_noise, cmp_noise) > thresholds.same_relative_dispersion_ceiling: + return make_decision( + ComparisonStatus.UNDECIDED, + "noise_too_high", + "relative dispersion is too high to declare same", + ) + + ref_interval = compute_timing_interval(ref_timing) + cmp_interval = compute_timing_interval(cmp_timing) + if ref_interval is None or cmp_interval is None: + return make_decision( + ComparisonStatus.UNDECIDED, + "missing_interval", + "could not construct comparable timing intervals", + ) + + decision = compare_intervals_for_same(ref_interval, cmp_interval, thresholds) + if decision.status != ComparisonStatus.SAME: + return decision + + return confirm_same_with_clock_rate( + ref_timing, cmp_timing, ref_interval, cmp_interval, thresholds + ) + + +def has_robust_estimate(summary): + return summary.median is not None and ( + summary.interquartile_range_relative is not None + or summary.interquartile_range is not None + ) + + +def has_mean_estimate(summary): + return summary.mean is not None and ( + summary.stdev_relative is not None or summary.stdev is not None + ) + + +def select_relative_dispersion(relative_dispersion, absolute_dispersion, center): + if relative_dispersion is not None: + return relative_dispersion + return compute_relative_dispersion(absolute_dispersion, center) + + +def compute_common_time_estimates(ref_timing, cmp_timing): + if has_robust_estimate(ref_timing) and has_robust_estimate(cmp_timing): + return ( + TimeEstimate( + center=ref_timing.median, + relative_dispersion=select_relative_dispersion( + ref_timing.interquartile_range_relative, + ref_timing.interquartile_range, + ref_timing.median, + ), + ), + TimeEstimate( + center=cmp_timing.median, + relative_dispersion=select_relative_dispersion( + cmp_timing.interquartile_range_relative, + cmp_timing.interquartile_range, + cmp_timing.median, + ), + ), + ) + + if has_mean_estimate(ref_timing) and has_mean_estimate(cmp_timing): + return ( + TimeEstimate( + center=ref_timing.mean, + relative_dispersion=select_relative_dispersion( + ref_timing.stdev_relative, ref_timing.stdev, ref_timing.mean + ), + ), + TimeEstimate( + center=cmp_timing.mean, + relative_dispersion=select_relative_dispersion( + cmp_timing.stdev_relative, cmp_timing.stdev, cmp_timing.mean + ), + ), + ) + + return ( + TimeEstimate( + center=ref_timing.mean, + relative_dispersion=compute_relative_dispersion( + ref_timing.stdev, ref_timing.mean + ), + ), + TimeEstimate( + center=cmp_timing.mean, + relative_dispersion=compute_relative_dispersion( + cmp_timing.stdev, cmp_timing.mean + ), + ), + ) + + +def compare_gpu_timings(ref_timing, cmp_timing, comparison_thresholds=None): + if comparison_thresholds is None: + comparison_thresholds = ComparisonThresholds() + + ref_estimate, cmp_estimate = compute_common_time_estimates(ref_timing, cmp_timing) + + cmp_time = cmp_estimate.center + ref_time = ref_estimate.center + + if cmp_time is None or ref_time is None: + return None + + if not math.isfinite(cmp_time) or not math.isfinite(ref_time): + return None + + if cmp_time <= 0.0 or ref_time <= 0.0: + return None + + cmp_noise = cmp_estimate.relative_dispersion + ref_noise = ref_estimate.relative_dispersion + + ref_interval = compute_timing_interval(ref_timing) + cmp_interval = compute_timing_interval(cmp_timing) + diff = cmp_time - ref_time + frac_diff = diff / ref_time + diff_interval = None + frac_diff_interval = None + if ref_interval is not None and cmp_interval is not None: + diff_interval = compute_diff_interval(ref_interval, cmp_interval) + frac_diff_interval = compute_frac_diff_interval(ref_interval, cmp_interval) + + if not has_finite_noise(ref_noise) or not has_finite_noise(cmp_noise): + max_noise = None + else: + max_noise = max(ref_noise, cmp_noise) + + decision = compare_timings_for_clear_gap( + ref_timing, cmp_timing, comparison_thresholds + ) + if decision.status == ComparisonStatus.UNDECIDED and decision.reason.code in { + "no_clear_gap", + "missing_interval", + }: + bulk_decision = compare_timings_for_bulk_same( + ref_timing, cmp_timing, comparison_thresholds + ) + if bulk_decision.reason.code == "bulk_data_unavailable": + decision = compare_timings_for_same( + ref_timing, cmp_timing, ref_noise, cmp_noise, comparison_thresholds + ) + else: + decision = bulk_decision + + return SummaryComparison( + ref_interval=ref_interval, + cmp_interval=cmp_interval, + ref_estimate=ref_estimate, + cmp_estimate=cmp_estimate, + ref_time=ref_time, + cmp_time=cmp_time, + ref_noise=ref_noise, + cmp_noise=cmp_noise, + diff=diff, + frac_diff=frac_diff, + diff_interval=diff_interval, + frac_diff_interval=frac_diff_interval, + max_noise=max_noise, + status=decision.status, + reason=decision.reason, + ) + + def find_matching_bench(needle, haystack): for hay in haystack: if hay["name"] == needle["name"]: @@ -69,8 +1786,8 @@ def format_int64_axis_value(axis_name, axis_value, axes): value = int(axis_value["value"]) if axis_flags == "pow2": value = math.log2(value) - return "2^%d" % value - return "%d" % value + return f"2^{value:.0f}" + return f"{value:d}" def format_float64_axis_value(axis_name, axis_value, axes): @@ -78,11 +1795,11 @@ def format_float64_axis_value(axis_name, axis_value, axes): def format_type_axis_value(axis_name, axis_value, axes): - return "%s" % axis_value["value"] + return f"{axis_value['value']}" def format_string_axis_value(axis_name, axis_value, axes): - return "%s" % axis_value["value"] + return f"{axis_value['value']}" def format_axis_value(axis_name, axis_value, axes): @@ -98,10 +1815,10 @@ def format_axis_value(axis_name, axis_value, axes): return format_string_axis_value(axis_name, axis_value, axes) -def make_display(name: str, display_values: [list[str]]) -> str: +def make_display(name: str, display_values: list[str]) -> str: open_bracket, close_bracket = ("[", "]") if len(display_values) > 1 else ("", "") - display_values = ",".join(display_values) - return f"{name}={open_bracket}{display_values}{close_bracket}" + joined_values = ",".join(display_values) + return f"{name}={open_bracket}{joined_values}{close_bracket}" def parse_axis_filters(axis_args): @@ -152,6 +1869,53 @@ def parse_axis_filters(axis_args): return filters +def build_benchmark_filter_plan(filter_actions): + global_axis_args = [] + benchmark_scopes = [] + current_scope = None + + for action_kind, action_value in filter_actions or []: + if action_kind == "benchmark": + current_scope = {"benchmark_name": action_value, "axis_args": []} + benchmark_scopes.append(current_scope) + elif current_scope is None: + global_axis_args.append(action_value) + else: + current_scope["axis_args"].append(action_value) + + return BenchmarkFilterPlan( + global_axis_filters=parse_axis_filters(global_axis_args), + benchmark_scopes=[ + BenchmarkFilterScope( + benchmark_name=scope["benchmark_name"], + axis_filters=parse_axis_filters(scope["axis_args"]), + ) + for scope in benchmark_scopes + ], + ) + + +def benchmark_is_selected(benchmark_name, filter_plan): + return not filter_plan.benchmark_scopes or any( + scope.benchmark_name == benchmark_name for scope in filter_plan.benchmark_scopes + ) + + +def axis_filter_groups_for_benchmark(benchmark_name, filter_plan): + if not filter_plan.benchmark_scopes: + return [filter_plan.global_axis_filters] + + matching_scopes = [ + scope + for scope in filter_plan.benchmark_scopes + if scope.benchmark_name == benchmark_name + ] + return [ + filter_plan.global_axis_filters + scope.axis_filters + for scope in matching_scopes + ] + + def matches_axis_filters(state, axis_filters): if not axis_filters: return True @@ -175,6 +1939,23 @@ def matches_axis_filters(state, axis_filters): return True +def matches_axis_filter_groups(state, axis_filter_groups): + return any( + matches_axis_filters(state, axis_filters) for axis_filters in axis_filter_groups + ) + + +def matching_axis_filters(state, axis_filter_groups): + return next( + ( + axis_filters + for axis_filters in axis_filter_groups + if matches_axis_filters(state, axis_filters) + ), + [], + ) + + def format_duration(seconds): if seconds >= 1: multiplier = 1.0 @@ -188,16 +1969,331 @@ def format_duration(seconds): else: multiplier = 1e6 units = "us" - return "%0.3f %s" % (seconds * multiplier, units) + return f"{seconds * multiplier:0.3f} {units}" + + +def select_duration_units(*seconds_values): + max_abs_seconds = max(abs(value) for value in seconds_values) + if max_abs_seconds >= 1: + return 1.0, "s" + if max_abs_seconds >= 1e-3: + return 1e3, "ms" + return 1e6, "us" + + +def duration_precision_for_center(center, delta_multiplier): + center_multiplier, _ = select_duration_units(center) + center_quantum = 10.0**-3 * (delta_multiplier / center_multiplier) + if center_quantum >= 1.0: + return 0 + return int(math.ceil(-math.log10(center_quantum))) + + +def format_duration_range(bounds): + if bounds is None: + return "n/a" + lower, upper = bounds + multiplier, units = select_duration_units(lower, upper) + return f"[{lower * multiplier:0.2f}, {upper * multiplier:0.2f}] {units}" + + +def format_timing_with_interval( + center, interval, *, center_width=None, interval_width=None +): + if center is None: + return "n/a" + if interval is None: + if center_width is not None and interval_width is not None: + center_multiplier, center_units = select_duration_units(center) + center_text = f"{center * center_multiplier:0.3f}" + center_text = f"{center_text:>{center_width}}" + if interval_width == 0: + return f"{center_text} {center_units}" + return f"{center_text} {' ' * interval_width} {center_units}" + return format_duration(center) + + lower_delta = interval.lower - interval.center + upper_delta = interval.upper - interval.center + center_multiplier, center_units = select_duration_units(center) + delta_multiplier, delta_units = select_duration_units(lower_delta, upper_delta) + precision = duration_precision_for_center(center, delta_multiplier) + if center_units == delta_units: + center_text = f"{center * center_multiplier:0.3f}" + interval_text = ( + f"[{lower_delta * delta_multiplier:+0.{precision}f}, " + f"{upper_delta * delta_multiplier:+0.{precision}f}]" + ) + if center_width is not None: + center_text = f"{center_text:>{center_width}}" + if interval_width is not None: + interval_text = f"{interval_text:>{interval_width}}" + return f"{center_text} {interval_text} {center_units}" + + return ( + f"{format_duration(center)} " + f"[{lower_delta * delta_multiplier:+0.{precision}f}, " + f"{upper_delta * delta_multiplier:+0.{precision}f}] {delta_units}" + ) + + +def longest_common_prefix(strings): + if not strings: + return "" + prefix = strings[0] + for text in strings[1:]: + while not text.startswith(prefix): + prefix = prefix[:-1] + if not prefix: + return "" + return prefix + + +def common_numeric_prefix_is_useful(prefix): + if "." not in prefix: + return False + + numeric_digits = sum(char.isdigit() for char in prefix) + fractional_prefix = prefix.split(".", 1)[1] + fractional_digits = sum(char.isdigit() for char in fractional_prefix) + return numeric_digits >= 2 and fractional_digits >= 1 + + +def align_interval_values(values, widths=None): + if widths is None: + widths = [max(len(value) for value in values)] * len(values) + return [f"{value:>{width}}" for value, width in zip(values, widths, strict=True)] + + +def explicit_interval_values(center, interval): + multiplier, units = select_duration_units( + interval.lower, interval.center, interval.upper + ) + return ( + [ + f"{interval.lower * multiplier:0.3f}", + f"{interval.center * multiplier:0.3f}", + f"{interval.upper * multiplier:0.3f}", + ], + units, + ) + + +def explicit_interval_column_widths(comparisons, center_getter, interval_getter): + widths = [0, 0, 0] + for comparison in comparisons: + center = center_getter(comparison) + interval = interval_getter(comparison) + if center is None or interval is None: + continue + + values, _ = explicit_interval_values(center, interval) + prefix = longest_common_prefix(values) + if common_numeric_prefix_is_useful(prefix): + continue + + widths = [max(width, len(value)) for width, value in zip(widths, values)] + return widths + + +def format_timing_with_explicit_interval(center, interval, *, value_widths=None): + if center is None: + return "n/a" + if interval is None: + return format_duration(center) + + values, units = explicit_interval_values(center, interval) + prefix = longest_common_prefix(values) + if not common_numeric_prefix_is_useful(prefix): + values = align_interval_values(values, value_widths) + return f"[{values[0]} | {values[1]} | {values[2]}] {units}" + + suffixes = [value[len(prefix) :] for value in values] + return f"{prefix}[{suffixes[0]} | {suffixes[1]} | {suffixes[2]}] {units}" def format_percentage(percentage): - # When there aren't enough samples for a meaningful noise measurement, - # the noise is recorded as infinity. Unfortunately, JSON spec doesn't - # allow for inf, so these get turned into null. if percentage is None: + return "n/a" + if math.isnan(percentage): + return "n/a" + if math.isinf(percentage): return "inf" - return "%0.2f%%" % (percentage * 100.0) + return f"{percentage * 100.0:0.2f}%" + + +def format_percentage_bounds(bounds, status): + if bounds is None: + return "n/a" + lower, upper = bounds + if status == ComparisonStatus.FAST: + return f"<= {upper * 100.0:+0.1f}%" + if status == ComparisonStatus.SLOW: + return f">= {lower * 100.0:+0.1f}%" + return f"in [{lower * 100.0:+0.1f}%, {upper * 100.0:+0.1f}%]" + + +def format_change(comparison): + if comparison.status not in {ComparisonStatus.FAST, ComparisonStatus.SLOW}: + return "" + return format_percentage_bounds(comparison.frac_diff_interval, comparison.status) + + +def get_display_headers(display): + if display == "legacy": + return ( + [ + "Ref Time", + "Ref Noise", + "Cmp Time", + "Cmp Noise", + "Diff", + "%Diff", + "Status", + ], + ["right", "right", "right", "right", "right", "right", "center"], + ) + if display == "explain": + return ( + [ + "Ref [Lo | Ce | Hi]", + "Cmp [Lo | Ce | Hi]", + "Ref Noise", + "Cmp Noise", + "Reason", + "Change", + "Status", + ], + ["right", "right", "right", "right", "left", "right", "center"], + ) + return ( + ["Ref", "Cmp", "Change", "Status"], + ["right", "right", "right", "center"], + ) + + +def append_display_row(row, comparison, no_color, display): + if display == "legacy": + row.append(format_duration(comparison.ref_time)) + row.append(format_percentage(comparison.ref_noise)) + row.append(format_duration(comparison.cmp_time)) + row.append(format_percentage(comparison.cmp_noise)) + row.append(format_duration(comparison.diff)) + row.append(format_percentage(comparison.frac_diff)) + row.append(colorize_comparison_status(comparison.status, no_color)) + return + + row.append( + format_timing_with_interval(comparison.ref_time, comparison.ref_interval) + ) + row.append( + format_timing_with_interval(comparison.cmp_time, comparison.cmp_interval) + ) + if display == "explain": + row[-2] = format_timing_with_explicit_interval( + comparison.ref_time, comparison.ref_interval + ) + row[-1] = format_timing_with_explicit_interval( + comparison.cmp_time, comparison.cmp_interval + ) + row.append(format_percentage(comparison.ref_noise)) + row.append(format_percentage(comparison.cmp_noise)) + row.append(format_reason_display_code(comparison.reason.code)) + row.append(format_change(comparison)) + row.append(colorize_comparison_status(comparison.status, no_color)) + + +def align_explain_interval_columns(rows, comparisons, axis_count): + ref_widths = explicit_interval_column_widths( + comparisons, + lambda comparison: comparison.ref_time, + lambda comparison: comparison.ref_interval, + ) + cmp_widths = explicit_interval_column_widths( + comparisons, + lambda comparison: comparison.cmp_time, + lambda comparison: comparison.cmp_interval, + ) + for row, comparison in zip(rows, comparisons, strict=True): + row[axis_count] = format_timing_with_explicit_interval( + comparison.ref_time, comparison.ref_interval, value_widths=ref_widths + ) + row[axis_count + 1] = format_timing_with_explicit_interval( + comparison.cmp_time, comparison.cmp_interval, value_widths=cmp_widths + ) + + +def timing_interval_column_widths(comparisons, center_getter, interval_getter): + center_width = 0 + interval_width = 0 + for comparison in comparisons: + center = center_getter(comparison) + if center is None: + continue + + center_multiplier, center_units = select_duration_units(center) + center_text = f"{center * center_multiplier:0.3f}" + center_width = max(center_width, len(center_text)) + + interval = interval_getter(comparison) + if interval is None: + continue + + lower_delta = interval.lower - interval.center + upper_delta = interval.upper - interval.center + delta_multiplier, delta_units = select_duration_units(lower_delta, upper_delta) + if center_units != delta_units: + continue + + precision = duration_precision_for_center(center, delta_multiplier) + interval_text = ( + f"[{lower_delta * delta_multiplier:+0.{precision}f}, " + f"{upper_delta * delta_multiplier:+0.{precision}f}]" + ) + interval_width = max(interval_width, len(interval_text)) + + return center_width, interval_width + + +def align_timing_interval_columns(rows, comparisons, axis_count): + ref_center_width, ref_interval_width = timing_interval_column_widths( + comparisons, + lambda comparison: comparison.ref_time, + lambda comparison: comparison.ref_interval, + ) + cmp_center_width, cmp_interval_width = timing_interval_column_widths( + comparisons, + lambda comparison: comparison.cmp_time, + lambda comparison: comparison.cmp_interval, + ) + for row, comparison in zip(rows, comparisons, strict=True): + row[axis_count] = format_timing_with_interval( + comparison.ref_time, + comparison.ref_interval, + center_width=ref_center_width, + interval_width=ref_interval_width, + ) + row[axis_count + 1] = format_timing_with_interval( + comparison.cmp_time, + comparison.cmp_interval, + center_width=cmp_center_width, + interval_width=cmp_interval_width, + ) + + +def has_finite_noise(noise): + return noise is not None and math.isfinite(noise) + + +def colorize_comparison_status(status, no_color): + if status == ComparisonStatus.UNKNOWN: + return colorize(status.value, Fore.YELLOW, Emoji.YELLOW, no_color) + if status == ComparisonStatus.UNDECIDED: + return colorize(status.value, Fore.YELLOW, Emoji.YELLOW, no_color) + if status == ComparisonStatus.SAME: + return colorize(status.value, Fore.BLUE, Emoji.BLUE, no_color) + if status == ComparisonStatus.FAST: + return colorize(status.value, Fore.GREEN, Emoji.GREEN, no_color) + return colorize(status.value, Fore.RED, Emoji.RED, no_color) def format_axis_values(axis_values, axes, axis_filters=None): @@ -292,16 +2388,28 @@ def plot_comparison_entries(entries, title=None, dark=False): def compare_benches( + run_data: ComparisonRunData, ref_benches, cmp_benches, threshold, plot_along, plot, dark, - axis_filters, - benchmark_filters, + filter_plan, no_color, + reference_device_filter=None, + compare_device_filter=None, + ref_json_dir=None, + cmp_json_dir=None, + ref_json_path=None, + cmp_json_path=None, + comparison_thresholds=None, + display="intervals", + bulk_debug_rows=None, ): + if comparison_thresholds is None: + comparison_thresholds = ComparisonThresholds() + if plot_along: import matplotlib.pyplot as plt import seaborn as sns @@ -314,12 +2422,28 @@ def compare_benches( ref_bench = find_matching_bench(cmp_bench, ref_benches) if not ref_bench: continue - if benchmark_filters and cmp_bench["name"] not in benchmark_filters: + if not benchmark_is_selected(cmp_bench["name"], filter_plan): continue + axis_filter_groups = axis_filter_groups_for_benchmark( + cmp_bench["name"], filter_plan + ) + + cmp_device_ids = resolve_benchmark_device_ids( + cmp_bench, compare_device_filter, "--compare-devices" + ) + ref_device_ids = resolve_benchmark_device_ids( + ref_bench, reference_device_filter, "--reference-devices" + ) + if len(cmp_device_ids) != len(ref_device_ids): + raise ValueError( + f"benchmark {cmp_bench['name']!r} has {len(ref_device_ids)} " + f"reference device(s) but {len(cmp_device_ids)} compare device(s); " + "nvbench_compare pairs devices by position, so each compared " + "benchmark must contain the same number of devices" + ) print(f"""# {cmp_bench["name"]}\n""") - cmp_device_ids = cmp_bench["devices"] axes = cmp_bench["axes"] ref_states = ref_bench["states"] cmp_states = cmp_bench["states"] @@ -328,36 +2452,53 @@ def compare_benches( headers = [x["name"] for x in axes] colalign = ["center"] * len(headers) - - headers.append("Ref Time") - colalign.append("right") - headers.append("Ref Noise") - colalign.append("right") - headers.append("Cmp Time") - colalign.append("right") - headers.append("Cmp Noise") - colalign.append("right") - headers.append("Diff") - colalign.append("right") - headers.append("%Diff") - colalign.append("right") - headers.append("Status") - colalign.append("center") - - for cmp_device_id in cmp_device_ids: - rows = [] - plot_data = {"cmp": {}, "ref": {}, "cmp_noise": {}, "ref_noise": {}} - - for cmp_state in cmp_states: - cmp_state_name = cmp_state["name"] - ref_state = next( - filter(lambda st: st["name"] == cmp_state_name, ref_states), None + display_headers, display_colalign = get_display_headers(display) + headers.extend(display_headers) + colalign.extend(display_colalign) + + for cmp_device_index, cmp_device_id in enumerate(cmp_device_ids): + ref_device_id = ref_device_ids[cmp_device_index] + ref_device_states = [ + state + for state in ref_states + if state["device"] == ref_device_id + and matches_axis_filter_groups(state, axis_filter_groups) + ] + cmp_device_states = [ + state + for state in cmp_states + if state["device"] == cmp_device_id + and matches_axis_filter_groups(state, axis_filter_groups) + ] + ref_states_by_name = group_states_by_match_key(ref_device_states) + cmp_states_by_name = group_states_by_match_key(cmp_device_states) + ref_state_counts = state_group_counts(ref_states_by_name) + cmp_state_counts = state_group_counts(cmp_states_by_name) + if ref_state_counts != cmp_state_counts: + raise ValueError( + f"benchmark {cmp_bench['name']!r} device pair " + f"ref={ref_device_id} cmp={cmp_device_id} has mismatched " + f"state occurrences: ref={dict(ref_state_counts)}, " + f"cmp={dict(cmp_state_counts)}" ) - if not ref_state: - continue - if not matches_axis_filters(cmp_state, axis_filters): - continue + rows = [] + row_comparisons = [] + plot_data: dict[str, dict[str, dict[float, float | None]]] = { + "cmp": {}, + "ref": {}, + "cmp_noise": {}, + "ref_noise": {}, + } + counters: dict[str, int] = {} + + for cmp_state in cmp_device_states: + cmp_state_name = state_match_key(cmp_state) + occurrence = counters.get(cmp_state_name, 0) + counters[cmp_state_name] = occurrence + 1 + # Duplicate state names are matched by occurrence order within + # the filtered device section. + ref_state = ref_states_by_name[cmp_state_name][occurrence] axis_values = cmp_state["axis_values"] if not axis_values: axis_values = [] @@ -373,121 +2514,69 @@ def compare_benches( if not ref_summaries or not cmp_summaries: continue - def lookup_summary(summaries, tag): - return next(filter(lambda s: s["tag"] == tag, summaries), None) - - cmp_time_summary = lookup_summary( - cmp_summaries, "nv/cold/time/gpu/mean" - ) - ref_time_summary = lookup_summary( - ref_summaries, "nv/cold/time/gpu/mean" - ) - cmp_noise_summary = lookup_summary( - cmp_summaries, "nv/cold/time/gpu/stdev/relative" - ) - ref_noise_summary = lookup_summary( - ref_summaries, "nv/cold/time/gpu/stdev/relative" - ) - # TODO: Use other timings, too. Maybe multiple rows, with a # "Timing" column + values "CPU/GPU/Batch"? - if not all( - [ - cmp_time_summary, - ref_time_summary, - cmp_noise_summary, - ref_noise_summary, - ] - ): + cmp_gpu_time = extract_gpu_timing_data(cmp_summaries, cmp_json_dir) + ref_gpu_time = extract_gpu_timing_data(ref_summaries, ref_json_dir) + comparison = compare_gpu_timings( + ref_gpu_time, cmp_gpu_time, comparison_thresholds + ) + if comparison is None: continue - def extract_value(summary): - summary_data = summary["data"] - value_data = next( - filter(lambda v: v["name"] == "value", summary_data) - ) - assert value_data["type"] == "float64" - return value_data["value"] - - cmp_time = extract_value(cmp_time_summary) - ref_time = extract_value(ref_time_summary) - cmp_noise = extract_value(cmp_noise_summary) - ref_noise = extract_value(ref_noise_summary) - - # Convert string encoding to expected numerics: - cmp_time = float(cmp_time) - ref_time = float(ref_time) - - diff = cmp_time - ref_time - frac_diff = diff / ref_time - - if ref_noise and cmp_noise: - ref_noise = float(ref_noise) - cmp_noise = float(cmp_noise) - min_noise = min(ref_noise, cmp_noise) - elif ref_noise: - ref_noise = float(ref_noise) - min_noise = ref_noise - elif cmp_noise: - cmp_noise = float(cmp_noise) - min_noise = cmp_noise - else: - min_noise = None # Noise is inf - if plot_along: - axis_name = [] - axis_value = "--" + axis_name_parts = [] + axis_value = None for av in axis_values: if av["name"] != plot_along: - axis_name.append(f"""{av["name"]} = {av["value"]}""") + axis_name_parts.append(f"""{av["name"]} = {av["value"]}""") else: axis_value = float(av["value"]) - axis_name = ", ".join(axis_name) - - if axis_name not in plot_data["cmp"]: - plot_data["cmp"][axis_name] = {} - plot_data["ref"][axis_name] = {} - plot_data["cmp_noise"][axis_name] = {} - plot_data["ref_noise"][axis_name] = {} - - plot_data["cmp"][axis_name][axis_value] = cmp_time - plot_data["ref"][axis_name][axis_value] = ref_time - plot_data["cmp_noise"][axis_name][axis_value] = cmp_noise - plot_data["ref_noise"][axis_name][axis_value] = ref_noise - - global config_count - global unknown_count - global pass_count - global failure_count - - config_count += 1 - if not min_noise: - unknown_count += 1 - status_label = "????" - status = colorize(status_label, Fore.YELLOW, Emoji.YELLOW, no_color) - elif abs(frac_diff) <= min_noise: - pass_count += 1 - status_label = "SAME" - status = colorize(status_label, Fore.BLUE, Emoji.BLUE, no_color) - elif diff < 0: - failure_count += 1 - status_label = "FAST" - status = colorize(status_label, Fore.GREEN, Emoji.GREEN, no_color) - else: - failure_count += 1 - status_label = "SLOW" - status = colorize(status_label, Fore.RED, Emoji.RED, no_color) - - if abs(frac_diff) >= threshold: - row.append(format_duration(ref_time)) - row.append(format_percentage(ref_noise)) - row.append(format_duration(cmp_time)) - row.append(format_percentage(cmp_noise)) - row.append(format_duration(diff)) - row.append(format_percentage(frac_diff)) - row.append(status) + if axis_value is not None: + axis_name = ", ".join(axis_name_parts) + + if axis_name not in plot_data["cmp"]: + plot_data["cmp"][axis_name] = {} + plot_data["ref"][axis_name] = {} + plot_data["cmp_noise"][axis_name] = {} + plot_data["ref_noise"][axis_name] = {} + + plot_data["cmp"][axis_name][axis_value] = comparison.cmp_time + plot_data["ref"][axis_name][axis_value] = comparison.ref_time + plot_data["cmp_noise"][axis_name][axis_value] = ( + comparison.cmp_noise + ) + plot_data["ref_noise"][axis_name][axis_value] = ( + comparison.ref_noise + ) + + run_data.stats.record(comparison.status, comparison.reason) + if abs(comparison.frac_diff) >= threshold: + axis_filters = matching_axis_filters(cmp_state, axis_filter_groups) + append_display_row(row, comparison, no_color, display) rows.append(row) + row_comparisons.append(comparison) + if bulk_debug_rows is not None: + bulk_debug_rows.append( + make_bulk_debug_row( + row_index=len(bulk_debug_rows), + table_row_index=len(rows) - 1, + benchmark_name=cmp_bench["name"], + ref_json_path=ref_json_path, + cmp_json_path=cmp_json_path, + ref_device_id=ref_device_id, + cmp_device_id=cmp_device_id, + cmp_state_name=cmp_state_name, + occurrence=occurrence, + occurrence_count=cmp_state_counts[cmp_state_name], + axis_values=axis_values, + axes=axes, + ref_timing=ref_gpu_time, + cmp_timing=cmp_gpu_time, + comparison=comparison, + ) + ) if plot: axis_label = format_axis_values(axis_values, axes, axis_filters) if axis_label: @@ -495,31 +2584,40 @@ def extract_value(summary): else: label = cmp_bench["name"] cmp_device = find_device_by_id( - cmp_state["device"], all_cmp_devices + cmp_state["device"], run_data.cmp_devices ) if cmp_device: comparison_device_names.add(cmp_device["name"]) comparison_entries.append( - (label, frac_diff, status_label, cmp_bench["name"]) + ( + label, + comparison.frac_diff, + comparison.status.value, + cmp_bench["name"], + ) ) if len(rows) == 0: continue - - cmp_device = find_device_by_id(cmp_device_id, all_cmp_devices) - ref_device = find_device_by_id(ref_state["device"], all_ref_devices) + if display == "explain": + align_explain_interval_columns(rows, row_comparisons, len(axes)) + elif display == "intervals": + align_timing_interval_columns(rows, row_comparisons, len(axes)) + + cmp_device = find_device_by_id(cmp_device_id, run_data.cmp_devices) + ref_device = find_device_by_id(ref_device_id, run_data.ref_devices) + if ref_device is None or cmp_device is None: + raise ValueError( + f"benchmark {cmp_bench['name']!r} references device pair " + f"ref={ref_device_id} cmp={cmp_device_id}, but device metadata is missing" + ) if cmp_device == ref_device: - print("## [%d] %s\n" % (cmp_device["id"], cmp_device["name"])) + print(f"## [{cmp_device['id']}] {cmp_device['name']}\n") else: print( - "## [%d] %s vs. [%d] %s\n" - % ( - ref_device["id"], - ref_device["name"], - cmp_device["id"], - cmp_device["name"], - ) + f"## [{ref_device['id']}] {ref_device['name']} vs. " + f"[{cmp_device['id']}] {cmp_device['name']}\n" ) # colalign and github format require tabulate 0.8.3 if tabulate_version >= (0, 8, 3): @@ -534,39 +2632,84 @@ def extract_value(summary): print("") if plot_along: - plt.xscale("log") - plt.yscale("log") - plt.xlabel(plot_along) - plt.ylabel("time [s]") - plt.title(cmp_device["name"]) - - def plot_line(key, shape, label): - x = [float(x) for x in plot_data[key][axis].keys()] - y = list(plot_data[key][axis].values()) - - noise = list(plot_data[key + "_noise"][axis].values()) - - top = [y[i] + y[i] * noise[i] for i in range(len(x))] - bottom = [y[i] - y[i] * noise[i] for i in range(len(x))] - - p = plt.plot(x, y, shape, marker="o", label=label) - plt.fill_between(x, bottom, top, color=p[0].get_color(), alpha=0.1) - - for axis in plot_data["cmp"].keys(): - plot_line("cmp", "-", axis) - plot_line("ref", "--", axis + " ref") - - plt.legend() - plt.show() + fig = plt.figure() + try: + plt.xscale("log") + plt.yscale("log") + plt.xlabel(plot_along) + plt.ylabel("time [s]") + plt.title(cmp_device["name"]) + + def plot_line(key, shape, label, data_axis, data=plot_data): + axis_times = data[key][data_axis] + if not axis_times: + return + axis_noise = data[key + "_noise"][data_axis] + series = sorted( + ( + ( + float(axis_value), + axis_times[axis_value], + axis_noise[axis_value], + ) + for axis_value in axis_times + ), + key=lambda item: item[0], + ) + x, y, noise = map(list, zip(*series, strict=True)) + + p = plt.plot(x, y, shape, marker="o", label=label) + + def plot_confidence_band(first, last): + if last - first < 2: + return + + band_x = x[first:last] + band_y = y[first:last] + band_noise = noise[first:last] + top = [ + band_y[i] + band_y[i] * band_noise[i] + for i in range(len(band_x)) + ] + bottom = [ + max( + band_y[i] - band_y[i] * band_noise[i], + band_y[i] * 0.001, + ) + for i in range(len(band_x)) + ] + plt.fill_between( + band_x, bottom, top, color=p[0].get_color(), alpha=0.1 + ) + + start = None + for i, noise_value in enumerate(noise): + if has_finite_noise(noise_value) and start is None: + start = i + if not has_finite_noise(noise_value) and start is not None: + plot_confidence_band(start, i) + start = None + + if start is not None: + plot_confidence_band(start, len(x)) + + for axis in plot_data["cmp"].keys(): + plot_line("cmp", "-", axis, axis) + plot_line("ref", "--", axis + " ref", axis) + + plt.legend() + plt.show() + finally: + plt.close(fig) if plot: title = "%SOL Bandwidth change" if len(comparison_device_names) == 1: title = f"{title} - {next(iter(comparison_device_names))}" - if axis_filters: + if filter_plan.global_axis_filters: axis_label = ", ".join( axis_filter["display"] - for axis_filter in axis_filters + for axis_filter in filter_plan.global_axis_filters if len(axis_filter["values"]) == 1 ) if axis_label: @@ -574,7 +2717,14 @@ def plot_line(key, shape, label): plot_comparison_entries(comparison_entries, title=title, dark=dark) -def main(): +def main() -> int: + """ + Returns a process exit code. + - 0 means the comparison completed successfully. + - 1 signals an error has occurred. + + The number of detected regressions is reported in the summary output. + """ help_text = "%(prog)s [reference.json compare.json | reference_dir/ compare_dir/]" parser = argparse.ArgumentParser(prog="nvbench_compare", usage=help_text) parser.add_argument( @@ -591,6 +2741,36 @@ def main(): default=0.0, help="only show benchmarks where percentage diff is >= THRESHOLD", ) + parser.add_argument( + "--preset", + choices=sorted(COMPARISON_THRESHOLD_PRESETS), + default=None, + help="comparison threshold preset", + ) + parser.add_argument( + "--config", + default=None, + help="comparison threshold TOML config", + ) + parser.add_argument( + "--dump-config", + action="store_true", + help="print the effective comparison threshold config and exit", + ) + parser.add_argument( + "--display", + choices=["intervals", "legacy", "explain"], + default="intervals", + help="comparison table display mode", + ) + parser.add_argument( + "--bulk-debug-python", + default=None, + help=( + "Write Python code that describes bulk sample/frequency files for " + "each displayed row. Use 'stdout' to print the code to stdout." + ), + ) parser.add_argument( "--plot-along", type=str, dest="plot_along", default=None, help="plot results" ) @@ -612,32 +2792,72 @@ def main(): action="store_true", help="Use emoji instead of ANSI color codes (useful for GitHub issues/PRs)", ) + parser.add_argument( + "--reference-devices", + default="all", + help="Reference devices to compare: all, a non-negative integer id, or comma-separated ids", + ) + parser.add_argument( + "--compare-devices", + default="all", + help="Compare devices to compare: all, a non-negative integer id, or comma-separated ids", + ) parser.add_argument( "-a", "--axis", - action="append", - default=[], - help="Filter on axis value, e.g. -a Elements{io}=2^20 (can repeat)", + dest="filter_actions", + action=OrderedBenchmarkFilterAction, + help=( + "Filter on axis value, e.g. -a Elements{io}=2^20. Applies to the " + "most recent --benchmark, or all benchmarks if specified before any " + "--benchmark arguments." + ), ) parser.add_argument( "-b", "--benchmark", - action="append", - default=[], + dest="filter_actions", + action=OrderedBenchmarkFilterAction, help="Filter by benchmark name (can repeat)", ) args, files_or_dirs = parser.parse_known_args() - print(files_or_dirs) try: - axis_filters = parse_axis_filters(args.axis) + comparison_preset, comparison_thresholds = resolve_comparison_thresholds( + args.preset, args.config + ) except ValueError as exc: print(str(exc)) - sys.exit(1) + return 1 + + if args.dump_config: + print(dump_comparison_config(comparison_preset, comparison_thresholds), end="") + return 0 + + try: + filter_plan = build_benchmark_filter_plan(args.filter_actions) + reference_device_filter = parse_device_filter( + args.reference_devices, "--reference-devices" + ) + compare_device_filter = parse_device_filter( + args.compare_devices, "--compare-devices" + ) + except ValueError as exc: + print(str(exc)) + return 1 if len(files_or_dirs) != 2: parser.print_help() - sys.exit(1) + return 1 + + bulk_debug_output = ( + None + if args.bulk_debug_python is None + else BulkDebugOutput(args.bulk_debug_python) + ) + bulk_debug_rows: list[dict[str, Any]] | None = ( + [] if bulk_debug_output is not None else None + ) # if provided two directories, find all the exactly named files # in both and treat them as the reference and compare @@ -658,16 +2878,31 @@ def main(): else: to_compare = [(files_or_dirs[0], files_or_dirs[1])] + stats = ComparisonStats() + for ref, comp in to_compare: ref_root = reader.read_file(ref) cmp_root = reader.read_file(comp) - global all_ref_devices - global all_cmp_devices - all_ref_devices = ref_root["devices"] - all_cmp_devices = cmp_root["devices"] + try: + selected_ref_devices = select_devices( + ref_root["devices"], reference_device_filter, "--reference-devices" + ) + selected_cmp_devices = select_devices( + cmp_root["devices"], compare_device_filter, "--compare-devices" + ) + except ValueError as exc: + print(str(exc)) + return 1 + + if len(selected_ref_devices) != len(selected_cmp_devices): + print( + f"--reference-devices selected {len(selected_ref_devices)} device(s), " + f"but --compare-devices selected {len(selected_cmp_devices)} device(s)" + ) + return 1 - if ref_root["devices"] != cmp_root["devices"]: + if selected_ref_devices != selected_cmp_devices: warn_fore = Fore.YELLOW if args.ignore_devices else Fore.RED msg_text = "Device sections do not match" print(colorize(msg_text, warn_fore, Emoji.NONE, args.no_color), end="") @@ -675,30 +2910,72 @@ def main(): print( jsondiff.diff( - ref_root["devices"], cmp_root["devices"], syntax="symmetric" + selected_ref_devices, selected_cmp_devices, syntax="symmetric" ) ) - if not args.ignore_devices: - sys.exit(1) + if not args.ignore_devices and require_matching_device_sections( + reference_device_filter, compare_device_filter + ): + return 1 - compare_benches( - ref_root["benchmarks"], - cmp_root["benchmarks"], - args.threshold, - args.plot_along, - args.plot, - args.dark, - axis_filters, - args.benchmark, - args.no_color, + run_data = ComparisonRunData( + stats=stats, + ref_devices=tuple(selected_ref_devices), + cmp_devices=tuple(selected_cmp_devices), ) + try: + compare_benches( + run_data, + ref_root["benchmarks"], + cmp_root["benchmarks"], + threshold=args.threshold, + plot_along=args.plot_along, + plot=args.plot, + dark=args.dark, + filter_plan=filter_plan, + no_color=args.no_color, + reference_device_filter=reference_device_filter, + compare_device_filter=compare_device_filter, + ref_json_dir=os.path.dirname(ref), + cmp_json_dir=os.path.dirname(comp), + ref_json_path=ref, + cmp_json_path=comp, + comparison_thresholds=comparison_thresholds, + display=args.display, + bulk_debug_rows=bulk_debug_rows, + ) + except ValueError as exc: + print(str(exc)) + return 1 + print("# Summary\n") - print("- Total Matches: %d" % config_count) - print(" - Pass (diff <= min_noise): %d" % pass_count) - print(" - Unknown (infinite noise): %d" % unknown_count) - print(" - Failure (diff > min_noise): %d" % failure_count) - return failure_count + print(f"- Total Matches: {stats.config_count}") + print(f" - Pass (centers close and intervals overlap): {stats.pass_count}") + print(f" - Improvement (clear timing gap, %Diff < 0): {stats.improvement_count}") + print(f" - Regression (clear timing gap, %Diff > 0): {stats.regression_count}") + print( + f" - Undecided (comparison requires more evidence): {stats.undecided_count}" + ) + if stats.undecided_reasons: + print(" - Reasons:") + for code, reason_summary in sorted( + stats.undecided_reasons.items(), + key=lambda item: item[1].count, + reverse=True, + ): + print(f" - {code}: {reason_summary.count} ({reason_summary.message})") + if args.display == "explain" and stats.reason_legend: + legend_entries = format_reason_legend_entries(stats.reason_legend) + if legend_entries: + print(f" - Reason legend: {'; '.join(legend_entries)}") + print(f" - Unknown (infinite or unavailable noise): {stats.unknown_count}") + try: + write_bulk_debug_python(bulk_debug_output, bulk_debug_rows or []) + except OSError as exc: + print(f"failed to write bulk debug Python output: {exc}") + return 1 + return 0 if __name__ == "__main__": diff --git a/python/test/test_nvbench_compare.py b/python/test/test_nvbench_compare.py new file mode 100644 index 00000000..9f0fdee2 --- /dev/null +++ b/python/test/test_nvbench_compare.py @@ -0,0 +1,2017 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import importlib.util +import sys +import types +from pathlib import Path + +import numpy as np +import pytest + + +@pytest.fixture +def nvbench_compare(monkeypatch): + class DummyLine: + def get_color(self): + return "black" + + pyplot = types.ModuleType("matplotlib.pyplot") + pyplot.figure = lambda *args, **kwargs: None + pyplot.xscale = lambda *args, **kwargs: None + pyplot.yscale = lambda *args, **kwargs: None + pyplot.xlabel = lambda *args, **kwargs: None + pyplot.ylabel = lambda *args, **kwargs: None + pyplot.title = lambda *args, **kwargs: None + pyplot.plot = lambda *args, **kwargs: [DummyLine()] + pyplot.fill_between = lambda *args, **kwargs: None + pyplot.legend = lambda *args, **kwargs: None + pyplot.show = lambda *args, **kwargs: None + pyplot.close = lambda *args, **kwargs: None + + matplotlib = types.ModuleType("matplotlib") + matplotlib.pyplot = pyplot + monkeypatch.setitem(sys.modules, "matplotlib", matplotlib) + monkeypatch.setitem(sys.modules, "matplotlib.pyplot", pyplot) + monkeypatch.setitem( + sys.modules, + "seaborn", + types.SimpleNamespace(set_theme=lambda *args, **kwargs: None), + ) + monkeypatch.setitem( + sys.modules, "jsondiff", types.SimpleNamespace(diff=lambda *args, **kwargs: {}) + ) + monkeypatch.setitem( + sys.modules, + "tabulate", + types.SimpleNamespace( + __version__="0.8.10", tabulate=lambda *args, **kwargs: "" + ), + ) + monkeypatch.setitem( + sys.modules, + "colorama", + types.SimpleNamespace( + Fore=types.SimpleNamespace( + BLUE="", + GREEN="", + RED="", + RESET="", + YELLOW="", + ) + ), + ) + + module_path = Path(__file__).resolve().parents[1] / "scripts" / "nvbench_compare.py" + spec = importlib.util.spec_from_file_location("nvbench_compare", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def make_state( + nvbench_compare, name, *, mean="1.0", noise="0.01", axis_value=None, device=0 +): + return { + "name": name, + "device": device, + "axis_values": [] + if axis_value is None + else [{"name": "A", "type": "int64", "value": axis_value}], + "summaries": [ + { + "tag": nvbench_compare.GPU_TIME_MEAN_TAG, + "data": [{"name": "value", "type": "float64", "value": mean}], + }, + { + "tag": nvbench_compare.GPU_TIME_STDEV_RELATIVE_TAG, + "data": [{"name": "value", "type": "float64", "value": noise}], + }, + ], + } + + +def make_summary(nvbench_compare, tag, value): + return { + "tag": getattr(nvbench_compare, tag), + "data": [{"name": "value", "type": "float64", "value": value}], + } + + +def make_binary_summary(nvbench_compare, tag, filename, size): + return { + "tag": getattr(nvbench_compare, tag), + "data": [ + {"name": "filename", "type": "string", "value": filename}, + {"name": "size", "type": "int64", "value": str(size)}, + ], + } + + +def make_gpu_timing_data( + nvbench_compare, + *, + minimum=None, + maximum=None, + mean=1.0, + stdev=None, + stdev_relative=0.01, + first_quartile=None, + median=None, + third_quartile=None, + interquartile_range=None, + interquartile_range_relative=None, + sm_clock_rate_mean=None, + sample_values=None, + frequency_values=None, +): + return nvbench_compare.GpuTimingData( + minimum=minimum, + maximum=maximum, + mean=mean, + stdev=stdev, + stdev_relative=stdev_relative, + first_quartile=first_quartile, + median=median, + third_quartile=third_quartile, + interquartile_range=interquartile_range, + interquartile_range_relative=interquartile_range_relative, + sm_clock_rate_mean=sm_clock_rate_mean, + sample_source=None + if sample_values is None + else types.SimpleNamespace(values=np.asarray(sample_values, dtype=np.float32)), + frequency_source=None + if frequency_values is None + else types.SimpleNamespace( + values=np.asarray(frequency_values, dtype=np.float32) + ), + ) + + +def make_benchmark(states, *, name="bench"): + devices = [] + for state in states: + if state["device"] not in devices: + devices.append(state["device"]) + + return { + "name": name, + "devices": devices, + "axes": [{"name": "A", "type": "int64", "flags": ""}] + if any(state["axis_values"] for state in states) + else [], + "states": states, + } + + +def make_comparison_run_data(nvbench_compare, ref_devices=None, cmp_devices=None): + devices = [{"id": 0, "name": "Test GPU"}] + return nvbench_compare.ComparisonRunData( + stats=nvbench_compare.ComparisonStats(), + ref_devices=tuple(devices if ref_devices is None else ref_devices), + cmp_devices=tuple(devices if cmp_devices is None else cmp_devices), + ) + + +def make_filter_plan(nvbench_compare, filter_actions=None): + return nvbench_compare.build_benchmark_filter_plan(filter_actions or []) + + +def test_compare_benches_accepts_matching_duplicate_state_counts( + monkeypatch, nvbench_compare +): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state2"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1", mean="1.005"), + make_state(nvbench_compare, "state1", mean="1.005"), + make_state(nvbench_compare, "state2", mean="1.005"), + ] + ) + ] + + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + assert run_data.stats.config_count == 3 + assert run_data.stats.pass_count == 0 + assert run_data.stats.improvement_count == 0 + assert run_data.stats.regression_count == 0 + assert run_data.stats.undecided_count == 3 + assert run_data.stats.unknown_count == 0 + + +def test_compare_benches_rejects_swapped_duplicate_state_counts( + monkeypatch, nvbench_compare +): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state2"), + make_state(nvbench_compare, "state2"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state1"), + make_state(nvbench_compare, "state2"), + make_state(nvbench_compare, "state2"), + make_state(nvbench_compare, "state2"), + ] + ) + ] + + with pytest.raises(ValueError, match="mismatched state occurrences"): + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + +def test_compare_benches_matches_duplicate_states_after_axis_filter( + monkeypatch, nvbench_compare +): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + ] + ) + ] + + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare, [("axis", "A=2")]), + no_color=True, + ) + + assert run_data.stats.config_count == 1 + assert run_data.stats.pass_count == 0 + assert run_data.stats.improvement_count == 0 + assert run_data.stats.regression_count == 0 + assert run_data.stats.undecided_count == 1 + assert run_data.stats.unknown_count == 0 + + +def test_compare_benches_skips_non_finite_centers(monkeypatch, nvbench_compare): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "finite", mean="1.0"), + make_state(nvbench_compare, "nan", mean="nan"), + make_state(nvbench_compare, "inf", mean="inf"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "finite", mean="1.0"), + make_state(nvbench_compare, "nan", mean="1.0"), + make_state(nvbench_compare, "inf", mean="1.0"), + ] + ) + ] + + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + assert run_data.stats.config_count == 1 + assert run_data.stats.pass_count == 0 + assert run_data.stats.improvement_count == 0 + assert run_data.stats.regression_count == 0 + assert run_data.stats.undecided_count == 1 + assert run_data.stats.unknown_count == 0 + + +def test_gpu_timing_data_loads_samples_and_frequencies_lazily( + tmp_path, nvbench_compare +): + samples_dir = tmp_path / "result.json-bin" + freqs_dir = tmp_path / "result.json-freqs-bin" + samples_dir.mkdir() + freqs_dir.mkdir() + samples_file = samples_dir / "0.bin" + freqs_file = freqs_dir / "0.bin" + + np.array([1.0, 2.0, 4.0], dtype="= +7.7%" + ) + + +def test_format_change_only_reports_fast_and_slow_rows(nvbench_compare): + fast = types.SimpleNamespace( + status=nvbench_compare.ComparisonStatus.FAST, + frac_diff_interval=(-0.3, -0.05), + ) + slow = types.SimpleNamespace( + status=nvbench_compare.ComparisonStatus.SLOW, + frac_diff_interval=(0.07, 0.55), + ) + same = types.SimpleNamespace( + status=nvbench_compare.ComparisonStatus.SAME, + frac_diff_interval=(-0.01, 0.01), + ) + undecided = types.SimpleNamespace( + status=nvbench_compare.ComparisonStatus.UNDECIDED, + frac_diff_interval=(-0.01, 0.01), + ) + + assert nvbench_compare.format_change(fast) == "<= -5.0%" + assert nvbench_compare.format_change(slow) == ">= +7.0%" + assert nvbench_compare.format_change(same) == "" + assert nvbench_compare.format_change(undecided) == "" + + +def test_format_timing_with_interval(nvbench_compare): + interval = nvbench_compare.TimingInterval( + lower=0.002237, upper=0.002389, center=0.0023 + ) + assert ( + nvbench_compare.format_timing_with_interval(0.0023, interval) + == "2.300 ms [-63, +89] us" + ) + + interval = nvbench_compare.TimingInterval( + lower=19.380e-6, upper=20.508e-6, center=19.944e-6 + ) + assert ( + nvbench_compare.format_timing_with_interval(19.944e-6, interval) + == "19.944 [-0.564, +0.564] us" + ) + + +def test_format_timing_with_explicit_interval(nvbench_compare): + interval = nvbench_compare.TimingInterval( + lower=0.001434, upper=0.001458, center=0.001446 + ) + assert ( + nvbench_compare.format_timing_with_explicit_interval(0.001446, interval) + == "1.4[34 | 46 | 58] ms" + ) + + interval = nvbench_compare.TimingInterval( + lower=18.400e-6, upper=19.464e-6, center=18.736e-6 + ) + assert ( + nvbench_compare.format_timing_with_explicit_interval(18.736e-6, interval) + == "[18.400 | 18.736 | 19.464] us" + ) + + interval = nvbench_compare.TimingInterval( + lower=19.380e-6, upper=20.508e-6, center=19.944e-6 + ) + assert ( + nvbench_compare.format_timing_with_explicit_interval(19.944e-6, interval) + == "[19.380 | 19.944 | 20.508] us" + ) + + interval = nvbench_compare.TimingInterval( + lower=99.094e-6, upper=100.882e-6, center=99.988e-6 + ) + assert ( + nvbench_compare.format_timing_with_explicit_interval(99.988e-6, interval) + == "[ 99.094 | 99.988 | 100.882] us" + ) + + +def test_align_explain_interval_columns_pads_values_across_rows(nvbench_compare): + rows = [["", ""], ["", ""]] + comparisons = [ + types.SimpleNamespace( + ref_time=19.944e-6, + ref_interval=nvbench_compare.TimingInterval( + lower=19.380e-6, center=19.944e-6, upper=20.508e-6 + ), + cmp_time=97.712e-6, + cmp_interval=nvbench_compare.TimingInterval( + lower=96.849e-6, center=97.712e-6, upper=98.574e-6 + ), + ), + types.SimpleNamespace( + ref_time=103.466e-6, + ref_interval=nvbench_compare.TimingInterval( + lower=102.739e-6, center=103.466e-6, upper=104.193e-6 + ), + cmp_time=101.868e-6, + cmp_interval=nvbench_compare.TimingInterval( + lower=100.916e-6, center=101.868e-6, upper=102.819e-6 + ), + ), + ] + + nvbench_compare.align_explain_interval_columns(rows, comparisons, axis_count=0) + + assert rows[0][0] == "[ 19.380 | 19.944 | 20.508] us" + assert rows[1][0] == "[102.739 | 103.466 | 104.193] us" + assert rows[0][1] == "[ 96.849 | 97.712 | 98.574] us" + assert rows[1][1] == "[100.916 | 101.868 | 102.819] us" + + +def test_align_timing_interval_columns_reserves_missing_interval_slot(nvbench_compare): + rows = [["", ""], ["", ""]] + comparisons = [ + types.SimpleNamespace( + ref_time=19.944e-6, + ref_interval=nvbench_compare.TimingInterval( + lower=19.380e-6, center=19.944e-6, upper=20.508e-6 + ), + cmp_time=18.736e-6, + cmp_interval=nvbench_compare.TimingInterval( + lower=18.400e-6, center=18.736e-6, upper=19.464e-6 + ), + ), + types.SimpleNamespace( + ref_time=20.390e-6, + ref_interval=nvbench_compare.TimingInterval( + lower=19.659e-6, center=20.390e-6, upper=21.121e-6 + ), + cmp_time=20.480e-6, + cmp_interval=None, + ), + ] + + nvbench_compare.align_timing_interval_columns(rows, comparisons, axis_count=0) + + cmp_interval_slot = len("[-0.336, +0.728]") + assert rows[0][1] == "18.736 [-0.336, +0.728] us" + assert rows[1][1] == f"20.480 {' ' * cmp_interval_slot} us" + + +def test_compare_gpu_timings_keeps_bulk_mismatch_undecided(nvbench_compare): + ref_timing = make_gpu_timing_data( + nvbench_compare, + minimum=1.0, + first_quartile=1.1, + median=1.2, + third_quartile=1.3, + mean=1.2, + interquartile_range_relative=0.01, + sample_values=[1.0, 1.0, 1.004, 1.004], + frequency_values=[100.0] * 4, + ) + cmp_timing = make_gpu_timing_data( + nvbench_compare, + minimum=1.02, + first_quartile=1.1, + median=1.204, + third_quartile=1.28, + mean=1.204, + interquartile_range_relative=0.01, + sample_values=[1.02, 1.02, 1.024, 1.024], + frequency_values=[100.0] * 4, + ) + + comparison = nvbench_compare.compare_gpu_timings(ref_timing, cmp_timing) + + assert comparison is not None + assert comparison.status == nvbench_compare.ComparisonStatus.UNDECIDED + assert comparison.reason.code == "bulk_time_support_mismatch" + assert "sample: min(ref=0.0%, cmp=0.0%) >= 99.0%" in comparison.reason.message + assert "support: min(ref=0.0%, cmp=0.0%) >= 80.0%" in comparison.reason.message + assert "99.0%" in comparison.reason.message + assert "80.0%" in comparison.reason.message + + +def test_compare_gpu_timings_requires_bulk_cycle_coverage(nvbench_compare): + ref_timing = make_gpu_timing_data( + nvbench_compare, + mean=1.0, + stdev_relative=0.01, + sample_values=[1.0, 1.0, 1.004, 1.004], + frequency_values=[100.0] * 4, + ) + cmp_timing = make_gpu_timing_data( + nvbench_compare, + mean=1.0, + stdev_relative=0.01, + sample_values=[1.0, 1.0, 1.004, 1.004], + frequency_values=[200.0] * 4, + ) + + comparison = nvbench_compare.compare_gpu_timings(ref_timing, cmp_timing) + + assert comparison is not None + assert comparison.status == nvbench_compare.ComparisonStatus.UNDECIDED + assert comparison.reason.code == "bulk_cycle_support_mismatch" + + +def test_bulk_same_reports_sample_weight_coverage_mismatch(nvbench_compare): + ref_values = [1.0, 1.001, 1.002, 1.003] + [1.02] * 100 + cmp_values = [1.0, 1.001, 1.002, 1.003] + + decision = nvbench_compare.compare_values_for_bulk_same( + ref_values, + cmp_values, + label="time", + thresholds=nvbench_compare.ComparisonThresholds(), + ) + + assert decision.status == nvbench_compare.ComparisonStatus.UNDECIDED + assert decision.reason.code == "bulk_time_support_mismatch" + assert "sample: min(ref=3.8%, cmp=100.0%) >= 99.0%" in decision.reason.message + assert "support: min(ref=80.0%, cmp=100.0%) >= 80.0%" in decision.reason.message + + +def test_bulk_same_filters_rare_values_from_support_coverage(nvbench_compare): + ref_values = [1.0] * 1000 + [1.02 + 0.01 * i for i in range(10)] + cmp_values = [1.0] + + decision = nvbench_compare.compare_values_for_bulk_same( + ref_values, + cmp_values, + label="time", + thresholds=nvbench_compare.ComparisonThresholds(), + ) + + assert decision.status == nvbench_compare.ComparisonStatus.SAME + assert decision.reason.code == "bulk_time_same" + + +def test_bulk_same_reports_unique_support_coverage_mismatch(nvbench_compare): + ref_values = [1.0] * 1000 + [1.02 + 0.01 * i for i in range(10)] + cmp_values = [1.0] + + decision = nvbench_compare.compare_values_for_bulk_same( + ref_values, + cmp_values, + label="time", + thresholds=nvbench_compare.ComparisonThresholds( + bulk_support_max_removed_sample_fraction=0.005 + ), + ) + + assert decision.status == nvbench_compare.ComparisonStatus.UNDECIDED + assert decision.reason.code == "bulk_time_support_mismatch" + assert "sample: min(ref=99.0%, cmp=100.0%) >= 99.0%" in decision.reason.message + assert "support: min(ref=9.1%, cmp=100.0%) >= 80.0%" in decision.reason.message + + +def test_bulk_same_retains_full_support_when_all_values_are_unique(nvbench_compare): + coverages = nvbench_compare.compute_nearest_neighbor_coverages( + [1.0, 1.02], + [1.0], + thresholds=nvbench_compare.ComparisonThresholds( + bulk_support_rare_sample_fraction=1.0, + bulk_support_max_removed_sample_fraction=1.0, + ), + ) + + assert coverages is not None + assert coverages["ref_sample"] == 0.5 + assert coverages["ref_support"] == 0.5 + assert coverages["ref_support_filter"] == nvbench_compare.SupportFilterInfo( + activated=False, + reason="all_values_unique", + removed_sample_fraction=0.0, + ) + + +def test_comparison_stats_records_undecided_status(nvbench_compare): + stats = nvbench_compare.ComparisonStats() + + stats.record(nvbench_compare.ComparisonStatus.UNDECIDED) + + assert stats.config_count == 1 + assert stats.pass_count == 0 + assert stats.improvement_count == 0 + assert stats.regression_count == 0 + assert stats.undecided_count == 1 + assert stats.unknown_count == 0 + + +def test_comparison_stats_records_undecided_reason(nvbench_compare): + stats = nvbench_compare.ComparisonStats() + less_severe_reason = nvbench_compare.DecisionReason( + code="test_reason", + message="less severe reason", + severity=1.0, + ) + more_severe_reason = nvbench_compare.DecisionReason( + code="test_reason", + message="more severe reason", + severity=2.0, + ) + + stats.record(nvbench_compare.ComparisonStatus.UNDECIDED, less_severe_reason) + stats.record(nvbench_compare.ComparisonStatus.UNDECIDED, more_severe_reason) + + summary = stats.undecided_reasons["test_reason"] + assert summary.count == 2 + assert summary.message == "more severe reason" + + +def test_reason_legend_omits_trivial_aliases(nvbench_compare): + reason_legend = { + "bulk-same": nvbench_compare.DecisionReasonSummary(canonical_code="bulk_same"), + "bt-sup-miss": nvbench_compare.DecisionReasonSummary( + canonical_code="bulk_time_support_mismatch" + ), + } + + assert nvbench_compare.format_reason_legend_entries(reason_legend) == [ + "bt-sup-miss = bulk_time_support_mismatch" + ] + + +@pytest.mark.parametrize("ref_time, cmp_time", [(None, 1.0), (1.0, None), (0.0, 1.0)]) +def test_compare_gpu_timings_rejects_unusable_centers( + nvbench_compare, ref_time, cmp_time +): + assert ( + nvbench_compare.compare_gpu_timings( + make_gpu_timing_data(nvbench_compare, mean=ref_time), + make_gpu_timing_data(nvbench_compare, mean=cmp_time), + ) + is None + ) + + +def test_compare_benches_reports_regression_when_robust_intervals_and_clock_confirm( + monkeypatch, nvbench_compare +): + run_data = make_comparison_run_data(nvbench_compare) + + ref_state = make_state(nvbench_compare, "state", mean="1.0", noise="0.01") + ref_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MIN_TAG", "0.9"), + make_summary(nvbench_compare, "GPU_TIME_Q1_TAG", "0.95"), + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.0"), + make_summary(nvbench_compare, "GPU_TIME_Q3_TAG", "1.05"), + make_summary(nvbench_compare, "GPU_TIME_IQR_RELATIVE_TAG", "0.01"), + make_summary(nvbench_compare, "GPU_SM_CLOCK_RATE_MEAN_TAG", "100.0"), + ] + ) + cmp_state = make_state(nvbench_compare, "state", mean="1.0", noise="0.01") + cmp_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MIN_TAG", "1.15"), + make_summary(nvbench_compare, "GPU_TIME_Q1_TAG", "1.18"), + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.2"), + make_summary(nvbench_compare, "GPU_TIME_Q3_TAG", "1.25"), + make_summary(nvbench_compare, "GPU_TIME_IQR_RELATIVE_TAG", "0.01"), + make_summary(nvbench_compare, "GPU_SM_CLOCK_RATE_MEAN_TAG", "100.0"), + ] + ) + + nvbench_compare.compare_benches( + run_data, + [make_benchmark([ref_state])], + [make_benchmark([cmp_state])], + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + assert run_data.stats.config_count == 1 + assert run_data.stats.pass_count == 0 + assert run_data.stats.improvement_count == 0 + assert run_data.stats.regression_count == 1 + assert run_data.stats.undecided_count == 0 + assert run_data.stats.unknown_count == 0 + + +def test_compare_benches_accepts_custom_comparison_thresholds( + monkeypatch, nvbench_compare +): + run_data = make_comparison_run_data(nvbench_compare) + + ref_state = make_state(nvbench_compare, "state", mean="1.0", noise="0.01") + ref_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MIN_TAG", "0.99"), + make_summary(nvbench_compare, "GPU_TIME_Q1_TAG", "0.995"), + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.0"), + make_summary(nvbench_compare, "GPU_TIME_Q3_TAG", "1.01"), + make_summary(nvbench_compare, "GPU_TIME_IQR_RELATIVE_TAG", "0.01"), + ] + ) + cmp_state = make_state(nvbench_compare, "state", mean="1.01", noise="0.01") + cmp_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MIN_TAG", "1.0"), + make_summary(nvbench_compare, "GPU_TIME_Q1_TAG", "1.005"), + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.01"), + make_summary(nvbench_compare, "GPU_TIME_Q3_TAG", "1.02"), + make_summary(nvbench_compare, "GPU_TIME_IQR_RELATIVE_TAG", "0.01"), + ] + ) + + nvbench_compare.compare_benches( + run_data, + [make_benchmark([ref_state])], + [make_benchmark([cmp_state])], + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + comparison_thresholds=nvbench_compare.ComparisonThresholds( + same_center_relative=0.02 + ), + ) + + assert run_data.stats.config_count == 1 + assert run_data.stats.pass_count == 1 + assert run_data.stats.undecided_count == 0 + + +def test_compare_benches_marks_unavailable_noise_undecided( + monkeypatch, nvbench_compare +): + run_data = make_comparison_run_data(nvbench_compare) + + missing_noise_ref = make_state(nvbench_compare, "missing_noise") + missing_noise_ref["summaries"] = [ + make_summary(nvbench_compare, "GPU_TIME_MEAN_TAG", "1.0") + ] + missing_noise_cmp = make_state(nvbench_compare, "missing_noise") + missing_noise_cmp["summaries"] = [ + make_summary(nvbench_compare, "GPU_TIME_MEAN_TAG", "1.001") + ] + + null_noise_ref = make_state(nvbench_compare, "null_noise") + null_noise_ref["summaries"] = [ + make_summary(nvbench_compare, "GPU_TIME_MEAN_TAG", "1.0"), + make_summary(nvbench_compare, "GPU_TIME_STDEV_RELATIVE_TAG", None), + ] + null_noise_cmp = make_state(nvbench_compare, "null_noise") + null_noise_cmp["summaries"] = [ + make_summary(nvbench_compare, "GPU_TIME_MEAN_TAG", "1.001"), + make_summary(nvbench_compare, "GPU_TIME_STDEV_RELATIVE_TAG", None), + ] + + nvbench_compare.compare_benches( + run_data, + [make_benchmark([missing_noise_ref, null_noise_ref])], + [make_benchmark([missing_noise_cmp, null_noise_cmp])], + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + assert run_data.stats.config_count == 2 + assert run_data.stats.pass_count == 0 + assert run_data.stats.improvement_count == 0 + assert run_data.stats.regression_count == 0 + assert run_data.stats.undecided_count == 2 + assert run_data.stats.unknown_count == 0 + + +def test_plot_along_skips_states_without_selected_axis(monkeypatch, nvbench_compare): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "with_axis", axis_value=1), + make_state(nvbench_compare, "without_axis"), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "with_axis", axis_value=1), + make_state(nvbench_compare, "without_axis"), + ] + ) + ] + + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along="A", + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + assert run_data.stats.config_count == 2 + assert run_data.stats.pass_count == 0 + assert run_data.stats.improvement_count == 0 + assert run_data.stats.regression_count == 0 + assert run_data.stats.undecided_count == 2 + assert run_data.stats.unknown_count == 0 + + +def test_device_filter_parser_accepts_all_and_duplicate_ids(nvbench_compare): + assert nvbench_compare.parse_device_filter(" all ", "--reference-devices") is None + assert nvbench_compare.parse_device_filter("0", "--reference-devices") == [0] + assert nvbench_compare.parse_device_filter("0, 2,0", "--reference-devices") == [ + 0, + 2, + 0, + ] + + +@pytest.mark.parametrize( + "device_arg", + [ + "", + " ", + "gpu", + "-1", + "0,gpu", + "0,-1", + "0,", + ",0", + ], +) +def test_device_filter_parser_rejects_invalid_values(nvbench_compare, device_arg): + with pytest.raises(ValueError, match="must be 'all'"): + nvbench_compare.parse_device_filter(device_arg, "--reference-devices") + + +def test_explicit_device_filters_downgrade_device_mismatch_to_warning(nvbench_compare): + assert nvbench_compare.require_matching_device_sections(None, None) + assert not nvbench_compare.require_matching_device_sections([0], None) + assert not nvbench_compare.require_matching_device_sections(None, [1]) + assert not nvbench_compare.require_matching_device_sections([0], [1]) + + +def test_compare_benches_pairs_filtered_devices_by_position( + monkeypatch, nvbench_compare +): + run_data = make_comparison_run_data( + nvbench_compare, + ref_devices=[ + {"id": 0, "name": "Reference GPU 0"}, + {"id": 1, "name": "Reference GPU 1"}, + ], + cmp_devices=[ + {"id": 0, "name": "Compare GPU 0"}, + {"id": 1, "name": "Compare GPU 1"}, + ], + ) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "Device=0", mean="1.0", device=0), + make_state(nvbench_compare, "Device=1", mean="9.0", device=1), + ] + ) + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "Device=0", mean="9.0", device=0), + make_state(nvbench_compare, "Device=1", mean="1.0", device=1), + ] + ) + ] + + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + reference_device_filter=[0], + compare_device_filter=[1], + ) + + assert run_data.stats.config_count == 1 + assert run_data.stats.pass_count == 0 + assert run_data.stats.improvement_count == 0 + assert run_data.stats.regression_count == 0 + assert run_data.stats.undecided_count == 1 + assert run_data.stats.unknown_count == 0 + + +def test_axis_filter_applies_to_most_recent_benchmark(monkeypatch, nvbench_compare): + run_data = make_comparison_run_data(nvbench_compare) + + ref_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ], + name="bench1", + ), + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="3.0", axis_value=1), + make_state(nvbench_compare, "state", mean="4.0", axis_value=2), + ], + name="bench2", + ), + ] + cmp_benches = [ + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="1.0", axis_value=1), + make_state(nvbench_compare, "state", mean="2.0", axis_value=2), + ], + name="bench1", + ), + make_benchmark( + [ + make_state(nvbench_compare, "state", mean="3.0", axis_value=1), + make_state(nvbench_compare, "state", mean="4.0", axis_value=2), + ], + name="bench2", + ), + ] + + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan( + nvbench_compare, + [("benchmark", "bench1"), ("axis", "A=2"), ("benchmark", "bench2")], + ), + no_color=True, + ) + + assert run_data.stats.config_count == 3 + assert run_data.stats.pass_count == 0 + assert run_data.stats.improvement_count == 0 + assert run_data.stats.regression_count == 0 + assert run_data.stats.undecided_count == 3 + assert run_data.stats.unknown_count == 0 + + +def test_main_returns_success_exit_code_when_regressions_are_detected( + monkeypatch, capsys, nvbench_compare +): + devices = [{"id": 0, "name": "Test GPU"}] + ref_state = make_state(nvbench_compare, "state", mean="1.0") + ref_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MIN_TAG", "0.9"), + make_summary(nvbench_compare, "GPU_TIME_Q1_TAG", "0.95"), + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.0"), + make_summary(nvbench_compare, "GPU_TIME_Q3_TAG", "1.05"), + make_summary(nvbench_compare, "GPU_SM_CLOCK_RATE_MEAN_TAG", "100.0"), + ] + ) + cmp_state = make_state(nvbench_compare, "state", mean="1.2") + cmp_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MIN_TAG", "1.15"), + make_summary(nvbench_compare, "GPU_TIME_Q1_TAG", "1.18"), + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.2"), + make_summary(nvbench_compare, "GPU_TIME_Q3_TAG", "1.25"), + make_summary(nvbench_compare, "GPU_SM_CLOCK_RATE_MEAN_TAG", "100.0"), + ] + ) + ref_root = { + "devices": devices, + "benchmarks": [make_benchmark([ref_state])], + } + cmp_root = { + "devices": devices, + "benchmarks": [make_benchmark([cmp_state])], + } + + def read_file(path): + return ref_root if path == "ref.json" else cmp_root + + monkeypatch.setattr(nvbench_compare.reader, "read_file", read_file) + monkeypatch.setattr(sys, "argv", ["nvbench_compare", "ref.json", "cmp.json"]) + + assert nvbench_compare.main() == 0 + assert "Regression (clear timing gap, %Diff > 0): 1" in capsys.readouterr().out + + +def test_main_prints_undecided_reason_summary(monkeypatch, capsys, nvbench_compare): + devices = [{"id": 0, "name": "Test GPU"}] + ref_root = { + "devices": devices, + "benchmarks": [ + make_benchmark([make_state(nvbench_compare, "state", noise="0.05")]) + ], + } + cmp_root = { + "devices": devices, + "benchmarks": [ + make_benchmark( + [make_state(nvbench_compare, "state", mean="1.01", noise="0.05")] + ) + ], + } + + def read_file(path): + return ref_root if path == "ref.json" else cmp_root + + monkeypatch.setattr(nvbench_compare.reader, "read_file", read_file) + monkeypatch.setattr( + sys, "argv", ["nvbench_compare", "--display", "explain", "ref.json", "cmp.json"] + ) + + assert nvbench_compare.main() == 0 + output = capsys.readouterr().out + assert "Undecided (comparison requires more evidence): 1" in output + assert "noise_too_high: 1" in output + assert "Reason legend: noise-high = noise_too_high" in output + + +def test_get_comparison_thresholds_returns_named_presets(nvbench_compare): + default = nvbench_compare.get_comparison_thresholds("default") + strict = nvbench_compare.get_comparison_thresholds("strict") + permissive = nvbench_compare.get_comparison_thresholds("permissive") + + assert default == nvbench_compare.ComparisonThresholds( + **nvbench_compare.COMPARISON_THRESHOLD_PRESET_VALUES["default"] + ) + assert strict.clear_gap_relative > default.clear_gap_relative + assert strict.same_center_relative < default.same_center_relative + assert strict.bulk_same_sample_coverage > default.bulk_same_sample_coverage + assert permissive.clear_gap_relative < default.clear_gap_relative + assert permissive.same_center_relative > default.same_center_relative + assert permissive.bulk_same_support_coverage < default.bulk_same_support_coverage + + +def test_dump_comparison_config_uses_grouped_toml(nvbench_compare): + config = nvbench_compare.dump_comparison_config( + "default", nvbench_compare.get_comparison_thresholds("default") + ) + + assert "version = 1\n" in config + assert '[preset]\nname = "default"\n' in config + assert "[clear_gap]\nrelative = 0.005\n" in config + assert "[same]\n" in config + assert "[bulk]\n" in config + assert "sample_coverage = 0.97\n" in config + assert "[bulk.rare_support]\n" in config + + +def test_resolve_comparison_thresholds_applies_config_overrides( + monkeypatch, nvbench_compare +): + def read_config(_): + return ( + "strict", + { + "bulk_same_sample_coverage": 0.93, + "bulk_support_max_removed_sample_fraction": 0.02, + }, + ) + + monkeypatch.setattr(nvbench_compare, "read_comparison_config_file", read_config) + + preset, thresholds = nvbench_compare.resolve_comparison_thresholds( + None, "settings.toml" + ) + assert preset == "strict" + assert thresholds.clear_gap_relative == pytest.approx( + nvbench_compare.get_comparison_thresholds("strict").clear_gap_relative + ) + assert thresholds.bulk_same_sample_coverage == pytest.approx(0.93) + assert thresholds.bulk_support_max_removed_sample_fraction == pytest.approx(0.02) + + preset, thresholds = nvbench_compare.resolve_comparison_thresholds( + "permissive", "settings.toml" + ) + assert preset == "permissive" + assert thresholds.clear_gap_relative == pytest.approx( + nvbench_compare.get_comparison_thresholds("permissive").clear_gap_relative + ) + assert thresholds.bulk_same_sample_coverage == pytest.approx(0.93) + assert thresholds.bulk_support_max_removed_sample_fraction == pytest.approx(0.02) + + +def test_parse_comparison_config_data_validates_grouped_thresholds(nvbench_compare): + preset, overrides = nvbench_compare.parse_comparison_config_data( + { + "version": 1, + "preset": {"name": "strict"}, + "clear_gap": {"relative": 0.01}, + "same": { + "center_relative": 0.002, + "overlap_fraction": 0.75, + "relative_dispersion_ceiling": 0.02, + }, + "bulk": { + "sample_coverage": 0.99, + "support_coverage": 0.8, + "rare_support": { + "sample_fraction": 0.001, + "max_removed_sample_fraction": 0.01, + }, + }, + } + ) + + assert preset == "strict" + assert overrides == { + "clear_gap_relative": 0.01, + "same_center_relative": 0.002, + "same_overlap_fraction": 0.75, + "same_relative_dispersion_ceiling": 0.02, + "bulk_same_sample_coverage": 0.99, + "bulk_same_support_coverage": 0.8, + "bulk_support_rare_sample_fraction": 0.001, + "bulk_support_max_removed_sample_fraction": 0.01, + } + + +@pytest.mark.parametrize( + "config_data, match", + [ + ({}, "version"), + ({"version": 2}, "unsupported"), + ({"version": 1, "rare_support": {}}, "unknown top-level"), + ({"version": 1, "bulk": {"unknown": 0.1}}, r"\[bulk\]"), + ({"version": 1, "clear_gap": {"rare_support": {}}}, r"\[clear_gap\]"), + ({"version": 1, "bulk": {"sample_coverage": 1.5}}, "<= 1"), + ({"version": 1, "same": {"center_relative": "tight"}}, "finite number"), + ({"version": 1, "preset": {"name": "aggressive"}}, "unknown comparison preset"), + ], +) +def test_parse_comparison_config_data_rejects_invalid_config( + nvbench_compare, config_data, match +): + with pytest.raises(ValueError, match=match): + nvbench_compare.parse_comparison_config_data(config_data) + + +def test_read_comparison_config_file_parses_toml_when_parser_is_available( + tmp_path, nvbench_compare +): + parser_module = "tomllib" if sys.version_info >= (3, 11) else "tomli" + pytest.importorskip(parser_module) + config_path = tmp_path / "settings.toml" + config_path.write_text( + """ +version = 1 + +[preset] +name = "strict" + +[bulk] +sample_coverage = 0.93 +""", + encoding="utf-8", + ) + + preset, overrides = nvbench_compare.read_comparison_config_file(config_path) + + assert preset == "strict" + assert overrides == {"bulk_same_sample_coverage": 0.93} + + +def test_main_dump_config_does_not_require_input_files( + monkeypatch, capsys, nvbench_compare +): + def read_file(_): + raise AssertionError("dump-config should not read JSON files") + + monkeypatch.setattr(nvbench_compare.reader, "read_file", read_file) + monkeypatch.setattr( + sys, + "argv", + ["nvbench_compare", "--preset", "strict", "--dump-config"], + ) + + assert nvbench_compare.main() == 0 + output = capsys.readouterr().out + assert 'name = "strict"' in output + assert "[bulk.rare_support]" in output + + +def test_main_dump_config_merges_config_and_cli_preset( + monkeypatch, capsys, nvbench_compare +): + def read_config(_): + return ("strict", {"bulk_same_sample_coverage": 0.93}) + + monkeypatch.setattr(nvbench_compare, "read_comparison_config_file", read_config) + monkeypatch.setattr( + sys, + "argv", + [ + "nvbench_compare", + "--config", + "settings.toml", + "--preset", + "permissive", + "--dump-config", + ], + ) + + assert nvbench_compare.main() == 0 + output = capsys.readouterr().out + assert 'name = "permissive"' in output + assert "relative = 0.0025" in output + assert "sample_coverage = 0.93" in output + + +def test_main_prints_bulk_debug_python_to_stdout(monkeypatch, capsys, nvbench_compare): + devices = [{"id": 0, "name": "Test GPU"}] + root = { + "devices": devices, + "benchmarks": [], + } + + monkeypatch.setattr(nvbench_compare.reader, "read_file", lambda _: root) + + def fake_compare_benches(*args, **kwargs): + kwargs["bulk_debug_rows"].append( + { + "row_index": 0, + "status": "UNDECIDED", + "reference_sample_filename": None, + "reference_sample_count": None, + "reference_frequency_filename": None, + "reference_frequency_count": None, + "compare_sample_filename": None, + "compare_sample_count": None, + "compare_frequency_filename": None, + "compare_frequency_count": None, + } + ) + + monkeypatch.setattr(nvbench_compare, "compare_benches", fake_compare_benches) + monkeypatch.setattr( + sys, + "argv", + [ + "nvbench_compare", + "--bulk-debug-python", + "STDOUT", + "ref.json", + "cmp.json", + ], + ) + + assert nvbench_compare.main() == 0 + output = capsys.readouterr().out + assert "bulk_rows = [" in output + assert "'status': 'UNDECIDED'" in output + assert "def load_bulk_data(row):" in output + + +def test_compare_benches_defaults_to_interval_display(monkeypatch, nvbench_compare): + run_data = make_comparison_run_data(nvbench_compare) + captured = {} + + def fake_tabulate(rows, headers, *args, **kwargs): + captured["rows"] = rows + captured["headers"] = headers + return "" + + monkeypatch.setattr(nvbench_compare.tabulate, "tabulate", fake_tabulate) + + ref_benches = [make_benchmark([make_state(nvbench_compare, "state", mean="1.0")])] + cmp_benches = [make_benchmark([make_state(nvbench_compare, "state", mean="1.01")])] + + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + ) + + assert captured["headers"][-4:] == ["Ref", "Cmp", "Change", "Status"] + row = captured["rows"][0] + assert row[-4].startswith("1.000 s") + assert row[-3].startswith("1.010 s") + assert row[-2] == "" + + +def test_compare_benches_legacy_display_uses_scalar_diff(monkeypatch, nvbench_compare): + run_data = make_comparison_run_data(nvbench_compare) + captured = {} + + def fake_tabulate(rows, headers, *args, **kwargs): + captured["rows"] = rows + captured["headers"] = headers + return "" + + monkeypatch.setattr(nvbench_compare.tabulate, "tabulate", fake_tabulate) + + ref_benches = [make_benchmark([make_state(nvbench_compare, "state", mean="1.0")])] + cmp_benches = [make_benchmark([make_state(nvbench_compare, "state", mean="1.01")])] + + nvbench_compare.compare_benches( + run_data, + ref_benches, + cmp_benches, + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + display="legacy", + ) + + assert captured["headers"][-7:] == [ + "Ref Time", + "Ref Noise", + "Cmp Time", + "Cmp Noise", + "Diff", + "%Diff", + "Status", + ] + row = captured["rows"][0] + assert row[-7] == "1.000 s" + assert row[-5] == "1.010 s" + assert row[-3] == "10.000 ms" + assert row[-2] == "1.00%" + + +def test_compare_benches_explain_display_uses_explicit_intervals( + monkeypatch, nvbench_compare +): + run_data = make_comparison_run_data(nvbench_compare) + captured = {} + + def fake_tabulate(rows, headers, *args, **kwargs): + captured["rows"] = rows + captured["headers"] = headers + return "" + + monkeypatch.setattr(nvbench_compare.tabulate, "tabulate", fake_tabulate) + + ref_state = make_state(nvbench_compare, "state", mean="1.0") + ref_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MIN_TAG", "1.0"), + make_summary(nvbench_compare, "GPU_TIME_Q1_TAG", "1.01"), + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.02"), + make_summary(nvbench_compare, "GPU_TIME_Q3_TAG", "1.03"), + make_summary(nvbench_compare, "GPU_SM_CLOCK_RATE_MEAN_TAG", "100.0"), + ] + ) + cmp_state = make_state(nvbench_compare, "state", mean="1.01") + cmp_state["summaries"].extend( + [ + make_summary(nvbench_compare, "GPU_TIME_MIN_TAG", "1.01"), + make_summary(nvbench_compare, "GPU_TIME_Q1_TAG", "1.02"), + make_summary(nvbench_compare, "GPU_TIME_MEDIAN_TAG", "1.03"), + make_summary(nvbench_compare, "GPU_TIME_Q3_TAG", "1.04"), + make_summary(nvbench_compare, "GPU_SM_CLOCK_RATE_MEAN_TAG", "100.0"), + ] + ) + + nvbench_compare.compare_benches( + run_data, + [make_benchmark([ref_state])], + [make_benchmark([cmp_state])], + threshold=0.0, + plot_along=None, + plot=False, + dark=False, + filter_plan=make_filter_plan(nvbench_compare), + no_color=True, + display="explain", + ) + + assert captured["headers"][-7:] == [ + "Ref [Lo | Ce | Hi]", + "Cmp [Lo | Ce | Hi]", + "Ref Noise", + "Cmp Noise", + "Reason", + "Change", + "Status", + ] + row = captured["rows"][0] + assert row[-7] == "1.0[00 | 20 | 30] s" + assert row[-6] == "1.0[10 | 30 | 40] s" + assert row[-3] == "centers-far" + assert row[-2] == "" + + +def test_main_passes_selected_preset_to_compare_benches(monkeypatch, nvbench_compare): + devices = [{"id": 0, "name": "Test GPU"}] + root = { + "devices": devices, + "benchmarks": [], + } + captured = {} + + monkeypatch.setattr(nvbench_compare.reader, "read_file", lambda _: root) + + def fake_compare_benches(*args, **kwargs): + captured["comparison_thresholds"] = kwargs["comparison_thresholds"] + captured["display"] = kwargs["display"] + + monkeypatch.setattr(nvbench_compare, "compare_benches", fake_compare_benches) + monkeypatch.setattr( + sys, + "argv", + [ + "nvbench_compare", + "--preset", + "strict", + "--display", + "explain", + "ref.json", + "cmp.json", + ], + ) + + assert nvbench_compare.main() == 0 + assert captured[ + "comparison_thresholds" + ] == nvbench_compare.get_comparison_thresholds("strict") + assert captured["display"] == "explain"