diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 2a44c1993..d7160fe0d 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -43,6 +43,7 @@ from weathergen.evaluate.plotting.plotter import Plotter from weathergen.evaluate.plotting.quantile_plots import QuantilePlots from weathergen.evaluate.plotting.score_cards import ScoreCards +from weathergen.evaluate.plotting.timeseries import Timeseries from weathergen.evaluate.scores.score import VerifiedData, get_score from weathergen.evaluate.utils.array_utils import bias_ranges, common_ranges from weathergen.evaluate.utils.clim_utils import get_climatology, needs_climatology @@ -758,6 +759,50 @@ def _dispatch_score_map_animations( return [p for r in results if r for p in r] +# --------------------------------------------------------------------------- +# Timeseries plots +# --------------------------------------------------------------------------- + + +def _dispatch_timeseries_plots( + da_preds: dict, + da_tars: dict, + output_dir: str, + stream: str, + regions: list[str], + run_id: str, + samples: dict, + channels: dict, + ensemble: list, + n_workers: int, +) -> None: + """Build and dispatch timeseries plot tasks for all (channel, sample[, ens]) triples.""" + data_ts = Timeseries(da_preds, da_tars) + has_ens = any("ens" in v.dims for v in da_preds.values()) + ens_members = ensemble if has_ens else [None] + ts_tasks = [ + { + "output_dir": output_dir, + "channel": str(channel), + "sample": sample, + "stream": stream, + "region": region, + "ens": ens, + } + for channel in channels + for sample in samples + for ens in ens_members + for region in regions + ] + calls = [delayed(data_ts.plot_single_timeseries)(**t) for t in ts_tasks] + dispatch_parallel( + calls, + n_workers=n_workers, + backend="loky", + desc=f"Timeseries {run_id} - {stream}", + ) + + # --------------------------------------------------------------------------- # Per-sample map / histogram plots # --------------------------------------------------------------------------- @@ -899,7 +944,7 @@ def plot_data( stream_cfg = reader.get_stream(stream) plot_settings = stream_cfg.get("plotting", {}) - plot_keys = ("plot_maps", "plot_histograms", "plot_animations") + plot_keys = ("plot_maps", "plot_histograms", "plot_animations", "plot_timeseries") if not plot_settings or not any(plot_settings.get(k, False) for k in plot_keys): return @@ -933,6 +978,9 @@ def plot_data( plot_target = plot_settings.get("plot_target", True) if not isinstance(plot_target, bool): raise TypeError("plot_target must be a boolean.") + plot_timeseries = plot_settings.get("plot_timeseries", False) + if not isinstance(plot_timeseries, bool): + raise TypeError("plot_timeseries must be a boolean.") plot_histograms = plot_settings.get("plot_histograms", False) if not isinstance(plot_histograms, bool) and plot_histograms not in { "across-samples", @@ -1080,6 +1128,20 @@ def plot_data( desc=f"Across-samples plots {run_id} - {stream}", ) + if plot_timeseries: + _dispatch_timeseries_plots( + da_preds=da_preds, + da_tars=da_tars, + output_dir=output_dir, + stream=stream, + regions=plotter.regions, + run_id=run_id, + samples=plot_sample_set, + channels=plot_channel_set, + ensemble=list(available_data.ensemble), + n_workers=num_plot_workers, + ) + if plot_animations: last_fstep = list(da_tars.keys())[-1] last_preds = da_preds[last_fstep] diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py new file mode 100644 index 000000000..cf2f1da53 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py @@ -0,0 +1,108 @@ +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt +import xarray as xr + +from weathergen.evaluate.utils.regions import RegionBoundingBox + + +class Timeseries: + """ + Initialize the Timeseries class. + + Parameters + ---------- + da_preds: + Dictionary of prediction datasets. + da_tars: + Dictionary of target datasets. + """ + + def __init__(self, da_preds: dict[str, xr.Dataset], da_tars: dict[str, xr.Dataset]): + self.da_preds = da_preds + self.da_tars = da_tars + + def get_preds_tars_per_region_sample_channel( + self, region: str, sample: int | str, channel: str + ) -> tuple[xr.Dataset, xr.Dataset]: + """Get preds/tars for the given sample/channel from the timeseries data. + Parameters + ---------- + region: str + The region for which to extract data. + sample: int | str + The sample for which to extract data. + channel: str + The channel for which to extract data. + + Returns + ------- + tuple[xr.Dataset, xr.Dataset] + The prediction and target datasets for the given sample and channel. + """ + + preds_steps, tars_steps = [], [] + for da_p, da_t in zip(self.da_preds.values(), self.da_tars.values(), strict=False): + # Select sample and channel first so lat/lon become 1D (ipoint,) + da_p = da_p.sel(sample=sample, channel=channel) + da_t = da_t.sel(sample=sample, channel=channel) + if region != "global": + bbox = RegionBoundingBox.from_region_name(region) + da_p = bbox.apply_mask(da_p) + da_t = bbox.apply_mask(da_t) + vt = da_p.valid_time.isel(ipoint=0).drop_vars("ipoint") + preds_steps.append(da_p.mean(dim="ipoint").assign_coords(valid_time=vt)) + tars_steps.append(da_t.mean(dim="ipoint").assign_coords(valid_time=vt)) + da_preds_ts, da_tars_ts = ( + xr.concat(preds_steps, dim="forecast_step"), + xr.concat(tars_steps, dim="forecast_step"), + ) + return da_preds_ts, da_tars_ts + + def plot_single_timeseries( + self, + output_dir: str, + channel: str, + sample: int | str, + stream: str, + region: str, + ens: str | int | None = None, + ) -> None: + """Plot and save a timeseries figure for one (channel, sample[, ens]) triple.""" + + da_preds_ts, da_tars_ts = self.get_preds_tars_per_region_sample_channel( + region, sample, channel + ) + has_ens = ens is not None and "ens" in da_preds_ts.dims and ens != "mean" + if has_ens: + da_preds_ts = da_preds_ts.sel(ens=ens) + valid_times = da_tars_ts.valid_time.values + + pred_label = "Prediction" if not has_ens else f"Prediction (ens {ens})" + + matplotlib.use("Agg") + fig, ax = plt.subplots(figsize=(15, 7)) + ax.plot(valid_times, da_preds_ts.values, label=pred_label) + ax.plot(valid_times, da_tars_ts.values, label=stream, linestyle="--") + fig.suptitle( + f"Timeseries Average - {region.capitalize()}", + fontsize=13, + fontweight="bold", + ) + ax.set_ylabel(channel) + ax.set_xlabel("Valid Time") + ax.legend() + max_ticks = 20 + day_interval = max(1, len(valid_times) // max_ticks) + ax.xaxis.set_major_locator(matplotlib.dates.DayLocator(interval=day_interval)) + ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d")) + ax.grid(True, linestyle="--", alpha=0.5) + fig.autofmt_xdate() + out_path = Path(output_dir) / "plots" / stream / "timeseries" + out_path.mkdir(parents=True, exist_ok=True) + fname = f"timeseries_{region}_{channel}_sample_{sample}" + if has_ens: + fname += f"_ens_{ens}" + fig.savefig(out_path / f"{fname}.png", bbox_inches="tight") + plt.close(fig) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py b/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py index 0cd860137..82036021e 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py @@ -268,5 +268,5 @@ def needs_climatology(metrics_dict: dict) -> bool: True if any metric requires climatology, False otherwise """ metrics = [m for metrics in metrics_dict.values() for m in metrics.keys()] - req_clim = ["acc", "rps", "rpss"] + req_clim = ["acc", "rps", "rpss"] return any(m in req_clim for m in metrics)