diff --git a/config/forecasters-ich1-oper-fixed.yaml b/config/forecasters-ich1-oper-fixed.yaml index 8a332955..1c8c2519 100644 --- a/config/forecasters-ich1-oper-fixed.yaml +++ b/config/forecasters-ich1-oper-fixed.yaml @@ -70,6 +70,8 @@ experiment: # - init_hour # - region - season + score_maps: + enabled: false locations: output_root: output/ diff --git a/config/forecasters-ich1-oper.yaml b/config/forecasters-ich1-oper.yaml index a2956738..b860f434 100644 --- a/config/forecasters-ich1-oper.yaml +++ b/config/forecasters-ich1-oper.yaml @@ -67,6 +67,8 @@ experiment: # - init_hour # - region - season + score_maps: + enabled: false locations: output_root: output/ diff --git a/config/forecasters-ich1.yaml b/config/forecasters-ich1.yaml index 1bf6f7be..576ce764 100644 --- a/config/forecasters-ich1.yaml +++ b/config/forecasters-ich1.yaml @@ -79,6 +79,8 @@ experiment: # - init_hour # - region - season + score_maps: + enabled: false locations: output_root: output/ diff --git a/config/temporal-downscalers-ich1.yaml b/config/temporal-downscalers-ich1.yaml index 51ec4ed9..21dfd5ec 100644 --- a/config/temporal-downscalers-ich1.yaml +++ b/config/temporal-downscalers-ich1.yaml @@ -98,6 +98,8 @@ experiment: - "V_10M:RMSE,R2,ETS" - "T_2M:RMSE,R2,ETS" - "TOT_PREC:RMSE,R2,ETS" + score_maps: + enabled: false showcase: params: diff --git a/src/data_input/__init__.py b/src/data_input/__init__.py index 6c14005a..96b3e476 100644 --- a/src/data_input/__init__.py +++ b/src/data_input/__init__.py @@ -180,6 +180,12 @@ def _discover_icon_member_ids( def load_from_grib_file(file: str | list[str], sel_kwargs): + # Coerce Path objects to str: earthkit-data unwraps a single-element list + # into one File source without converting, and then fails on non-str paths. + if isinstance(file, (list, tuple)): + file = [str(f) for f in file] + else: + file = str(file) fieldlist = ekd.from_source("file", file, lazily=True).to_fieldlist() return fieldlist_to_xarray(fieldlist.sel(**sel_kwargs)) diff --git a/src/evalml/cli.py b/src/evalml/cli.py index 51a9ed45..f3df3bf4 100644 --- a/src/evalml/cli.py +++ b/src/evalml/cli.py @@ -146,7 +146,7 @@ def execute_workflow( if report and not dry_run: command += ["--report-after-run", "--report", str(report)] - command.append(target) + command += [target] command += list(extra_smk_args) if not verbose: command += ["--quiet", "rules"] # reduce verobosity of snakemake output @@ -165,7 +165,15 @@ def cli(): ) @workflow_options def experiment( - configfile, cores, verbose, dry_run, unlock, report, dag, rulegraph, extra_smk_args + configfile, + cores, + verbose, + dry_run, + unlock, + report, + dag, + rulegraph, + extra_smk_args, ): execute_workflow( configfile, diff --git a/src/evalml/config.py b/src/evalml/config.py index e101afc3..b5b42402 100644 --- a/src/evalml/config.py +++ b/src/evalml/config.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, List, Any, ClassVar, FrozenSet, Optional +from typing import Dict, List, Any, ClassVar, FrozenSet, Literal, Optional from pydantic import BaseModel, Field, RootModel, field_validator @@ -219,6 +219,49 @@ class BaselineItem(BaseModel): baseline: BaselineConfig +class ScoreMapsConfig(BaseModel): + """Parameters controlling which score map plots are produced.""" + + enabled: bool = Field( + default=False, + description="Whether to produce score maps (computationally intensive).", + ) + params: List[str] = Field( + default=["T_2M"], + description=( + "List of parameters to plot. Supported values: T_2M, TD_2M, U_10M, V_10M, " + "PS, PMSL, TOT_PREC (native), and SP_10M (derived wind speed from U_10M/V_10M)." + ), + ) + leadtimes: List[int] | Literal["all"] = Field( + default=[6, 24], + description=( + "List of lead times (hours) to plot, or the literal string 'all' " + "to expand to the union of step lists from all configured runs " + "and baselines." + ), + ) + scores: List[str] = Field( + default=["BIAS"], + description="List of verification scores to plot. Supported: BIAS, RMSE, MAE.", + ) + regions: List[str] = Field( + default=["switzerland"], + description="List of regions to plot (e.g. switzerland, centraleurope).", + ) + seasons: List[str] = Field( + default=["all"], + description="List of seasons to plot ('all', 'DJF', 'MAM', 'JJA', 'SON').", + ) + init_hours: List[str] = Field( + default=["all"], + description=( + "List of initialization hours to plot. Use 'all' for the unstratified " + "view, or zero-padded hour strings like '00', '06', '12', '18'." + ), + ) + + class DomainConfig(BaseModel): """A custom map domain defined by name, extent, and projection.""" @@ -380,6 +423,10 @@ class ExperimentConfig(BaseModel): default=None, description="Scorecard generation configuration. Omit or set enabled: false to disable.", ) + score_maps: ScoreMapsConfig = Field( + default_factory=ScoreMapsConfig, + description="Score map plot configuration. Set enabled: true to produce score maps.", + ) @field_validator("thresholds") @classmethod diff --git a/src/plotting/__init__.py b/src/plotting/__init__.py index 163bf74b..524834c6 100644 --- a/src/plotting/__init__.py +++ b/src/plotting/__init__.py @@ -39,7 +39,7 @@ def get_projection(name: str) -> "ccrs.Projection": "projection": _PROJECTIONS["orthographic"], }, "centraleurope": { - "extent": [-2.6, 19.5, 40.2, 52.3], + "extent": [-1.5, 18, 41.5, 51], "projection": _PROJECTIONS["orthographic"], }, "icon-ch": { diff --git a/src/plotting/colormap_defaults.py b/src/plotting/colormap_defaults.py index 88c065e6..ca30fec2 100644 --- a/src/plotting/colormap_defaults.py +++ b/src/plotting/colormap_defaults.py @@ -4,6 +4,7 @@ from matplotlib import pyplot as plt import warnings from .colormap_loader import load_ncl_colormap +import numpy as np def _fallback(): @@ -129,6 +130,135 @@ def _fallback(): 120.0, ], }, + # Sequential Reds for RMSE and MAE: error is non-negative, larger ⇒ darker. + # Levels start at 0 so saturation maps directly to error magnitude; + # discrete levels make absolute values readable from the colour bar. + # RMSE: + "U_10M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "V_10M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "SP_10M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "TD_2M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "°C"}, + "T_2M.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "°C"}, + "PMSL.RMSE.map": { + "cmap": plt.get_cmap("Reds", 7), + "levels": [0, 50, 100, 150, 200, 250, 300, 350], + } + | {"units": "Pa"}, + "PS.RMSE.map": { + "cmap": plt.get_cmap("Reds", 7), + "levels": [0, 50, 100, 150, 200, 250, 300, 350], + } + | {"units": "Pa"}, + "TOT_PREC.RMSE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 1, 1.5, 2, 3, 4], + } + | {"units": "mm"}, + # MAE: + "U_10M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "V_10M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "SP_10M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "m/s"}, + "TD_2M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "°C"}, + "T_2M.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 0.5, 1, 1.5, 2, 2.5, 3], + } + | {"units": "°C"}, + "PMSL.MAE.map": { + "cmap": plt.get_cmap("Reds", 7), + "levels": [0, 50, 100, 150, 200, 250, 300, 350], + } + | {"units": "Pa"}, + "PS.MAE.map": { + "cmap": plt.get_cmap("Reds", 7), + "levels": [0, 50, 100, 150, 200, 250, 300, 350], + } + | {"units": "Pa"}, + "TOT_PREC.MAE.map": { + "cmap": plt.get_cmap("Reds", 6), + "levels": [0, 1, 1.5, 2, 3, 4], + } + | {"units": "mm"}, + # the levels for precip are a bit on the bright side, but still worth keeping consistent with RMSE. + # Bias: + # diverging colour scheme for the Bias to reflect the nature of the data (can be positive or negative, symmetric). + # Red-Blue colour scheme for all variables except precipitation, where a Brown-Green scheme is more suggestive. + "U_10M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 9), + "levels": np.arange(start=-2.25, stop=2.26, step=0.5), + } + | {"units": "m/s"}, + "V_10M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 9), + "levels": np.arange(start=-2.25, stop=2.26, step=0.5), + } + | {"units": "m/s"}, + "SP_10M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 9), + "levels": np.arange(start=-2.25, stop=2.26, step=0.5), + } + | {"units": "m/s"}, + "TD_2M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-2.75, stop=2.76, step=0.5), + } + | {"units": "°C"}, + "T_2M.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-2.75, stop=2.76, step=0.5), + } + | {"units": "°C"}, + "PMSL.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-110, stop=111, step=20), + } + | {"units": "Pa"}, + "PS.BIAS.map": { + "cmap": plt.get_cmap("RdBu_r", 11), + "levels": np.arange(start=-110, stop=111, step=20), + } + | {"units": "Pa"}, + "TOT_PREC.BIAS.map": { + "cmap": plt.get_cmap("BrBG", 9), + "levels": [-1, -0.5, -0.25, -0.1, 0.1, 0.25, 0.5, 1], + } + | {"units": "mm"}, } CMAP_DEFAULTS = defaultdict(_fallback, _CMAP_DEFAULTS) diff --git a/workflow/Snakefile b/workflow/Snakefile index 1e9ea971..24214458 100644 --- a/workflow/Snakefile +++ b/workflow/Snakefile @@ -130,6 +130,36 @@ onerror: # ----------------------------------------------------- +SCORE_MAPS_CONFIG = config["experiment"]["score_maps"] + + +rule score_maps_all: + """Target rule for score maps (opt-in via experiment.score_maps.enabled).""" + input: + expand( + rules.plot_score_maps.output, + run_id=collect_all_candidates(), + leadtime=resolve_leadtimes(SCORE_MAPS_CONFIG["leadtimes"]), + score=SCORE_MAPS_CONFIG["scores"], + param=SCORE_MAPS_CONFIG["params"], + region=SCORE_MAPS_CONFIG["regions"], + season=SCORE_MAPS_CONFIG["seasons"], + init_hour=SCORE_MAPS_CONFIG["init_hours"], + experiment=EXPERIMENT_NAME, + ), + expand( + rules.plot_score_maps_baseline.output, + baseline_id=list(BASELINE_CONFIGS), + leadtime=resolve_leadtimes(SCORE_MAPS_CONFIG["leadtimes"]), + score=SCORE_MAPS_CONFIG["scores"], + param=SCORE_MAPS_CONFIG["params"], + region=SCORE_MAPS_CONFIG["regions"], + season=SCORE_MAPS_CONFIG["seasons"], + init_hour=SCORE_MAPS_CONFIG["init_hours"], + experiment=EXPERIMENT_NAME, + ), + + rule experiment_all: """Target rule for experiment workflow.""" input: @@ -158,6 +188,7 @@ rule experiment_all: ) else [] ), + (rules.score_maps_all.input if SCORE_MAPS_CONFIG["enabled"] else []), rule showcase_all: diff --git a/workflow/rules/common.smk b/workflow/rules/common.smk index dc7fa0fc..a62ff46b 100644 --- a/workflow/rules/common.smk +++ b/workflow/rules/common.smk @@ -353,3 +353,20 @@ _scorecard = config.get("experiment", {}).get("scorecards") or {} SCORECARD_CONFIGS = ( _scorecard.get("sections", {}) if _scorecard.get("enabled", True) else {} ) + + +def resolve_leadtimes(spec): + """Resolve a lead-time specification from config. + + Accepts: + - a list of ints — returned verbatim. + - the literal string "all" — expanded to the union of step lists + from all configured runs and baselines. + """ + if spec != "all": + return spec + all_steps = set() + for cfg in (*RUN_CONFIGS.values(), *BASELINE_CONFIGS.values()): + start, end, step = map(int, cfg["steps"].split("/")) + all_steps.update(range(start, end + 1, step)) + return sorted(all_steps) diff --git a/workflow/rules/plot.smk b/workflow/rules/plot.smk index 91ff1080..1392a53d 100644 --- a/workflow/rules/plot.smk +++ b/workflow/rules/plot.smk @@ -164,3 +164,49 @@ rule make_forecast_animation: """ convert -delay {params.delay} -loop 0 {input} {output} """ + + +rule plot_score_maps: + # localrule: True + input: + script="workflow/scripts/plot_score_maps.mo.py", + verif_file=OUT_ROOT / "data/runs/{run_id}/score_maps/{param}_{leadtime}.nc", + output: + OUT_ROOT + / "results/{experiment}/score_maps/runs/{run_id}/{param}_{score}_{region}_{season}_{init_hour}_{leadtime}.png", + log: + OUT_ROOT + / "logs/plot_score_maps/{experiment}/{run_id}-{param}-{score}-{region}-{season}-{init_hour}-{leadtime}.log", + wildcard_constraints: + leadtime=r"\d+", # only digits + init_hour=r"all|\d{1,2}", + resources: + slurm_partition="postproc", + cpus_per_task=1, + runtime="10m", + shell: + """ + export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) + uv run python {input.script} \ + --input {input.verif_file} --outfn {output[0]} --region {wildcards.region} \ + --param {wildcards.param} --leadtime {wildcards.leadtime} --score {wildcards.score} \ + --season {wildcards.season} --init_hour {wildcards.init_hour} >{log} 2>&1 + # interactive editing (needs to set localrule: True and use only one core) + # marimo edit {input.script} -- \ + # --input {input.verif_file} --outfn {output[0]} --region {wildcards.region} \ + # --param {wildcards.param} --leadtime {wildcards.leadtime} --score {wildcards.score} \ + # --season {wildcards.season} --init_hour {wildcards.init_hour} + """ + + +use rule plot_score_maps as plot_score_maps_baseline with: + input: + script="workflow/scripts/plot_score_maps.mo.py", + verif_file=OUT_ROOT + / f"data/baselines/{{baseline_id}}/{config['truth']['label']}/score_maps/{{param}}_{{leadtime}}.nc", + output: + OUT_ROOT + / "results/{experiment}/score_maps/baselines/{baseline_id}/{param}_{score}_{region}_{season}_{init_hour}_{leadtime}.png", + log: + OUT_ROOT + / "logs/plot_score_maps/{experiment}/{baseline_id}-{param}-{score}-{region}-{season}-{init_hour}-{leadtime}.log", diff --git a/workflow/rules/verification.smk b/workflow/rules/verification.smk index 394eaf87..52032864 100644 --- a/workflow/rules/verification.smk +++ b/workflow/rules/verification.smk @@ -169,3 +169,80 @@ rule verification_metrics_plot: """ uv run {input.script} {input.verif} --output_dir {output} >{log} 2>&1 """ + + +rule verification_score_maps: + input: + "src/verification/__init__.py", + "src/data_input/__init__.py", + script="workflow/scripts/verification_score_maps.py", + inference_okfiles=lambda wc: expand( + rules.inference_execute.output.okfile, + init_time=_restrict_reftimes_to_hours(REFTIMES), + allow_missing=True, + ), + truth=config["truth"]["root"], + output: + OUT_ROOT / "data/runs/{run_id}/score_maps/{param}_{leadtime}.nc", + log: + OUT_ROOT / "logs/verification_score_maps/{run_id}-{param}-{leadtime}.log", + resources: + cpus_per_task=24, + mem_mb=50_000, + runtime="60m", + # wildcard_constraints: + # run_id="^" # to avoid ambiguitiy with run_baseline_verif + # TODO: implement logic to use experiment name instead of run_id as wildcard + params: + fcst_label=lambda wc: RUN_CONFIGS[wc.run_id].get("label"), + fcst_steps=lambda wc: RUN_CONFIGS[wc.run_id]["steps"], + truth_label=config["truth"]["label"], + reftimes=" ".join(t.strftime("%Y%m%d%H%M") for t in REFTIMES), + shell: + """ + uv run {input.script} \ + --run_root output/data/runs/{wildcards.run_id} \ + --reftimes {params.reftimes} \ + --truth {input.truth} \ + --step {wildcards.leadtime} \ + --steps "{params.fcst_steps}" \ + --param {wildcards.param} \ + --output {output} >{log} 2>&1 + """ + + +rule verification_score_maps_baseline: + input: + "src/verification/__init__.py", + "src/data_input/__init__.py", + script="workflow/scripts/verification_score_maps.py", + forecast=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["root"], + truth=config["truth"]["root"], + eckit_grids=rules.data_download_eckit_geo_grids.output, + output: + OUT_ROOT + / f"data/baselines/{{baseline_id}}/{config['truth']['label']}/score_maps/{{param}}_{{leadtime}}.nc", + log: + OUT_ROOT + / f"logs/verification_score_maps_baseline/{{baseline_id}}-{config['truth']['label']}-{{param}}-{{leadtime}}.log", + resources: + cpus_per_task=24, + mem_mb=50_000, + runtime="60m", + params: + baseline_steps=lambda wc: BASELINE_CONFIGS[wc.baseline_id]["steps"], + member=lambda wc: BASELINE_CONFIGS[wc.baseline_id].get("member", "000"), + reftimes=" ".join(t.strftime("%Y%m%d%H%M") for t in REFTIMES), + shell: + """ + export ECCODES_DEFINITION_PATH=$(realpath .venv/share/eccodes-cosmo-resources/definitions) + uv run {input.script} \ + --baseline_root {input.forecast} \ + --reftimes {params.reftimes} \ + --truth {input.truth} \ + --step {wildcards.leadtime} \ + --steps "{params.baseline_steps}" \ + --param {wildcards.param} \ + --member "{params.member}" \ + --output {output} >{log} 2>&1 + """ diff --git a/workflow/scripts/plot_score_maps.mo.py b/workflow/scripts/plot_score_maps.mo.py new file mode 100644 index 00000000..c484ad3e --- /dev/null +++ b/workflow/scripts/plot_score_maps.mo.py @@ -0,0 +1,222 @@ +import marimo + +__generated_with = "0.19.4" +app = marimo.App(width="medium") + + +@app.cell +def _(): + import logging + from argparse import ArgumentParser + from pathlib import Path + + import earthkit.plots as ekp + import numpy as np + import xarray as xr + + from plotting import DOMAINS + from plotting import StatePlotter + from plotting.colormap_defaults import CMAP_DEFAULTS + + return ( + ArgumentParser, + CMAP_DEFAULTS, + DOMAINS, + Path, + StatePlotter, + ekp, + logging, + np, + xr, + ) + + +@app.cell +def _(logging): + LOG = logging.getLogger(__name__) + LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + logging.basicConfig(level=logging.INFO, format=LOG_FMT) + return (LOG,) + + +@app.cell +def _(ArgumentParser, Path, np): + parser = ArgumentParser() + + parser.add_argument( + "--input", + type=str, + default=None, + help="Directory to .nc data containing the error fields", + ) + parser.add_argument("--outfn", type=str, help="output filename") + parser.add_argument("--leadtime", type=str, help="leadtime") + parser.add_argument("--param", type=str, help="parameter") + parser.add_argument("--region", type=str, help="name of region") + parser.add_argument( + "--score", + type=str, + help="Evaluation Score. So far Bias, RMSE, MAE or STDE are implemented.", + ) + parser.add_argument("--season", type=str, default="all", help="season filter") + parser.add_argument( + "--init_hour", type=str, default="all", help="initialization hour filter" + ) + + args = parser.parse_args() + verif_file = Path(args.input) + outfn = Path(args.outfn) + lead_time = args.leadtime + param = args.param + region = args.region + season = args.season + init_hour = args.init_hour + score = args.score + + if isinstance(init_hour, str): + if init_hour == "all": + init_hour = -999 + else: + try: + init_hour = int(init_hour) + except ValueError as exc: + raise ValueError("init_hour must be 'all' or an integer hour") from exc + + lead_time = np.timedelta64(lead_time, "h") + return ( + init_hour, + lead_time, + outfn, + param, + region, + score, + season, + verif_file, + ) + + +@app.cell +def _(LOG, init_hour, param, score, season, verif_file, xr): + ds = xr.open_dataset(verif_file) + LOG.info("Opened dataset: %s", ds) + var = f"{param}.{score}" + LOG.info( + "Selecting variable '%s' for season '%s', init_hour=%s", var, season, init_hour + ) + ds = ds[var].sel(season=season, init_hour=init_hour) + LOG.info( + "Selected DataArray: dims=%s, shape=%s, dtype=%s", ds.dims, ds.shape, ds.dtype + ) + LOG.info( + "Value range: min=%.4g, max=%.4g, n_nan=%d", + float(ds.min()), + float(ds.max()), + int(ds.isnull().sum()), + ) + return (ds,) + + +@app.cell +def _(CMAP_DEFAULTS, ekp): + def get_style(param, score, units_override=None): + """Get style and colormap settings for the plot. + Needed because cmap/norm does not work in Style(colors=cmap), + still needs to be passed as arguments to tripcolor()/tricontourf(). + """ + score_key = f"{param}.{score}.map" + cfg = ( + CMAP_DEFAULTS[score_key] + if score_key in CMAP_DEFAULTS + else CMAP_DEFAULTS.get(param, {}) + ) + units = units_override if units_override is not None else cfg.get("units", "") + return { + "style": ekp.styles.Style( + levels=cfg.get("bounds", cfg.get("levels", None)), + extend="both", + units=units, + colors=cfg.get("colors", None), + ), + "norm": cfg.get("norm", None), + "cmap": cfg.get("cmap", None), + "levels": cfg.get("levels", None), + "vmin": cfg.get("vmin", None), + "vmax": cfg.get("vmax", None), + "colors": cfg.get("colors", None), + } + + return (get_style,) + + +@app.cell +def _( + DOMAINS, + LOG, + StatePlotter, + ds, + get_style, + init_hour, + lead_time, + np, + outfn, + param, + region, + score, + season, +): + # plot individual fields + + plotter = StatePlotter( + ds["lon"].values.ravel(), + ds["lat"].values.ravel(), + outfn.parent, + ) + fig = plotter.init_geoaxes( + nrows=1, + ncols=1, + projection=DOMAINS[region]["projection"], + bbox=DOMAINS[region]["extent"], + name=region, + size=(6, 6), + ) + subplot = fig.add_map(row=0, column=0) + + plot_vals = ds.values.ravel() + + style_kwargs = get_style(param, score) + LOG.info("style_kwargs: %s", style_kwargs) + + if np.all(np.isnan(plot_vals)): + LOG.warning( + "All values are NaN for %s %s season=%s — plotting empty map.", + param, + score, + season, + ) + import matplotlib.patches as mpatches + + subplot.ax.set_facecolor("#cccccc") + subplot.standard_layers() + grey_patch = mpatches.Patch(color="#cccccc", label="No data") + subplot.ax.legend(handles=[grey_patch], loc="lower left", fontsize=8) + else: + plotter.plot_field(subplot, plot_vals, **style_kwargs) + + # black coast lines and country borders for better visibility + # grey is hardly visible, especially when the shading colours are intense. + subplot.coastlines(edgecolor="black", linewidth=1.0, zorder=5) + subplot.borders(edgecolor="black", linewidth=0.5, zorder=5) + + init_hour_lbl = "all" if init_hour == -999 else f"{init_hour:02d}" + fig.title( + f"{score} of {param}, Season: {season}, " + f"Init hour: {init_hour_lbl}, Lead Time: {lead_time}" + ) + + fig.save(outfn, bbox_inches="tight", dpi=200) + LOG.info(f"saved: {outfn}") + return + + +if __name__ == "__main__": + app.run() diff --git a/workflow/scripts/verification_score_maps.py b/workflow/scripts/verification_score_maps.py new file mode 100644 index 00000000..e4496ee8 --- /dev/null +++ b/workflow/scripts/verification_score_maps.py @@ -0,0 +1,681 @@ +"""Compute spatial maps of temporally-aggregated forecast errors. + +For a fixed lead time and variable, iterates over all initialisation times +(discovered under a run directory, or taken from --reftimes for baselines), +loads the corresponding forecast field and the matching truth slice from a +reference zarr, maps the forecast onto the truth grid, and accumulates running +error statistics without ever holding the full time series in memory. The +final BIAS / RMSE / MAE / STDE maps are written to a NetCDF file. + +Forecasts load through data_input.load_forecast_data, which routes by source: +ML run directories (GRIB files), INCA (NetCDF archive), or otherwise the ICON +operational GRIB archive. Baselines (--baseline_root) use the latter two paths; +init times are not discovered from the archive but taken from --reftimes, with +unavailable dates skipped at load time. + +Design note: one Snakemake job per (run, param, lead time), each loading only the +step(s) it needs. We deliberately do not load all lead times at once: per-job +memory and output disk scale with N_leadtimes x grid size, which is infeasible at +interpolator (1 h) and nowcasting (10 min) resolutions; that cost is independent +of GRIB read speed, so it does not improve as loading gets faster. For TOT_PREC +the loader (data_input._tot_prec_handling) de-accumulates over the requested +[step - period, step] window, so we just select the target step. + +Usage +----- + uv run workflow/scripts/verification_score_maps.py \\ + output/data/runs/ \\ + --truth /path/to/truth.zarr \\ + --step 24 \\ + --param T_2M +""" + +import logging +from argparse import ArgumentParser, Namespace +from datetime import datetime, timedelta +from pathlib import Path + +import numpy as np +import xarray as xr + +from data_input import load_forecast_data, parse_steps +from verification.spatial import map_forecast_to_truth + +LOG = logging.getLogger(__name__) +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + +DATETIME_FMT = "%Y%m%d%H%M" + +SEASONS = ["DJF", "MAM", "JJA", "SON", "all"] +# Init hour buckets. -999 is the "all" sentinel (matches verification_aggregation.py). +INIT_HOURS = [0, 6, 12, 18, -999] + + +def _season_of(dt: datetime) -> str: + """Return the meteorological season string for a given datetime.""" + month = dt.month + if month in (12, 1, 2): + return "DJF" + if month in (3, 4, 5): + return "MAM" + if month in (6, 7, 8): + return "JJA" + return "SON" + + +# Maps from standard parameter names to zarr variable names. +# COSMO-2e zarrs use short CF names; COSMO-1e zarrs keep the COSMO names. +_PARAMS_MAP_CO2 = { + "T_2M": "2t", + "TD_2M": "2d", + "U_10M": "10u", + "V_10M": "10v", + "PS": "sp", + "PMSL": "msl", + "TOT_PREC": "tp", +} +# Derived variables and the components they require. +_DERIVED = { + "SP_10M": ("U_10M", "V_10M"), +} + +# Params whose GRIB/zarr values are cumulative-from-start accumulations and must +# be de-accumulated over a [step - period, step] window before verification. +_ACCUMULATED_PARAMS = {"TOT_PREC"} + + +def _params_map(truth_root: Path, accum_h: int | None = None) -> dict[str, str]: + """Map canonical parameter names to truth-zarr variable names. + + COSMO-2e zarrs use short CF names. COSMO-1e / ICON zarrs store precip as + period accumulations named ``TOT_PREC_H``, where N is the accumulation + length in hours (matching the verification step spacing); pass it via + ``accum_h``. + """ + if "co2" in truth_root.name: + return _PARAMS_MAP_CO2 + suffix = f"TOT_PREC_{accum_h}H" if accum_h else "TOT_PREC_6H" + return {k: k.replace("TOT_PREC", suffix) for k in _PARAMS_MAP_CO2} + + +def _compute_derived(ds: xr.Dataset, param: str) -> xr.DataArray: + """Compute a derived variable from its components already present in *ds*.""" + if param == "SP_10M": + return (ds["U_10M"] ** 2 + ds["V_10M"] ** 2) ** 0.5 + raise ValueError(f"No recipe for derived variable '{param}'") + + +# --------------------------------------------------------------------------- +# Truth loading +# --------------------------------------------------------------------------- +# TODO: consolidate with src/data_input/__init__.py as part of the +# refactor/data-io branch. _open_zarr_component below duplicates +# ~80% of load_analysis_data_from_zarr but returns a lazy DataArray +# rather than a time-sliced Dataset, which is what our streaming +# aggregation needs. The right end-state is a shared lazy-open primitive +# in data_input that both consumers use; not introduced here to avoid +# conflicting with the data-io refactor. Until then this opener must +# mirror the loader's conventions (notably the m -> mm precip conversion +# from MRB-820). + + +def _open_zarr_component( + root: Path, param: str, accum_h: int | None = None +) -> xr.DataArray: + """Open a single native zarr variable lazily as a DataArray.""" + zarr_param = _params_map(root, accum_h)[param] + + ds = xr.open_zarr(root, consolidated=False) + ds = ds.set_index(time="dates") + + # Extract lat/lon before selecting on variable (they live on cell only). + spatial_dim = "cell" + lat = ds["latitudes"] if "latitudes" in ds else None + lon = ds["longitudes"] if "longitudes" in ds else None + + ds = ds.assign_coords(variable=ds.attrs["variables"]) + ds = ds.sel(variable=zarr_param).squeeze("ensemble", drop=True) + + # Recover 2-D spatial shape when stored as a flat cell dimension. + if len(ds.attrs["field_shape"]) == 2: + ny, nx = ds.attrs["field_shape"] + y_idx, x_idx = np.unravel_index(np.arange(ny * nx), (ny, nx)) + ds = ds.assign_coords(y=(spatial_dim, y_idx), x=(spatial_dim, x_idx)) + ds = ds.set_index(**{spatial_dim: ("y", "x")}).unstack(spatial_dim) + spatial_dim = None # now (y, x) + + da = ds["data"].rename(param).drop_vars("variable", errors="ignore") + + # Truth zarrs store precip in m (anemoi convention); all forecast loaders + # deliver canonical mm (kg m-2) since MRB-820, which put this conversion in + # load_analysis_data_from_zarr. Mirror it here until this opener is + # consolidated into data_input (refactor/data-io). Stays lazy (dask). + if param in _ACCUMULATED_PARAMS: + da = da * 1000 + + # Attach latitude/longitude as coordinates on the spatial dimension(s). + # Use the full names to match the forecast loader (load_forecast_data) and + # map_forecast_to_truth, which key on `latitude`/`longitude`. + if lat is not None and lon is not None: + if spatial_dim is not None: + # flat 1-D case: cell/values dim + da = da.assign_coords( + latitude=(spatial_dim, lat.values), + longitude=(spatial_dim, lon.values), + ) + else: + # 2-D case: lat/lon still on original flat index — attach via unstack + da = da.assign_coords( + latitude=(["y", "x"], lat.values.reshape(ny, nx)), + longitude=(["y", "x"], lon.values.reshape(ny, nx)), + ) + + return da + + +def open_truth_zarr(root: Path, param: str, accum_h: int | None = None) -> xr.DataArray: + """Open the truth zarr lazily and return a DataArray for *param*. + + For derived variables (e.g. SP_10M) the required components are loaded and + the derivation is applied on the fly. The returned DataArray has dimensions + ``(time, y, x)`` or ``(time, values)`` and always exposes ``latitude``/``longitude``. + ``accum_h`` selects the precip accumulation length (TOT_PREC_H). + """ + if param in _DERIVED: + components = { + c: _open_zarr_component(root, c, accum_h).drop_vars( + "variable", errors="ignore" + ) + for c in _DERIVED[param] + } + ds = xr.Dataset(components) + return _compute_derived(ds, param) + return _open_zarr_component(root, param, accum_h) + + +# --------------------------------------------------------------------------- +# Init-time discovery +# --------------------------------------------------------------------------- + + +def iter_init_dirs(run_root: Path) -> list[tuple[datetime, Path]]: + """Return ``(reftime, grib_dir)`` pairs for every complete init time. + + Expects subdirectories named ``YYYYMMDDHHMI`` directly under *run_root*. + GRIB files may live either directly in the init-time directory or inside a + ``grib/`` subdirectory. + """ + result = [] + for d in sorted(run_root.iterdir()): + if not d.is_dir(): + continue + try: + reftime = datetime.strptime(d.name, DATETIME_FMT) + except ValueError: + continue + grib_dir = d / "grib" if (d / "grib").is_dir() else d + if not any(grib_dir.glob("*.grib")): + LOG.debug("No GRIB files in %s, skipping", grib_dir) + continue + result.append((reftime, grib_dir)) + return result + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(args: Namespace) -> None: + LOG.info("=" * 60) + LOG.info("Spatial verification param=%s step=%dh", args.param, args.step) + LOG.info("Run root : %s", args.run_root) + LOG.info("Truth : %s", args.truth) + LOG.info("Output : %s", args.output) + LOG.info("=" * 60) + + # Accumulated params (TOT_PREC) are stored cumulative-from-start, while the + # truth is a period accumulation whose length equals the verification step + # spacing (e.g. 6h for steps "0/120/6"). Derive that period so we can (a) + # request the matching [step - period, step] window from the forecast loader + # and (b) read the matching TOT_PREC_H truth variable. We do not + # assume a fixed period; it follows the configured --steps. + accum_h: int | None = None + if args.param in _ACCUMULATED_PARAMS: + if not args.steps: + raise ValueError( + f"--steps is required for accumulated param '{args.param}' " + "(used to derive the accumulation period)." + ) + spacing = np.diff(parse_steps(args.steps)) + if spacing.size == 0: + raise ValueError( + f"Cannot derive an accumulation period from --steps '{args.steps}'." + ) + accum_h = int(spacing.min()) + if args.step < accum_h: + raise ValueError( + f"Lead time {args.step}h is smaller than the {accum_h}h " + f"accumulation period; cannot form a [step - period, step] " + f"window for '{args.param}'." + ) + req_steps = [args.step - accum_h, args.step] + LOG.info("Accumulation period: %dh (forecast window %s)", accum_h, req_steps) + + # INCA delivers native 1h precip sums and (unlike the GRIB paths, where + # the cumulative-from-start diff adapts to the requested window) cannot + # re-aggregate to a coarser period: the value at the target step would + # stay a 1h sum while the truth read is TOT_PREC_H — a silent + # mismatch. Re-aggregation in the loader is a planned follow-up. + if args.baseline_root and "INCA" in args.baseline_root.parts and accum_h != 1: + raise ValueError( + f"INCA provides native 1h accumulations only, but the step " + f"spacing of --steps '{args.steps}' implies a {accum_h}h " + f"accumulation period for '{args.param}'. Use 1h-spaced steps " + f"for INCA score maps." + ) + else: + req_steps = [args.step] + + # Open the truth zarr once; individual time slices are loaded on demand. + truth_da = open_truth_zarr(args.truth, args.param, accum_h) + # Normalise to datetime64[ns] so membership checks work regardless of zarr precision. + truth_da = truth_da.assign_coords( + time=truth_da.time.values.astype("datetime64[ns]") + ) + # Rename flat spatial dim to 'values' if the zarr uses 'cell'. + if "cell" in truth_da.dims: + truth_da = truth_da.rename({"cell": "values"}) + truth_times = set( + truth_da.time.values + ) # keep as datetime64, tolist() yields ints for ns precision + LOG.info("Truth opened lazily: %s", truth_da) + + if args.baseline_root: + # The operational archive is too large to enumerate up front; the + # experiment's configured init times define the work list, and dates + # missing from the archive are skipped at load time below. + init_items = [ + (rt, None) + for rt in sorted(datetime.strptime(s, DATETIME_FMT) for s in args.reftimes) + ] + LOG.info("Using %d baseline init times from --reftimes", len(init_items)) + else: + init_items = iter_init_dirs(args.run_root) + LOG.info("Found %d init time directories", len(init_items)) + + # Restrict to the experiment's configured init times if provided. + if args.reftimes: + wanted = {datetime.strptime(s, DATETIME_FMT) for s in args.reftimes} + init_items = [(rt, d) for rt, d in init_items if rt in wanted] + LOG.info("Filtered to %d init times matching --reftimes", len(init_items)) + + step_td = timedelta(hours=args.step) + + # Running accumulators keyed by (season, init_hour) – initialised on the + # first successfully processed sample so that we can infer the spatial + # shape from the data. Each entry is a numpy array over the spatial + # dimension(s). + bucket_keys = [(s, h) for s in SEASONS for h in INIT_HOURS] + accum_n: dict[tuple[str, int], np.ndarray | None] = {k: None for k in bucket_keys} + accum_sum_e: dict[tuple[str, int], np.ndarray | None] = { + k: None for k in bucket_keys + } + accum_sum_se: dict[tuple[str, int], np.ndarray | None] = { + k: None for k in bucket_keys + } + accum_sum_ae: dict[tuple[str, int], np.ndarray | None] = { + k: None for k in bucket_keys + } + ref_truth_slice: xr.DataArray | None = None # kept for output coordinates + + n_ok = 0 + n_skip = 0 + + for reftime, grib_dir in init_items: + valid_time = np.datetime64(reftime + step_td).astype("datetime64[ns]") + + if valid_time not in truth_times: + LOG.debug("Valid time %s not in truth, skipping %s", valid_time, reftime) + n_skip += 1 + continue + + LOG.info( + "Processing reftime=%s valid=%s", + reftime.strftime(DATETIME_FMT), + valid_time, + ) + + first_iter = n_ok == 0 + + # --- load forecast --- + fct_params = ( + list(_DERIVED[args.param]) if args.param in _DERIVED else [args.param] + ) + + try: + # For accumulated params (TOT_PREC) req_steps is the [step - period, + # step] window; for GRIB sources (runs and the ICON archive) the + # loader de-accumulates the cumulative-from-start field over the + # requested steps (diff over `step`), so the target step holds the + # period accumulation; INCA returns native 1h sums, matching the + # period because accum_h == 1 is enforced above. Instantaneous + # params request a single step. The target step is selected just + # below. + # + # NOTE: data_input._tot_prec_handling treats the FIRST loaded step + # positionally as the forecast initial condition and zero-fills it + # when it is present-but-all-NaN. For our window the first step is + # `step - period` (non-zero except at the first lead time); a + # corrupt/NaN field there would be silently zeroed, turning the diff + # into a cumulative-from-start value. A missing GRIB instead yields + # NaN and is skipped downstream. The principled fix belongs in + # _tot_prec_handling (gate the IC zero-fill on step 0 being + # requested) — see the port note. This applies to runs and ICON + # baselines alike. + src_root = args.baseline_root if args.baseline_root else grib_dir + fcst = load_forecast_data( + src_root, reftime, req_steps, fct_params, member=args.member + ) + except Exception as exc: + LOG.warning("Could not load forecast for %s: %s", reftime, exc) + n_skip += 1 + continue + + # Select the target step. The earthkit loader returns forecasts over the + # requested steps with a `step` (timedelta64) dimension; for TOT_PREC the + # loader has already de-accumulated over the window, so the target step + # holds the period accumulation, and for instantaneous params only the + # single requested step is present. + if "step" in fcst.dims: + fcst = fcst.sel(step=np.timedelta64(args.step, "h")) + + # Compute derived variable if needed. + if args.param in _DERIVED: + fcst = fcst.assign({args.param: _compute_derived(fcst, args.param)}) + + if first_iter: + LOG.info("fcst (after step selection): %s", fcst) + fcst_raw = fcst[args.param].values if args.param in fcst else None + if fcst_raw is not None: + n_nan_fcst = int(np.isnan(fcst_raw).sum()) + LOG.info( + "fcst[%s]: shape=%s, min=%.4g, max=%.4g, n_nan=%d", + args.param, + fcst_raw.shape, + float(np.nanmin(fcst_raw)) + if n_nan_fcst < fcst_raw.size + else float("nan"), + float(np.nanmax(fcst_raw)) + if n_nan_fcst < fcst_raw.size + else float("nan"), + n_nan_fcst, + ) + + # --- load truth slice --- + truth_slice = truth_da.sel(time=valid_time).compute() + # For derived variables truth_da is already the derived DataArray, + # so wrap it in a Dataset for map_forecast_to_truth compatibility. + truth_ds = ( + truth_slice.to_dataset(name=args.param) + if isinstance(truth_slice, xr.DataArray) + else truth_slice + ) + + if first_iter: + truth_raw = truth_slice.values + n_nan_truth = int(np.isnan(truth_raw).sum()) + LOG.info( + "truth_slice[%s]: shape=%s, min=%.4g, max=%.4g, n_nan=%d", + args.param, + truth_raw.shape, + float(np.nanmin(truth_raw)) + if n_nan_truth < truth_raw.size + else float("nan"), + float(np.nanmax(truth_raw)) + if n_nan_truth < truth_raw.size + else float("nan"), + n_nan_truth, + ) + + # --- map forecast onto truth grid --- + try: + fcst_mapped = map_forecast_to_truth(fcst, truth_ds) + except Exception as exc: + LOG.warning("Spatial mapping failed for %s: %s", reftime, exc) + n_skip += 1 + continue + + fcst_param = fcst_mapped[args.param] + # Squeeze size-1 non-spatial dims so the error array is purely spatial. + # The earthkit loader keeps `number` (ensemble), `z` (vertical) and + # `forecast_reference_time` as size-1 dims for a deterministic surface run. + for dim in ["eps", "ensemble", "number", "z", "forecast_reference_time"]: + if dim in fcst_param.dims and fcst_param.sizes[dim] == 1: + fcst_param = fcst_param.squeeze(dim, drop=True) + fcst_vals = fcst_param.values + truth_vals = truth_slice.values + error = fcst_vals - truth_vals # shape: spatial dims of truth + + if first_iter: + n_nan_mapped = int(np.isnan(fcst_vals).sum()) + LOG.info( + "fcst_mapped[%s]: shape=%s, min=%.4g, max=%.4g, n_nan=%d", + args.param, + fcst_vals.shape, + float(np.nanmin(fcst_vals)) + if n_nan_mapped < fcst_vals.size + else float("nan"), + float(np.nanmax(fcst_vals)) + if n_nan_mapped < fcst_vals.size + else float("nan"), + n_nan_mapped, + ) + n_nan_err = int(np.isnan(error).sum()) + LOG.info( + "error: shape=%s, min=%.4g, max=%.4g, n_nan=%d / %d", + error.shape, + float(np.nanmin(error)) if n_nan_err < error.size else float("nan"), + float(np.nanmax(error)) if n_nan_err < error.size else float("nan"), + n_nan_err, + error.size, + ) + + n_nan_error = int(np.isnan(error).sum()) + if n_nan_error == error.size: + LOG.warning( + "reftime=%s: error is all-NaN (%d points) — nothing accumulated.", + reftime.strftime(DATETIME_FMT), + error.size, + ) + + # --- initialise accumulators on first valid sample --- + if accum_n[("all", -999)] is None: + for k in bucket_keys: + accum_n[k] = np.zeros(error.shape, dtype=np.int64) + accum_sum_e[k] = np.zeros(error.shape, dtype=np.float64) + accum_sum_se[k] = np.zeros(error.shape, dtype=np.float64) + accum_sum_ae[k] = np.zeros(error.shape, dtype=np.float64) + ref_truth_slice = truth_slice + + # --- accumulate into matching (season, init_hour) buckets, plus the + # "all" rows/cols on each axis (NaN-safe) --- + season = _season_of(reftime) + ih = reftime.hour + valid = ~np.isnan(error) + for s in (season, "all"): + for h in (ih, -999): + accum_n[(s, h)][valid] += 1 + accum_sum_e[(s, h)][valid] += error[valid] + accum_sum_se[(s, h)][valid] += error[valid] ** 2 + accum_sum_ae[(s, h)][valid] += np.abs(error[valid]) + n_ok += 1 + + LOG.info("Finished: %d init times processed, %d skipped", n_ok, n_skip) + + if n_ok == 0: + LOG.error("No data could be processed – no output written.") + return + + # --- compute aggregate maps per (season, init_hour), then stack --- + spatial_coords = { + c: ref_truth_slice[c] + for c in ref_truth_slice.coords + if set(ref_truth_slice[c].dims).issubset(set(ref_truth_slice.dims)) + and c != "time" + } + spatial_dims = list(ref_truth_slice.dims) + out_dims = ["season", "init_hour"] + spatial_dims + out_coords = {"season": SEASONS, "init_hour": INIT_HOURS, **spatial_coords} + + def _strat_da(compute_fn) -> xr.DataArray: + """Stack per-(season, init_hour) arrays into a (season, init_hour, *spatial) DataArray.""" + out_shape = (len(SEASONS), len(INIT_HOURS)) + ref_truth_slice.shape + arr = np.empty(out_shape, dtype=np.float32) + for i, s in enumerate(SEASONS): + for j, h in enumerate(INIT_HOURS): + n = accum_n[(s, h)] + with np.errstate(invalid="ignore", divide="ignore"): + arr[i, j] = compute_fn(n, s, h).astype(np.float32) + return xr.DataArray(arr, dims=out_dims, coords=out_coords) + + out = xr.Dataset( + { + f"{args.param}.BIAS": _strat_da( + lambda n, s, h: np.where(n > 0, accum_sum_e[(s, h)] / n, np.nan) + ), + f"{args.param}.RMSE": _strat_da( + lambda n, s, h: np.where( + n > 0, np.sqrt(accum_sum_se[(s, h)] / n), np.nan + ) + ), + f"{args.param}.MAE": _strat_da( + lambda n, s, h: np.where(n > 0, accum_sum_ae[(s, h)] / n, np.nan) + ), + f"{args.param}.STDE": _strat_da( + lambda n, s, h: np.where( + n > 0, + np.sqrt( + np.maximum( + accum_sum_se[(s, h)] / n - (accum_sum_e[(s, h)] / n) ** 2, + 0.0, + ) + ), + np.nan, + ) + ), + f"{args.param}.N": _strat_da(lambda n, s, h: np.where(n > 0, n, np.nan)), + }, + attrs={ + "param": args.param, + "step_h": args.step, + # Accumulation period of the verified quantity (accumulated params + # only) — lets consumers tell a 1h INCA map from a 6h ICON map. + "accum_h": accum_h if accum_h is not None else "n/a", + "member": args.member, + "source": str(args.baseline_root if args.baseline_root else args.run_root), + "n_processed": n_ok, + "n_skipped": n_skip, + }, + ) + + LOG.info("Output dataset:\n%s", out) + args.output.parent.mkdir(parents=True, exist_ok=True) + out.to_netcdf(args.output) + LOG.info("Saved to %s", args.output) + + +if __name__ == "__main__": + parser = ArgumentParser( + description=( + "Compute spatial maps of temporally-aggregated forecast errors. " + "Supports model runs (GRIB) and baselines (ICON GRIB archive or " + "INCA NetCDF archive). " + "Exactly one of --run_root or --baseline_root must be provided." + ) + ) + parser.add_argument( + "--run_root", + type=Path, + default=None, + help="Root directory of a model run (e.g. output/data/runs/).", + ) + parser.add_argument( + "--baseline_root", + type=Path, + default=None, + help=( + "Root directory of a baseline archive (e.g. the ICON-CH1/CH2-EPS " + "operational GRIB archive, or an INCA NetCDF archive). Requires " + "--reftimes." + ), + ) + parser.add_argument( + "--member", + type=str, + default="000", + help=( + "Ensemble member to load for ICON baselines: '000' for control, " + "'median' for the pre-computed median, 'mean' to average all " + "members, or any 3-digit member ID. Ignored for runs and INCA." + ), + ) + parser.add_argument( + "--truth", + type=Path, + required=True, + help="Path to the reference zarr dataset.", + ) + parser.add_argument( + "--step", + type=int, + required=True, + help="Forecast lead time in hours (e.g. 24).", + ) + parser.add_argument( + "--param", + type=str, + required=True, + help="Variable to verify (e.g. T_2M, TD_2M, U_10M).", + ) + parser.add_argument( + "--steps", + type=str, + default=None, + help=( + "Forecast step spec 'start/stop/step' (e.g. '0/120/6'). Required for " + "accumulated params (TOT_PREC): the accumulation period is the step " + "spacing, the forecast is accumulated over [step - period, step], and " + "the matching TOT_PREC_H truth variable is read. Ignored for " + "instantaneous params." + ), + ) + parser.add_argument( + "--reftimes", + nargs="+", + default=None, + help=( + "List of init times (YYYYMMDDHHMM). For runs: optional restriction of " + "the discovered init-time directories. For baselines: required; " + "defines the init times to load from the archive." + ), + ) + parser.add_argument( + "--output", + type=Path, + required=True, + help="Output NetCDF file.", + ) + args = parser.parse_args() + + if bool(args.run_root) == bool(args.baseline_root): + parser.error("Exactly one of --run_root or --baseline_root must be provided.") + if args.baseline_root and not args.reftimes: + parser.error( + "--reftimes is required with --baseline_root: init times cannot be " + "discovered from the operational archive." + ) + + main(args) diff --git a/workflow/tools/config.schema.json b/workflow/tools/config.schema.json index d5760f85..8e677694 100644 --- a/workflow/tools/config.schema.json +++ b/workflow/tools/config.schema.json @@ -261,6 +261,10 @@ ], "default": null, "description": "Scorecard generation configuration. Omit or set enabled: false to disable." + }, + "score_maps": { + "$ref": "#/$defs/ScoreMapsConfig", + "description": "Score map plot configuration. Set enabled: true to produce score maps." } }, "required": [ @@ -566,6 +570,94 @@ "title": "Profile", "type": "object" }, + "ScoreMapsConfig": { + "description": "Parameters controlling which score map plots are produced.", + "properties": { + "enabled": { + "default": false, + "description": "Whether to produce score maps (computationally intensive).", + "title": "Enabled", + "type": "boolean" + }, + "params": { + "default": [ + "T_2M" + ], + "description": "List of parameters to plot. Supported values: T_2M, TD_2M, U_10M, V_10M, PS, PMSL, TOT_PREC (native), and SP_10M (derived wind speed from U_10M/V_10M).", + "items": { + "type": "string" + }, + "title": "Params", + "type": "array" + }, + "leadtimes": { + "anyOf": [ + { + "items": { + "type": "integer" + }, + "type": "array" + }, + { + "const": "all", + "type": "string" + } + ], + "default": [ + 6, + 24 + ], + "description": "List of lead times (hours) to plot, or the literal string 'all' to expand to the union of step lists from all configured runs and baselines.", + "title": "Leadtimes" + }, + "scores": { + "default": [ + "BIAS" + ], + "description": "List of verification scores to plot. Supported: BIAS, RMSE, MAE.", + "items": { + "type": "string" + }, + "title": "Scores", + "type": "array" + }, + "regions": { + "default": [ + "switzerland" + ], + "description": "List of regions to plot (e.g. switzerland, centraleurope).", + "items": { + "type": "string" + }, + "title": "Regions", + "type": "array" + }, + "seasons": { + "default": [ + "all" + ], + "description": "List of seasons to plot ('all', 'DJF', 'MAM', 'JJA', 'SON').", + "items": { + "type": "string" + }, + "title": "Seasons", + "type": "array" + }, + "init_hours": { + "default": [ + "all" + ], + "description": "List of initialization hours to plot. Use 'all' for the unstratified view, or zero-padded hour strings like '00', '06', '12', '18'.", + "items": { + "type": "string" + }, + "title": "Init Hours", + "type": "array" + } + }, + "title": "ScoreMapsConfig", + "type": "object" + }, "ScorecardConfig": { "additionalProperties": false, "description": "Configuration for a single named scorecard.",