diff --git a/config/evaluate/eval_config_diffusion.yml b/config/evaluate/eval_config_diffusion.yml index 15a567af2..bb55bf133 100644 --- a/config/evaluate/eval_config_diffusion.yml +++ b/config/evaluate/eval_config_diffusion.yml @@ -18,14 +18,14 @@ # vmax: 40 evaluation: - metrics : ["rmse", "mae"] - regions: ["global", "nhem"] + metrics : ["rmse", "mae", "spread", "ssr","ssr_adj", "crps"] #spread + ssr are ensemble (probabilistic) metrics + regions: ["global"] summary_plots : true ratio_plots : false heat_maps : false summary_dir: "./plots/" plot_ensemble: "members" #supported: false, "std", "minmax", "members" - plot_score_maps: false #plot scores on a 2D maps. it slows down score computation + plot_score_maps: true #plot scores on a 2D maps (incl. ensemble spread map). it slows down score computation print_summary: false #print out score values on screen. it can be verbose log_scale: false add_grid: false diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py index 1f99a3728..e11f52d1d 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/line_plots.py @@ -256,6 +256,7 @@ def plot( print_summary: bool = False, title: str | None = None, colors: list[str | None] | None = None, + line: float | None = None, ) -> None: """ Plot a line graph comparing multiple datasets. @@ -274,6 +275,9 @@ def plot( Name of the dimension to be used for the y-axis. print_summary: If True, print a summary of the values from the graph. + line: + If provided, draw a horizontal reference line at the given y-value + (e.g. the optimal value of a metric). Returns ------- None @@ -321,7 +325,9 @@ def plot( # TODO: generalise this for other x_dims by introducing a "units" # entry in the function if needed xunits = "hr" if x_dim == "lead_time" else None - self._plot_base(fig, name, x_dim, y_dim, print_summary, xunits=xunits, title=title) + self._plot_base( + fig, name, x_dim, y_dim, print_summary, line=line, xunits=xunits, title=title + ) def _plot_base( self, diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index f32056e47..4ded7b0d3 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -177,6 +177,8 @@ def _plot_score_maps_per_stream( if not valid: return + ens_metrics = {m for m, r in valid if "ens" in r.dims} + plot_metrics = xr.concat( [r for _, r in valid], dim="metric", @@ -189,14 +191,15 @@ def _plot_score_maps_per_stream( metric=[m for m, _ in valid], ).compute() - if "ens" in preds.dims: - plot_metrics["ens"] = preds.ens + if "ens" in plot_metrics.dims: + plot_metrics = plot_metrics.assign_coords(ens=preds.ens.values) - has_ens = "ens" in plot_metrics.coords - ens_values = plot_metrics.coords["ens"].values if has_ens else [None] + all_ens = plot_metrics.coords["ens"].values if "ens" in plot_metrics.dims else [None] plot_tasks: list[dict] = [] for metric in plot_metrics.coords["metric"].values: + metric_has_ens = str(metric) in ens_metrics and "ens" in plot_metrics.dims + ens_values = all_ens if metric_has_ens else [None] for ens_val in ens_values: tag = f"score_maps_{metric}_fstep_{fstep}" + ( f"_ens_{ens_val}" if ens_val is not None else "" @@ -205,7 +208,10 @@ def _plot_score_maps_per_stream( sel = {"metric": metric, "channel": channel} if ens_val is not None: sel["ens"] = ens_val - data = plot_metrics.sel(**sel).squeeze() + data = plot_metrics.sel(**sel) + if ens_val is None and "ens" in data.dims: + data = data.isel(ens=0, drop=True) + data = data.squeeze() title = f"{metric} - {channel}: fstep {fstep}" + ( f", ens {ens_val}" if ens_val is not None else "" ) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py index dfb2b3e8f..c9f4052a4 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py @@ -390,6 +390,9 @@ def plot_metric_region( title = f"{metric.upper()} | {stream} | {ch}" + ref_line_dict = {"ssr_adj": 1.0} + line = ref_line_dict.get(metric) + plotter.plot( selected_data, labels, @@ -399,6 +402,7 @@ def plot_metric_region( print_summary=print_summary, title=title, colors=colors, + line=line, ) diff --git a/packages/evaluate/src/weathergen/evaluate/scores/score.py b/packages/evaluate/src/weathergen/evaluate/scores/score.py index 6f8faafc7..cdadf0e8e 100755 --- a/packages/evaluate/src/weathergen/evaluate/scores/score.py +++ b/packages/evaluate/src/weathergen/evaluate/scores/score.py @@ -199,9 +199,11 @@ def __init__( } self.prob_metrics_dict = { "ssr": self.calc_ssr, + "ssr_adj": self.calc_ssr_adj, "crps": self.calc_crps, "rank_histogram": self.calc_rank_histogram, "spread": self.calc_spread, + "spread_adj": self.calc_spread_adj, } def get_score( @@ -254,12 +256,15 @@ def get_score( f = self.det_metrics_dict[score_name] _logger.debug(f"Using deterministic metric: {score_name}") elif score_name in self.prob_metrics_dict.keys(): - assert self.ens_dim in data.prediction.dims, ( - f"Probablistic score {score_name} chosen, but ensemble dimension {self.ens_dim} " - "not found in prediction data. Skipping score calculation." - ) - return None + if self._ens_dim not in data.prediction.dims: + _logger.warning( + f"Probabilistic score '{score_name}' chosen, but ensemble dimension " + f"'{self._ens_dim}' not found in prediction data dims " + f"{data.prediction.dims}. Skipping score calculation." + ) + return None f = self.prob_metrics_dict[score_name] + _logger.debug(f"Using probabilistic metric: {score_name}") else: raise ValueError( f"Unknown score chosen. Supported scores: { @@ -1258,6 +1263,33 @@ def calc_spread(self, p: xr.DataArray, **kwargs) -> xr.DataArray: return self._mean(np.sqrt(ens_std**2)) + def calc_spread_adj(self, p: xr.DataArray, **kwargs) -> xr.DataArray: + """ + Calculate the (unbiased) ensemble spread following the GenCast convention. + + Defined as the square root of the mean *unbiased* (``ddof=1``) ensemble variance, matching + the spread of GenCast (Price et al., https://arxiv.org/pdf/2312.15796, Eq. A.6). The + finite-ensemble inflation factor ``sqrt((M + 1) / M)`` is *not* applied here; following + GenCast it is applied in the spread-skill ratio instead (see ``calc_ssr_adj``). + + Unlike the unadjusted ``calc_spread`` (which uses the biased ``ddof=0`` standard + deviation), this uses the unbiased variance, as required for the GenCast spread-skill + relation to hold. The unadjusted ``calc_spread`` is left unchanged. + + Parameters + ---------- + p: xr.DataArray + Forecast data array with ensemble dimension + + Returns + ------- + xr.DataArray + Unbiased ensemble spread (GenCast convention) + """ + ens_var = p.var(dim=self._ens_dim, ddof=1) + + return np.sqrt(self._mean(ens_var)) + def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: """ Calculate the Spread-Skill Ratio (SSR) of the forecast ensemble data w.r.t. reference data @@ -1273,10 +1305,44 @@ def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: xr.DataArray Spread-Skill Ratio (SSR) """ - ssr = self.calc_spread(p) / self.calc_rmse(p, gt) # spread/rmse + ens_mean = p.mean(dim=self._ens_dim) + ssr = self.calc_spread(p) / self.calc_rmse(ens_mean, gt) return ssr + def calc_ssr_adj(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray: + """ + Calculate the ensemble-size-adjusted Spread-Skill Ratio (SSR) of the forecast ensemble + data w.r.t. reference data. + + Following GenCast (Price et al., https://arxiv.org/pdf/2312.15796, Eq. A.9), this is + + SSR_adj = sqrt((M + 1) / M) * spread / RMSE(ensemble_mean) + + where ``spread`` is the unbiased ensemble spread (``calc_spread_adj``) and M is the + ensemble size. For a perfectly reliable ensemble of finite size M, the RMSE of the + ensemble mean is inflated relative to the spread by exactly ``sqrt((M + 1) / M)`` + (Fortin et al., 2014), so the ``sqrt((M + 1) / M)`` factor applied here makes a perfectly + calibrated ensemble yield an adjusted SSR of exactly 1. The original ``calc_spread`` / + ``calc_ssr`` are left unchanged. + + Parameters + ---------- + p: xr.DataArray + Forecast data array with ensemble dimension + gt: xr.DataArray + Ground truth data array + Returns + ------- + xr.DataArray + Ensemble-size-adjusted Spread-Skill Ratio (SSR). Optimal value is 1. + """ + ens_size = p.sizes[self._ens_dim] + correction = np.sqrt((ens_size + 1) / ens_size) + ens_mean = p.mean(dim=self._ens_dim) + + return correction * self.calc_spread_adj(p) / self.calc_rmse(ens_mean, gt) + def calc_crps( self, p: xr.DataArray, diff --git a/src/weathergen/model/diffusion.py b/src/weathergen/model/diffusion.py index d5f971b2f..473748712 100644 --- a/src/weathergen/model/diffusion.py +++ b/src/weathergen/model/diffusion.py @@ -70,10 +70,10 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast f"'{self.conditioning}' (got offset={_offset})" ) _input_num_steps = self.cf.get("training_config", {}).get("model_input", {}).get("forecasting", {}).get("num_steps_input", 0) - # assert self.conditioning != "forecast" or _input_num_steps == 2, ( - # f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is " - # f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" - # ) + assert self.conditioning != "forecast" or _input_num_steps == 2, ( + f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is " + f"'{self.conditioning}' (got input_num_steps={_input_num_steps})" + ) assert self.conditioning not in ["date_time", "date", "time"] or _input_num_steps == 1, ( f"forecast.input_num_steps must be 1 when fe_diffusion_model_conditioning is " f"'{self.conditioning}' (got input_num_steps={_input_num_steps})"