From 61fc25d1ff0b557c79673baacb8676077b5540ce Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Wed, 10 Jun 2026 14:43:08 +0200 Subject: [PATCH 1/7] reactivated assert --- src/weathergen/model/diffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index d5f971b2f..473748712 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -70,10 +70,10 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"'{self.conditioning}' (got offset={_offset})" ) _input_num_steps = self.cf.get("training_config", {}).get("model_input", {}).get("forecasting", {}).get("num_steps_input", 0) - # assert self.conditioning != "forecast" or _input_num_steps == 2, ( - # f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is " - # f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" - # ) + assert self.conditioning != "forecast" or _input_num_steps == 2, ( + f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is " + f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" + ) assert self.conditioning not in ["date_time", "date", "time"] or _input_num_steps == 1, ( f"forecast.input_num_steps must be 1 when fe_diffusion_model_conditioning is " f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" From 15b06ba3d5699066f16e825312471269cd9b9015 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 11 Jun 2026 16:50:24 +0200 Subject: [PATCH 2/7] score fixes --- config/evaluate/eval_config_diffusion.yml | 6 +- .../evaluate/docs/ensemble_metrics_notes.md | 73 +++++++++++++++++++ .../evaluate/plotting/plot_orchestration.py | 26 +++++-- .../src/weathergen/evaluate/scores/score.py | 21 ++++-- 4 files changed, 112 insertions(+), 14 deletions(-) create mode 100644 packages/evaluate/docs/ensemble_metrics_notes.md diff --git a/config/evaluate/eval_config_diffusion.yml b/config/evaluate/eval_config_diffusion.yml index 15a567af2..7c0e9700c 100644 --- a/config/evaluate/eval_config_diffusion.yml +++ b/config/evaluate/eval_config_diffusion.yml @@ -18,14 +18,14 @@ # vmax: 40 evaluation: - metrics : ["rmse", "mae"] - regions: ["global", "nhem"] + metrics : ["ssr"] #spread + ssr are ensemble (probabilistic) metrics + regions: ["global"] summary_plots : true ratio_plots : false heat_maps : false summary_dir: "./plots/" plot_ensemble: "members" #supported: false, "std", "minmax", "members" - plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + plot_score_maps: true #plot scores on a 2D maps (incl. ensemble spread map). it slows down score computation print_summary: false #print out score values on screen. it can be verbose log_scale: false add_grid: false diff --git a/packages/evaluate/docs/ensemble_metrics_notes.md b/packages/evaluate/docs/ensemble_metrics_notes.md new file mode 100644 index 000000000..18b6f34ab --- /dev/null +++ b/packages/evaluate/docs/ensemble_metrics_notes.md @@ -0,0 +1,73 @@ +# Ensemble spread-skill metrics: implementation notes & divergence from GenCast + +This note documents the **current** behaviour of the ensemble spread / spread-skill metrics in +`packages/evaluate` and how they differ from the GenCast definitions. The metrics were enabled +(they were previously dead code, see below) **without changing their mathematical definitions** — +this document is the basis for a later decision on whether to align them with GenCast. + +## What was enabled + +The probabilistic metrics (`spread`, `ssr`, `crps`, `rank_histogram`) are registered in +`Scores.prob_metrics_dict` but were unreachable: `Scores.get_score` referenced an undefined +`self.ens_dim` and unconditionally `return None`ed before dispatching. The dispatch branch was +fixed to skip cleanly when the ensemble dim is absent and otherwise call the metric. No metric +definition was changed. + +Enabled in `config/evaluate/eval_config_diffusion.yml`: +- `evaluation.metrics: [..., "spread", "ssr"]` → spread-skill line plots per `channel` + (variable+level) vs `lead_time`. +- `evaluation.plot_score_maps: true` → ensemble spread **map** per `channel`/`forecast_step` + (the score-map path computes each metric with `agg_dims="sample"`, keeping `ipoint`). + +## Current definitions (in `scores/score.py`) + +```python +def calc_spread(self, p): + ens_std = p.std(dim="ens") # xarray default ddof=0 + return self._mean(np.sqrt(ens_std**2)) # sqrt(std**2) == std; mean over agg_dims + +def calc_ssr(self, p, gt): + return self.calc_spread(p) / self.calc_rmse(p, gt) +``` + +`self._mean` is an **unweighted** `mean` over `agg_dims` (default `ipoint` in the summary path, +`sample` in the score-map path). `calc_rmse` is `sqrt(mean((p - gt)**2))`. + +## GenCast reference (arXiv:2312.15796) + +For forecast times `k`, grid points `i` with latitude weights `a_i`, ensemble members `m`, +ensemble mean `x̄`, truth `y`: + +- Spread = `sqrt( mean_k [ Σ_i a_i · (1/(M-1)) Σ_m (x_{i,k}^m - x̄_{i,k})² ] )` +- Skill (ensemble-mean RMSE) = `sqrt( mean_k [ Σ_i a_i · (x̄_{i,k} - y_{i,k})² ] )` +- Spread/skill ratio = Spread / Skill; for a calibrated M-member ensemble this equals + `sqrt((M+1)/M)`. + +## Divergences (current → GenCast) + +1. **Spread aggregation order.** Current returns `mean_i( std_m )` (spatial mean of the per-point + std). GenCast uses `sqrt( mean_i( var_m ) )`. In general `mean(std) ≠ sqrt(mean(var))`. +2. **Unbiased variance.** Current `std`/`var` use xarray default `ddof=0`; GenCast uses `1/(M-1)` + (`ddof=1`). +3. **Latitude weighting.** Current uses an unweighted mean over `ipoint`; GenCast applies + `cos(lat)` weights `a_i`. (Deliberately kept unweighted for consistency with the existing + `rmse`/`mae` in this repo.) +4. **sqrt vs forecast-time averaging order.** The summary path keeps the `sample` dimension and + averages it at plot time (`LinePlots` averages over all non-x dims), i.e. *sqrt then average + over forecast times*. GenCast averages over forecast times *inside* the sqrt. +5. **SSR denominator.** *(Aligned with GenCast.)* `calc_ssr` now divides by + `calc_rmse(p.mean("ens"), gt)` — the RMSE of the ensemble **mean** (the skill), so SSR is a + single value per variable-level-fstep with the standard calibration interpretation + (under-/over-dispersion around the `sqrt((M+1)/M)`-corrected target). Previously it divided by + the per-member RMSE (`calc_rmse(p, gt)` with the ensemble still present), which produced one + ratio per ensemble member. +6. **No `sqrt((M+1)/M)` correction.** SSR is the raw spread/skill ratio, so the calibrated target + line is `sqrt((M+1)/M)`, not 1. + +## Spread map + +The spread map is the same `calc_spread` evaluated with `agg_dims="sample"` (reduces `ens` and +`sample`, keeps `ipoint`), routed through `plot_score_maps_per_stream` → +`Plotter.scatter_plot`. Per point it is `mean_sample( std_m )` rather than +`sqrt( mean_sample( var_m ) )` — the same divergences (1)–(2) apply. No latitude weighting is +relevant here since it is a spatial field, not a spatial average. diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index f32056e47..2c87fac9d 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -177,6 +177,11 @@ def _plot_score_maps_per_stream( if not valid: return + # Metrics that retain a per-member ensemble dimension (e.g. rmse, mae) vs those that + # reduce it (e.g. spread, crps). Mixing both in the concat below broadcasts the reduced + # ones across the ens dim, so we track membership to plot each metric correctly. + ens_metrics = {m for m, r in valid if "ens" in r.dims} + plot_metrics = xr.concat( [r for _, r in valid], dim="metric", @@ -189,14 +194,20 @@ def _plot_score_maps_per_stream( metric=[m for m, _ in valid], ).compute() - if "ens" in preds.dims: - plot_metrics["ens"] = preds.ens + # Restore the ensemble member labels only when the concatenated result actually keeps the + # ens dimension (i.e. at least one metric is per-member). Guarding on plot_metrics rather + # than preds avoids a CoordinateValidationError when every valid metric reduced the ens dim. + if "ens" in plot_metrics.dims: + plot_metrics = plot_metrics.assign_coords(ens=preds.ens.values) - has_ens = "ens" in plot_metrics.coords - ens_values = plot_metrics.coords["ens"].values if has_ens else [None] + all_ens = plot_metrics.coords["ens"].values if "ens" in plot_metrics.dims else [None] plot_tasks: list[dict] = [] for metric in plot_metrics.coords["metric"].values: + # Per-member maps only for metrics that kept the ens dim; ens-reduced metrics (spread, + # crps, ...) get a single map even when concat broadcast them across ens. + metric_has_ens = str(metric) in ens_metrics and "ens" in plot_metrics.dims + ens_values = all_ens if metric_has_ens else [None] for ens_val in ens_values: tag = f"score_maps_{metric}_fstep_{fstep}" + ( f"_ens_{ens_val}" if ens_val is not None else "" @@ -205,7 +216,12 @@ def _plot_score_maps_per_stream( sel = {"metric": metric, "channel": channel} if ens_val is not None: sel["ens"] = ens_val - data = plot_metrics.sel(**sel).squeeze() + data = plot_metrics.sel(**sel) + if ens_val is None and "ens" in data.dims: + # Collapse the broadcast ens dim for an ens-reduced metric (values are + # identical across members) to a single field. + data = data.isel(ens=0, drop=True) + data = data.squeeze() title = f"{metric} - {channel}: fstep {fstep}" + ( f", ens {ens_val}" if ens_val is not None else "" ) diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 6f8faafc7..afde4fe2f 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -254,12 +254,15 @@ def get_score( f = self.det_metrics_dict[score_name] _logger.debug(f"Using deterministic metric: {score_name}") elif score_name in self.prob_metrics_dict.keys(): - assert self.ens_dim in data.prediction.dims, ( - f"Probablistic score {score_name} chosen, but ensemble dimension {self.ens_dim} " - "not found in prediction data. Skipping score calculation." - ) - return None + if self._ens_dim not in data.prediction.dims: + _logger.warning( + f"Probabilistic score '{score_name}' chosen, but ensemble dimension " + f"'{self._ens_dim}' not found in prediction data dims " + f"{data.prediction.dims}. Skipping score calculation." + ) + return None f = self.prob_metrics_dict[score_name] + _logger.debug(f"Using probabilistic metric: {score_name}") else: raise ValueError( f"Unknown score chosen. Supported scores: { @@ -1273,7 +1276,13 @@ def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: xr.DataArray Spread-Skill Ratio (SSR) """ - ssr = self.calc_spread(p) / self.calc_rmse(p, gt) # spread/rmse + # Spread-skill ratio = spread / skill, where the "skill" is the RMSE of the ensemble + # MEAN (GenCast / WeatherBench2 convention), not the per-member RMSE. Reducing the + # ensemble dim in the denominator (to match the already ens-reduced spread numerator) + # yields a single value per variable-level-fstep with a clean calibration interpretation, + # rather than one ratio per ensemble member. + ens_mean = p.mean(dim=self._ens_dim) + ssr = self.calc_spread(p) / self.calc_rmse(ens_mean, gt) return ssr From 26a22e9ef6219b2f42d196b9532c692d8c3f3f16 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 11 Jun 2026 17:01:14 +0200 Subject: [PATCH 3/7] more metrics --- config/evaluate/eval_config_diffusion.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/evaluate/eval_config_diffusion.yml b/config/evaluate/eval_config_diffusion.yml index 7c0e9700c..e94982561 100644 --- a/config/evaluate/eval_config_diffusion.yml +++ b/config/evaluate/eval_config_diffusion.yml @@ -18,7 +18,7 @@ # vmax: 40 evaluation: - metrics : ["ssr"] #spread + ssr are ensemble (probabilistic) metrics + metrics : ["rmse", "mae", "spread", "ssr", "crps"] #spread + ssr are ensemble (probabilistic) metrics regions: ["global"] summary_plots : true ratio_plots : false From 451e9b2abe9db4732879eb89085e7cc22d81f88f Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Sun, 14 Jun 2026 11:53:40 +0200 Subject: [PATCH 4/7] remove comments --- .../weathergen/evaluate/plotting/plot_orchestration.py | 10 ---------- .../evaluate/src/weathergen/evaluate/scores/score.py | 5 ----- 2 files changed, 15 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 2c87fac9d..4ded7b0d3 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -177,9 +177,6 @@ def _plot_score_maps_per_stream( if not valid: return - # Metrics that retain a per-member ensemble dimension (e.g. rmse, mae) vs those that - # reduce it (e.g. spread, crps). Mixing both in the concat below broadcasts the reduced - # ones across the ens dim, so we track membership to plot each metric correctly. ens_metrics = {m for m, r in valid if "ens" in r.dims} plot_metrics = xr.concat( @@ -194,9 +191,6 @@ def _plot_score_maps_per_stream( metric=[m for m, _ in valid], ).compute() - # Restore the ensemble member labels only when the concatenated result actually keeps the - # ens dimension (i.e. at least one metric is per-member). Guarding on plot_metrics rather - # than preds avoids a CoordinateValidationError when every valid metric reduced the ens dim. if "ens" in plot_metrics.dims: plot_metrics = plot_metrics.assign_coords(ens=preds.ens.values) @@ -204,8 +198,6 @@ def _plot_score_maps_per_stream( plot_tasks: list[dict] = [] for metric in plot_metrics.coords["metric"].values: - # Per-member maps only for metrics that kept the ens dim; ens-reduced metrics (spread, - # crps, ...) get a single map even when concat broadcast them across ens. metric_has_ens = str(metric) in ens_metrics and "ens" in plot_metrics.dims ens_values = all_ens if metric_has_ens else [None] for ens_val in ens_values: @@ -218,8 +210,6 @@ def _plot_score_maps_per_stream( sel["ens"] = ens_val data = plot_metrics.sel(**sel) if ens_val is None and "ens" in data.dims: - # Collapse the broadcast ens dim for an ens-reduced metric (values are - # identical across members) to a single field. data = data.isel(ens=0, drop=True) data = data.squeeze() title = f"{metric} - {channel}: fstep {fstep}" + ( diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index afde4fe2f..5b18db6bf 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -1276,11 +1276,6 @@ def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: xr.DataArray Spread-Skill Ratio (SSR) """ - # Spread-skill ratio = spread / skill, where the "skill" is the RMSE of the ensemble - # MEAN (GenCast / WeatherBench2 convention), not the per-member RMSE. Reducing the - # ensemble dim in the denominator (to match the already ens-reduced spread numerator) - # yields a single value per variable-level-fstep with a clean calibration interpretation, - # rather than one ratio per ensemble member. ens_mean = p.mean(dim=self._ens_dim) ssr = self.calc_spread(p) / self.calc_rmse(ens_mean, gt) From 68964622caf9e36a0464bcface387cfd68da7324 Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Sun, 14 Jun 2026 12:14:38 +0200 Subject: [PATCH 5/7] remove docs --- .../evaluate/docs/ensemble_metrics_notes.md | 73 ------------------- 1 file changed, 73 deletions(-) delete mode 100644 packages/evaluate/docs/ensemble_metrics_notes.md diff --git a/packages/evaluate/docs/ensemble_metrics_notes.md b/packages/evaluate/docs/ensemble_metrics_notes.md deleted file mode 100644 index 18b6f34ab..000000000 --- a/packages/evaluate/docs/ensemble_metrics_notes.md +++ /dev/null @@ -1,73 +0,0 @@ -# Ensemble spread-skill metrics: implementation notes & divergence from GenCast - -This note documents the **current** behaviour of the ensemble spread / spread-skill metrics in -`packages/evaluate` and how they differ from the GenCast definitions. The metrics were enabled -(they were previously dead code, see below) **without changing their mathematical definitions** — -this document is the basis for a later decision on whether to align them with GenCast. - -## What was enabled - -The probabilistic metrics (`spread`, `ssr`, `crps`, `rank_histogram`) are registered in -`Scores.prob_metrics_dict` but were unreachable: `Scores.get_score` referenced an undefined -`self.ens_dim` and unconditionally `return None`ed before dispatching. The dispatch branch was -fixed to skip cleanly when the ensemble dim is absent and otherwise call the metric. No metric -definition was changed. - -Enabled in `config/evaluate/eval_config_diffusion.yml`: -- `evaluation.metrics: [..., "spread", "ssr"]` → spread-skill line plots per `channel` - (variable+level) vs `lead_time`. -- `evaluation.plot_score_maps: true` → ensemble spread **map** per `channel`/`forecast_step` - (the score-map path computes each metric with `agg_dims="sample"`, keeping `ipoint`). - -## Current definitions (in `scores/score.py`) - -```python -def calc_spread(self, p): - ens_std = p.std(dim="ens") # xarray default ddof=0 - return self._mean(np.sqrt(ens_std**2)) # sqrt(std**2) == std; mean over agg_dims - -def calc_ssr(self, p, gt): - return self.calc_spread(p) / self.calc_rmse(p, gt) -``` - -`self._mean` is an **unweighted** `mean` over `agg_dims` (default `ipoint` in the summary path, -`sample` in the score-map path). `calc_rmse` is `sqrt(mean((p - gt)**2))`. - -## GenCast reference (arXiv:2312.15796) - -For forecast times `k`, grid points `i` with latitude weights `a_i`, ensemble members `m`, -ensemble mean `x̄`, truth `y`: - -- Spread = `sqrt( mean_k [ Σ_i a_i · (1/(M-1)) Σ_m (x_{i,k}^m - x̄_{i,k})² ] )` -- Skill (ensemble-mean RMSE) = `sqrt( mean_k [ Σ_i a_i · (x̄_{i,k} - y_{i,k})² ] )` -- Spread/skill ratio = Spread / Skill; for a calibrated M-member ensemble this equals - `sqrt((M+1)/M)`. - -## Divergences (current → GenCast) - -1. **Spread aggregation order.** Current returns `mean_i( std_m )` (spatial mean of the per-point - std). GenCast uses `sqrt( mean_i( var_m ) )`. In general `mean(std) ≠ sqrt(mean(var))`. -2. **Unbiased variance.** Current `std`/`var` use xarray default `ddof=0`; GenCast uses `1/(M-1)` - (`ddof=1`). -3. **Latitude weighting.** Current uses an unweighted mean over `ipoint`; GenCast applies - `cos(lat)` weights `a_i`. (Deliberately kept unweighted for consistency with the existing - `rmse`/`mae` in this repo.) -4. **sqrt vs forecast-time averaging order.** The summary path keeps the `sample` dimension and - averages it at plot time (`LinePlots` averages over all non-x dims), i.e. *sqrt then average - over forecast times*. GenCast averages over forecast times *inside* the sqrt. -5. **SSR denominator.** *(Aligned with GenCast.)* `calc_ssr` now divides by - `calc_rmse(p.mean("ens"), gt)` — the RMSE of the ensemble **mean** (the skill), so SSR is a - single value per variable-level-fstep with the standard calibration interpretation - (under-/over-dispersion around the `sqrt((M+1)/M)`-corrected target). Previously it divided by - the per-member RMSE (`calc_rmse(p, gt)` with the ensemble still present), which produced one - ratio per ensemble member. -6. **No `sqrt((M+1)/M)` correction.** SSR is the raw spread/skill ratio, so the calibrated target - line is `sqrt((M+1)/M)`, not 1. - -## Spread map - -The spread map is the same `calc_spread` evaluated with `agg_dims="sample"` (reduces `ens` and -`sample`, keeps `ipoint`), routed through `plot_score_maps_per_stream` → -`Plotter.scatter_plot`. Per point it is `mean_sample( std_m )` rather than -`sqrt( mean_sample( var_m ) )` — the same divergences (1)–(2) apply. No latitude weighting is -relevant here since it is a spatial field, not a spatial average. From a26fded62e59bfbf9cfdcf9c0ff11fc4a3c0524c Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Mon, 15 Jun 2026 19:49:41 +0200 Subject: [PATCH 6/7] implemented spread adj and ssr adj --- .../evaluate/plotting/line_plots.py | 8 ++- .../evaluate/plotting/plot_utils.py | 4 ++ .../src/weathergen/evaluate/scores/score.py | 62 +++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py index 1f99a3728..e11f52d1d 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py @@ -256,6 +256,7 @@ def plot( print_summary: bool = False, title: str | None = None, colors: list[str | None] | None = None, + line: float | None = None, ) -> None: """ Plot a line graph comparing multiple datasets. @@ -274,6 +275,9 @@ def plot( Name of the dimension to be used for the y-axis. print_summary: If True, print a summary of the values from the graph. + line: + If provided, draw a horizontal reference line at the given y-value + (e.g. the optimal value of a metric). Returns ------- None @@ -321,7 +325,9 @@ def plot( # TODO: generalise this for other x_dims by introducing a "units" # entry in the function if needed xunits = "hr" if x_dim == "lead_time" else None - self._plot_base(fig, name, x_dim, y_dim, print_summary, xunits=xunits, title=title) + self._plot_base( + fig, name, x_dim, y_dim, print_summary, line=line, xunits=xunits, title=title + ) def _plot_base( self, diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index dfb2b3e8f..c9f4052a4 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -390,6 +390,9 @@ def plot_metric_region( title = f"{metric.upper()} | {stream} | {ch}" + ref_line_dict = {"ssr_adj": 1.0} + line = ref_line_dict.get(metric) + plotter.plot( selected_data, labels, @@ -399,6 +402,7 @@ def plot_metric_region( print_summary=print_summary, title=title, colors=colors, + line=line, ) diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 5b18db6bf..cdadf0e8e 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -199,9 +199,11 @@ def __init__( } self.prob_metrics_dict = { "ssr": self.calc_ssr, + "ssr_adj": self.calc_ssr_adj, "crps": self.calc_crps, "rank_histogram": self.calc_rank_histogram, "spread": self.calc_spread, + "spread_adj": self.calc_spread_adj, } def get_score( @@ -1261,6 +1263,33 @@ def calc_spread(self, p: xr.DataArray, **kwargs) -> xr.DataArray: return self._mean(np.sqrt(ens_std**2)) + def calc_spread_adj(self, p: xr.DataArray, **kwargs) -> xr.DataArray: + """ + Calculate the (unbiased) ensemble spread following the GenCast convention. + + Defined as the square root of the mean *unbiased* (``ddof=1``) ensemble variance, matching + the spread of GenCast (Price et al., https://arxiv.org/pdf/2312.15796, Eq. A.6). The + finite-ensemble inflation factor ``sqrt((M + 1) / M)`` is *not* applied here; following + GenCast it is applied in the spread-skill ratio instead (see ``calc_ssr_adj``). + + Unlike the unadjusted ``calc_spread`` (which uses the biased ``ddof=0`` standard + deviation), this uses the unbiased variance, as required for the GenCast spread-skill + relation to hold. The unadjusted ``calc_spread`` is left unchanged. + + Parameters + ---------- + p: xr.DataArray + Forecast data array with ensemble dimension + + Returns + ------- + xr.DataArray + Unbiased ensemble spread (GenCast convention) + """ + ens_var = p.var(dim=self._ens_dim, ddof=1) + + return np.sqrt(self._mean(ens_var)) + def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ Calculate the Spread-Skill Ratio (SSR) of the forecast ensemble data w.r.t. reference data @@ -1281,6 +1310,39 @@ def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: return ssr + def calc_ssr_adj(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: + """ + Calculate the ensemble-size-adjusted Spread-Skill Ratio (SSR) of the forecast ensemble + data w.r.t. reference data. + + Following GenCast (Price et al., https://arxiv.org/pdf/2312.15796, Eq. A.9), this is + + SSR_adj = sqrt((M + 1) / M) * spread / RMSE(ensemble_mean) + + where ``spread`` is the unbiased ensemble spread (``calc_spread_adj``) and M is the + ensemble size. For a perfectly reliable ensemble of finite size M, the RMSE of the + ensemble mean is inflated relative to the spread by exactly ``sqrt((M + 1) / M)`` + (Fortin et al., 2014), so the ``sqrt((M + 1) / M)`` factor applied here makes a perfectly + calibrated ensemble yield an adjusted SSR of exactly 1. The original ``calc_spread`` / + ``calc_ssr`` are left unchanged. + + Parameters + ---------- + p: xr.DataArray + Forecast data array with ensemble dimension + gt: xr.DataArray + Ground truth data array + Returns + ------- + xr.DataArray + Ensemble-size-adjusted Spread-Skill Ratio (SSR). Optimal value is 1. + """ + ens_size = p.sizes[self._ens_dim] + correction = np.sqrt((ens_size + 1) / ens_size) + ens_mean = p.mean(dim=self._ens_dim) + + return correction * self.calc_spread_adj(p) / self.calc_rmse(ens_mean, gt) + def calc_crps( self, p: xr.DataArray, From 9ebac949047696fb6ef0fed251434f181595285b Mon Sep 17 00:00:00 2001 From: moritzhauschulz Date: Thu, 18 Jun 2026 10:48:37 +0200 Subject: [PATCH 7/7] ssr_adj in config eval --- config/evaluate/eval_config_diffusion.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/evaluate/eval_config_diffusion.yml b/config/evaluate/eval_config_diffusion.yml index e94982561..bb55bf133 100644 --- a/config/evaluate/eval_config_diffusion.yml +++ b/config/evaluate/eval_config_diffusion.yml @@ -18,7 +18,7 @@ # vmax: 40 evaluation: - metrics : ["rmse", "mae", "spread", "ssr", "crps"] #spread + ssr are ensemble (probabilistic) metrics + metrics : ["rmse", "mae", "spread", "ssr","ssr_adj", "crps"] #spread + ssr are ensemble (probabilistic) metrics regions: ["global"] summary_plots : true ratio_plots : false