Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from weathergen.evaluate.plotting.plotter import Plotter
from weathergen.evaluate.plotting.quantile_plots import QuantilePlots
from weathergen.evaluate.plotting.score_cards import ScoreCards
from weathergen.evaluate.plotting.timeseries import Timeseries
from weathergen.evaluate.scores.score import VerifiedData, get_score
from weathergen.evaluate.utils.array_utils import bias_ranges, common_ranges
from weathergen.evaluate.utils.clim_utils import get_climatology, needs_climatology
Expand Down Expand Up @@ -758,6 +759,50 @@ def _dispatch_score_map_animations(
return [p for r in results if r for p in r]


# ---------------------------------------------------------------------------
# Timeseries plots
# ---------------------------------------------------------------------------


def _dispatch_timeseries_plots(
da_preds: dict,
da_tars: dict,
output_dir: str,
stream: str,
regions: list[str],
run_id: str,
samples: dict,
channels: dict,
ensemble: list,
n_workers: int,
) -> None:
"""Build and dispatch timeseries plot tasks for all (channel, sample[, ens]) triples."""
data_ts = Timeseries(da_preds, da_tars)
has_ens = any("ens" in v.dims for v in da_preds.values())
ens_members = ensemble if has_ens else [None]
ts_tasks = [
{
"output_dir": output_dir,
"channel": str(channel),
"sample": sample,
"stream": stream,
"region": region,
"ens": ens,
}
for channel in channels
for sample in samples
for ens in ens_members
for region in regions
]
calls = [delayed(data_ts.plot_single_timeseries)(**t) for t in ts_tasks]
dispatch_parallel(
calls,
n_workers=n_workers,
backend="loky",
desc=f"Timeseries {run_id} - {stream}",
)


# ---------------------------------------------------------------------------
# Per-sample map / histogram plots
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -899,7 +944,7 @@ def plot_data(
stream_cfg = reader.get_stream(stream)
plot_settings = stream_cfg.get("plotting", {})

plot_keys = ("plot_maps", "plot_histograms", "plot_animations")
plot_keys = ("plot_maps", "plot_histograms", "plot_animations", "plot_timeseries")
if not plot_settings or not any(plot_settings.get(k, False) for k in plot_keys):
return

Expand Down Expand Up @@ -933,6 +978,9 @@ def plot_data(
plot_target = plot_settings.get("plot_target", True)
if not isinstance(plot_target, bool):
raise TypeError("plot_target must be a boolean.")
plot_timeseries = plot_settings.get("plot_timeseries", False)
if not isinstance(plot_timeseries, bool):
raise TypeError("plot_timeseries must be a boolean.")
plot_histograms = plot_settings.get("plot_histograms", False)
if not isinstance(plot_histograms, bool) and plot_histograms not in {
"across-samples",
Expand Down Expand Up @@ -1080,6 +1128,20 @@ def plot_data(
desc=f"Across-samples plots {run_id} - {stream}",
)

if plot_timeseries:
_dispatch_timeseries_plots(
da_preds=da_preds,
da_tars=da_tars,
output_dir=output_dir,
stream=stream,
regions=plotter.regions,
run_id=run_id,
samples=plot_sample_set,
channels=plot_channel_set,
ensemble=list(available_data.ensemble),
n_workers=num_plot_workers,
)

if plot_animations:
last_fstep = list(da_tars.keys())[-1]
last_preds = da_preds[last_fstep]
Expand Down
108 changes: 108 additions & 0 deletions packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import xarray as xr

from weathergen.evaluate.utils.regions import RegionBoundingBox


class Timeseries:
"""
Initialize the Timeseries class.

Parameters
----------
da_preds:
Dictionary of prediction datasets.
da_tars:
Dictionary of target datasets.
"""

def __init__(self, da_preds: dict[str, xr.Dataset], da_tars: dict[str, xr.Dataset]):
self.da_preds = da_preds
self.da_tars = da_tars

def get_preds_tars_per_region_sample_channel(
self, region: str, sample: int | str, channel: str
) -> tuple[xr.Dataset, xr.Dataset]:
"""Get preds/tars for the given sample/channel from the timeseries data.
Parameters
----------
region: str
The region for which to extract data.
sample: int | str
The sample for which to extract data.
channel: str
The channel for which to extract data.

Returns
-------
tuple[xr.Dataset, xr.Dataset]
The prediction and target datasets for the given sample and channel.
"""

preds_steps, tars_steps = [], []
for da_p, da_t in zip(self.da_preds.values(), self.da_tars.values(), strict=False):
# Select sample and channel first so lat/lon become 1D (ipoint,)
da_p = da_p.sel(sample=sample, channel=channel)
da_t = da_t.sel(sample=sample, channel=channel)
if region != "global":
bbox = RegionBoundingBox.from_region_name(region)
da_p = bbox.apply_mask(da_p)
da_t = bbox.apply_mask(da_t)
vt = da_p.valid_time.isel(ipoint=0).drop_vars("ipoint")
preds_steps.append(da_p.mean(dim="ipoint").assign_coords(valid_time=vt))
tars_steps.append(da_t.mean(dim="ipoint").assign_coords(valid_time=vt))
da_preds_ts, da_tars_ts = (
xr.concat(preds_steps, dim="forecast_step"),
xr.concat(tars_steps, dim="forecast_step"),
)
return da_preds_ts, da_tars_ts

def plot_single_timeseries(
self,
output_dir: str,
channel: str,
sample: int | str,
stream: str,
region: str,
ens: str | int | None = None,
) -> None:
"""Plot and save a timeseries figure for one (channel, sample[, ens]) triple."""

da_preds_ts, da_tars_ts = self.get_preds_tars_per_region_sample_channel(
region, sample, channel
)
has_ens = ens is not None and "ens" in da_preds_ts.dims and ens != "mean"
if has_ens:
da_preds_ts = da_preds_ts.sel(ens=ens)
valid_times = da_tars_ts.valid_time.values

pred_label = "Prediction" if not has_ens else f"Prediction (ens {ens})"

matplotlib.use("Agg")
fig, ax = plt.subplots(figsize=(15, 7))
ax.plot(valid_times, da_preds_ts.values, label=pred_label)
ax.plot(valid_times, da_tars_ts.values, label=stream, linestyle="--")
fig.suptitle(
f"Timeseries Average - {region.capitalize()}",
fontsize=13,
fontweight="bold",
)
ax.set_ylabel(channel)
ax.set_xlabel("Valid Time")
ax.legend()
max_ticks = 20
day_interval = max(1, len(valid_times) // max_ticks)
ax.xaxis.set_major_locator(matplotlib.dates.DayLocator(interval=day_interval))
ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d"))
ax.grid(True, linestyle="--", alpha=0.5)
fig.autofmt_xdate()
out_path = Path(output_dir) / "plots" / stream / "timeseries"
out_path.mkdir(parents=True, exist_ok=True)
fname = f"timeseries_{region}_{channel}_sample_{sample}"
if has_ens:
fname += f"_ens_{ens}"
fig.savefig(out_path / f"{fname}.png", bbox_inches="tight")
plt.close(fig)
Original file line number Diff line number Diff line change
Expand Up @@ -268,5 +268,5 @@ def needs_climatology(metrics_dict: dict) -> bool:
True if any metric requires climatology, False otherwise
"""
metrics = [m for metrics in metrics_dict.values() for m in metrics.keys()]
req_clim = ["acc", "rps", "rpss"]
req_clim = ["acc", "rps", "rpss"]
return any(m in req_clim for m in metrics)
Loading