Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions config/evaluate/eval_config_diffusion.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
# vmax: 40

evaluation:
metrics : ["rmse", "mae"]
regions: ["global", "nhem"]
metrics : ["rmse", "mae", "spread", "ssr","ssr_adj", "crps"] #spread + ssr are ensemble (probabilistic) metrics
regions: ["global"]
summary_plots : true
ratio_plots : false
heat_maps : false
summary_dir: "./plots/"
plot_ensemble: "members" #supported: false, "std", "minmax", "members"
plot_score_maps: false #plot scores on a 2D maps. it slows down score computation
plot_score_maps: true #plot scores on a 2D maps (incl. ensemble spread map). it slows down score computation
print_summary: false #print out score values on screen. it can be verbose
log_scale: false
add_grid: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def plot(
print_summary: bool = False,
title: str | None = None,
colors: list[str | None] | None = None,
line: float | None = None,
) -> None:
"""
Plot a line graph comparing multiple datasets.
Expand All @@ -274,6 +275,9 @@ def plot(
Name of the dimension to be used for the y-axis.
print_summary:
If True, print a summary of the values from the graph.
line:
If provided, draw a horizontal reference line at the given y-value
(e.g. the optimal value of a metric).
Returns
-------
None
Expand Down Expand Up @@ -321,7 +325,9 @@ def plot(
# TODO: generalise this for other x_dims by introducing a "units"
# entry in the function if needed
xunits = "hr" if x_dim == "lead_time" else None
self._plot_base(fig, name, x_dim, y_dim, print_summary, xunits=xunits, title=title)
self._plot_base(
fig, name, x_dim, y_dim, print_summary, line=line, xunits=xunits, title=title
)

def _plot_base(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def _plot_score_maps_per_stream(
if not valid:
return

ens_metrics = {m for m, r in valid if "ens" in r.dims}

plot_metrics = xr.concat(
[r for _, r in valid],
dim="metric",
Expand All @@ -189,14 +191,15 @@ def _plot_score_maps_per_stream(
metric=[m for m, _ in valid],
).compute()

if "ens" in preds.dims:
plot_metrics["ens"] = preds.ens
if "ens" in plot_metrics.dims:
plot_metrics = plot_metrics.assign_coords(ens=preds.ens.values)

has_ens = "ens" in plot_metrics.coords
ens_values = plot_metrics.coords["ens"].values if has_ens else [None]
all_ens = plot_metrics.coords["ens"].values if "ens" in plot_metrics.dims else [None]

plot_tasks: list[dict] = []
for metric in plot_metrics.coords["metric"].values:
metric_has_ens = str(metric) in ens_metrics and "ens" in plot_metrics.dims
ens_values = all_ens if metric_has_ens else [None]
for ens_val in ens_values:
tag = f"score_maps_{metric}_fstep_{fstep}" + (
f"_ens_{ens_val}" if ens_val is not None else ""
Expand All @@ -205,7 +208,10 @@ def _plot_score_maps_per_stream(
sel = {"metric": metric, "channel": channel}
if ens_val is not None:
sel["ens"] = ens_val
data = plot_metrics.sel(**sel).squeeze()
data = plot_metrics.sel(**sel)
if ens_val is None and "ens" in data.dims:
data = data.isel(ens=0, drop=True)
data = data.squeeze()
title = f"{metric} - {channel}: fstep {fstep}" + (
f", ens {ens_val}" if ens_val is not None else ""
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def plot_metric_region(

title = f"{metric.upper()} | {stream} | {ch}"

ref_line_dict = {"ssr_adj": 1.0}
line = ref_line_dict.get(metric)

plotter.plot(
selected_data,
labels,
Expand All @@ -399,6 +402,7 @@ def plot_metric_region(
print_summary=print_summary,
title=title,
colors=colors,
line=line,
)


Expand Down
78 changes: 72 additions & 6 deletions packages/evaluate/src/weathergen/evaluate/scores/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,11 @@ def __init__(
}
self.prob_metrics_dict = {
"ssr": self.calc_ssr,
"ssr_adj": self.calc_ssr_adj,
"crps": self.calc_crps,
"rank_histogram": self.calc_rank_histogram,
"spread": self.calc_spread,
"spread_adj": self.calc_spread_adj,
}

def get_score(
Expand Down Expand Up @@ -254,12 +256,15 @@ def get_score(
f = self.det_metrics_dict[score_name]
_logger.debug(f"Using deterministic metric: {score_name}")
elif score_name in self.prob_metrics_dict.keys():
assert self.ens_dim in data.prediction.dims, (
f"Probablistic score {score_name} chosen, but ensemble dimension {self.ens_dim} "
"not found in prediction data. Skipping score calculation."
)
return None
if self._ens_dim not in data.prediction.dims:
_logger.warning(
f"Probabilistic score '{score_name}' chosen, but ensemble dimension "
f"'{self._ens_dim}' not found in prediction data dims "
f"{data.prediction.dims}. Skipping score calculation."
)
return None
f = self.prob_metrics_dict[score_name]
_logger.debug(f"Using probabilistic metric: {score_name}")
else:
raise ValueError(
f"Unknown score chosen. Supported scores: {
Expand Down Expand Up @@ -1258,6 +1263,33 @@ def calc_spread(self, p: xr.DataArray, **kwargs) -> xr.DataArray:

return self._mean(np.sqrt(ens_std**2))

def calc_spread_adj(self, p: xr.DataArray, **kwargs) -> xr.DataArray:
"""
Calculate the (unbiased) ensemble spread following the GenCast convention.

