Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/evaluate/eval_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ evaluation:
plot_ensemble: "members" #supported: false, "std", "minmax", "members"
plot_score_maps: false #plot scores on a 2D maps. it slows down score computation
plot_score_animations: false #plot animations of score maps across forecast steps. it slows down score computation
plot_score_timeseries: true #plot timeseries of scores across forecast steps.
plot_score_init_timeseries: true #plot timeseries of scores across forecast steps.
print_summary: false #print out score values on screen. it can be verbose
log_scale: false
add_grid: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,18 +165,6 @@ def run_score_timeseries_pipeline(
f"{len(fsteps)} fsteps × {len(unique_hours)} init hours."
)

# --- Parallel plotting ---
_plot_timeseries_parallel(
reader,
stream,
scores_by_hour,
unique_hours,
fsteps,
da_tars,
global_plotting_options,
n_workers,
)

return scores_by_hour


Expand Down Expand Up @@ -213,119 +201,6 @@ def _compute_timeseries_scores_for_fstep(
return fstep, region, metric_scores


def _plot_single_timeseries(
output_dir: str,
run_id: str,
metric_name: str,
region: str,
channel: str | None,
fstep: int,
lt_label: str,
score: xr.DataArray,
unique_hours: list[int],
image_format: str,
dpi_val: int,
) -> None:
"""Plot a single timeseries figure (parallelisable worker)."""
matplotlib.use("Agg")

score_vals = score.sel(channel=channel) if channel is not None else score
hours = score_vals.coords["source_end_hour"].values
values = score_vals.values.flatten()

plt.figure(figsize=(10, 6), dpi=dpi_val)
plt.plot(hours, values, marker="o", linewidth=2, label=run_id)

ch_label = channel if channel is not None else "all"
title = f"{metric_name.upper()} vs source end hour | {ch_label} | {lt_label} | {region}"
plt.title(title)
plt.xlabel("Source window end hour [UTC]")
plt.ylabel(metric_name.upper())
plt.xlim(min(unique_hours) - 0.5, max(unique_hours) + 0.5)
plt.xticks(unique_hours)
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()

out_dir = Path(output_dir)
plot_path = (
out_dir / f"{metric_name}_{ch_label}_{region}_lead_{lt_label}"
f"_by_source_end_hour.{image_format}"
)
plt.savefig(plot_path, bbox_inches="tight")
plt.close()


def _plot_timeseries_parallel(
reader: Reader,
stream: str,
scores_by_hour: dict[str, dict[str, dict[int, xr.DataArray]]],
unique_hours: list[int],
fsteps: list[int],
da_tars: dict[int, xr.DataArray],
global_plotting_options: dict | None,
n_workers: int | None = None,
) -> None:
"""Dispatch timeseries plotting tasks in parallel."""

output_dir = reader.runplot_dir / "plots" / stream / "score_timeseries"
output_dir.mkdir(parents=True, exist_ok=True)

run_id = reader.run_id
plot_cfg = global_plotting_options or {}
image_format = plot_cfg.get("image_format", "png")
dpi_val = plot_cfg.get("dpi_val", 150)

# Build fstep → lead_time label mapping
lead_time_by_fstep: dict[int, str] = {}
for fstep in fsteps:
da = da_tars[fstep]
if "lead_time" in da.coords:
lt = da.coords["lead_time"].values
hours = int(lt.astype("timedelta64[h]").astype(int))
lead_time_by_fstep[fstep] = f"lead time {hours}h"
else:
lead_time_by_fstep[fstep] = f"fstep {fstep}"

# Build plot tasks
plot_tasks: list[dict] = []
for metric_name, region_dict in scores_by_hour.items():
for region, fstep_dict in region_dict.items():
sample_score = next(iter(fstep_dict.values()))
channels = (
list(sample_score.coords["channel"].values)
if "channel" in sample_score.dims
else [None]
)
for channel in channels:
for fstep, score in fstep_dict.items():
plot_tasks.append(
dict(
output_dir=str(output_dir),
run_id=run_id,
metric_name=metric_name,
region=region,
channel=channel,
fstep=fstep,
lt_label=lead_time_by_fstep[fstep],
score=score,
unique_hours=unique_hours,
image_format=image_format,
dpi_val=dpi_val,
)
)

_logger.info(
f"RUN {run_id} - {stream}: Plotting {len(plot_tasks)} timeseries figures "
f"with up to {n_workers} worker(s)."
)

calls = [delayed(_plot_single_timeseries)(**t) for t in plot_tasks]
dispatch_parallel(calls, n_workers=n_workers, backend="loky", desc=f"Timeseries plots {stream}")

_logger.info(f"RUN {run_id} - {stream}: Score timeseries plots saved to {output_dir}.")


# ---------------------------------------------------------------------------
# Score maps
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1144,7 +1019,7 @@ def plot_timeseries_summary(
image_format = plt_opt.get("image_format", "png")
dpi_val = plt_opt.get("dpi_val", 150)

ts_dir = summary_dir / "score_timeseries"
ts_dir = summary_dir / "score_init_time_series"
ts_dir.mkdir(parents=True, exist_ok=True)

for metric_name, region_dict in timeseries_scores.items():
Expand Down Expand Up @@ -1176,7 +1051,8 @@ def plot_timeseries_summary(
)
hours = score_vals.coords["source_end_hour"].values
values = score_vals.values.flatten()
label = runs[run_id].get("label", run_id)
run_label = runs[run_id].get("label", run_id)
label = f"{run_label} ({run_id})"
color = runs[run_id].get("color", None)
plt.plot(
hours,
Expand All @@ -1202,8 +1078,10 @@ def plot_timeseries_summary(
plt.legend()
plt.tight_layout()

run_ids_str = "_".join(sorted(run_dict.keys()))
plot_path = (
ts_dir / f"{metric_name}_{ch_label}_{region}_{stream}"
f"_{run_ids_str}"
f"_fstep_{fstep}_by_source_end_hour.{image_format}"
)
plt.savefig(plot_path, bbox_inches="tight")
Expand Down
10 changes: 5 additions & 5 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,16 +248,16 @@ def _process_stream(
return run_id, stream, {}, {}

plot_score_maps = plot_score_options.get("plot_score_maps", False) and type_ == "zarr"
plot_score_timeseries = (
plot_score_options.get("plot_score_timeseries", False) and type_ == "zarr"
plot_score_init_time_series = (
plot_score_options.get("plot_score_init_time_series", False) and type_ == "zarr"
)

stream_loaded_scores, recomputable_metrics = reader.load_scores(stream, regions, metrics)
scores_dict = stream_loaded_scores
if recomputable_metrics:
metrics_to_compute = recomputable_metrics
regions_to_compute = list(set(recomputable_metrics.keys()))
elif plot_score_maps or plot_score_timeseries:
elif plot_score_maps or plot_score_init_time_series:
metrics_to_compute = {r: metrics for r in regions}
regions_to_compute = regions
else:
Expand All @@ -284,7 +284,7 @@ def _process_stream(
plot_score_options=plot_score_options,
)

if plot_score_timeseries:
if plot_score_init_time_series:
ts_scores = run_score_timeseries_pipeline(
reader,
stream,
Expand Down Expand Up @@ -320,7 +320,7 @@ def evaluate_from_config(cfg: dict, mlflow_client: MlflowClient | None) -> None:
plot_score_options = {
"plot_score_maps": cfg.evaluation.get("plot_score_maps", False),
"plot_score_animations": cfg.evaluation.get("plot_score_animations", False),
"plot_score_timeseries": cfg.evaluation.get("plot_score_timeseries", False),
"plot_score_init_time_series": cfg.evaluation.get("plot_score_init_time_series", False),
}

global_plotting_opts = cfg.get("global_plotting_options", {})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,5 +268,5 @@ def needs_climatology(metrics_dict: dict) -> bool:
True if any metric requires climatology, False otherwise
"""
metrics = [m for metrics in metrics_dict.values() for m in metrics.keys()]
req_clim = ["acc", "rps", "rpss"]
req_clim = ["acc", "rps", "rpss"]
return any(m in req_clim for m in metrics)
Loading