Defined as the square root of the mean *unbiased* (``ddof=1``) ensemble variance, matching
the spread of GenCast (Price et al., https://arxiv.org/pdf/2312.15796, Eq. A.6). The
finite-ensemble inflation factor ``sqrt((M + 1) / M)`` is *not* applied here; following
GenCast it is applied in the spread-skill ratio instead (see ``calc_ssr_adj``).

Unlike the unadjusted ``calc_spread`` (which uses the biased ``ddof=0`` standard
deviation), this uses the unbiased variance, as required for the GenCast spread-skill
relation to hold. The unadjusted ``calc_spread`` is left unchanged.

Parameters
----------
p: xr.DataArray
Forecast data array with ensemble dimension

Returns
-------
xr.DataArray
Unbiased ensemble spread (GenCast convention)
"""
ens_var = p.var(dim=self._ens_dim, ddof=1)

return np.sqrt(self._mean(ens_var))

def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray:
"""
Calculate the Spread-Skill Ratio (SSR) of the forecast ensemble data w.r.t. reference data
Expand All @@ -1273,10 +1305,44 @@ def calc_ssr(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray:
xr.DataArray
Spread-Skill Ratio (SSR)
"""
ssr = self.calc_spread(p) / self.calc_rmse(p, gt) # spread/rmse
ens_mean = p.mean(dim=self._ens_dim)
ssr = self.calc_spread(p) / self.calc_rmse(ens_mean, gt)

return ssr

def calc_ssr_adj(self, p: xr.DataArray, gt: xr.DataArray) -> xr.DataArray:
"""
Calculate the ensemble-size-adjusted Spread-Skill Ratio (SSR) of the forecast ensemble
data w.r.t. reference data.

Following GenCast (Price et al., https://arxiv.org/pdf/2312.15796, Eq. A.9), this is

SSR_adj = sqrt((M + 1) / M) * spread / RMSE(ensemble_mean)

where ``spread`` is the unbiased ensemble spread (``calc_spread_adj``) and M is the
ensemble size. For a perfectly reliable ensemble of finite size M, the RMSE of the
ensemble mean is inflated relative to the spread by exactly ``sqrt((M + 1) / M)``
(Fortin et al., 2014), so the ``sqrt((M + 1) / M)`` factor applied here makes a perfectly
calibrated ensemble yield an adjusted SSR of exactly 1. The original ``calc_spread`` /
``calc_ssr`` are left unchanged.

Parameters
----------
p: xr.DataArray
Forecast data array with ensemble dimension
gt: xr.DataArray
Ground truth data array
Returns
-------
xr.DataArray
Ensemble-size-adjusted Spread-Skill Ratio (SSR). Optimal value is 1.
"""
ens_size = p.sizes[self._ens_dim]
correction = np.sqrt((ens_size + 1) / ens_size)
ens_mean = p.mean(dim=self._ens_dim)

return correction * self.calc_spread_adj(p) / self.calc_rmse(ens_mean, gt)

def calc_crps(
self,
p: xr.DataArray,
Expand Down
8 changes: 4 additions & 4 deletions src/weathergen/model/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def __init__(self, cf: Config, num_healpix_cells: int, forecast_engine: Forecast
f"'{self.conditioning}' (got offset={_offset})"
)
_input_num_steps = self.cf.get("training_config", {}).get("model_input", {}).get("forecasting", {}).get("num_steps_input", 0)
# assert self.conditioning != "forecast" or _input_num_steps == 2, (
# f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is "
# f"'{self.conditioning}' (got input_num_steps={_input_num_steps})"
# )
assert self.conditioning != "forecast" or _input_num_steps == 2, (
f"forecast.input_num_steps must be 2 when fe_diffusion_model_conditioning is "
f"'{self.conditioning}' (got input_num_steps={_input_num_steps})"
)
assert self.conditioning not in ["date_time", "date", "time"] or _input_num_steps == 1, (
f"forecast.input_num_steps must be 1 when fe_diffusion_model_conditioning is "
f"'{self.conditioning}' (got input_num_steps={_input_num_steps})"
Expand Down
Loading