From c22529da3df362e61c6f42557e46367e4380438c Mon Sep 17 00:00:00 2001 From: Savvas Melidonis Date: Tue, 2 Jun 2026 14:14:06 +0200 Subject: [PATCH 1/6] Add timeseries class, edit plot_orchestration --- .../evaluate/plotting/plot_orchestration.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 2a44c1993..f8b49d0b4 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -21,6 +21,7 @@ from joblib import delayed from PIL import Image from tqdm import tqdm +from weathergen.evaluate.plotting.timeseries import Timeseries from weathergen.evaluate.io.data.io_orchestration import dispatch_parallel, get_num_workers from weathergen.evaluate.io.io_reader import Reader, ReaderOutput @@ -758,6 +759,40 @@ def _dispatch_score_map_animations( return [p for r in results if r for p in r] +# --------------------------------------------------------------------------- +# Timeseries plots +# --------------------------------------------------------------------------- + + +def _dispatch_timeseries_plots( + da_preds: dict, + da_tars: dict, + output_dir: str, + stream: str, + run_id: str, + num_plot_workers: int, +) -> None: + """Build and dispatch timeseries plot tasks for all (channel, sample) pairs.""" + data_ts = Timeseries(da_preds, da_tars) + ts_tasks = [ + { + "output_dir": output_dir, + "channel": str(channel), + "sample": sample, + "stream": stream, + } + for channel in data_ts.get_channels() + for sample in data_ts.get_samples() + ] + calls = [delayed(data_ts.plot_single_timeseries)(**t) for t in ts_tasks] + dispatch_parallel( + calls, + n_workers=num_plot_workers, + backend="loky", + desc=f"Timeseries {run_id} - {stream}", + ) + + # --------------------------------------------------------------------------- # Per-sample map / histogram plots # --------------------------------------------------------------------------- @@ -899,7 +934,7 @@ def plot_data( stream_cfg = reader.get_stream(stream) plot_settings = stream_cfg.get("plotting", {}) - plot_keys = ("plot_maps", "plot_histograms", "plot_animations") + plot_keys = ("plot_maps", "plot_histograms", "plot_animations", "plot_timeseries") if not plot_settings or not any(plot_settings.get(k, False) for k in plot_keys): return @@ -933,6 +968,9 @@ def plot_data( plot_target = plot_settings.get("plot_target", True) if not isinstance(plot_target, bool): raise TypeError("plot_target must be a boolean.") + plot_timeseries = plot_settings.get("plot_timeseries", False) + if not isinstance(plot_timeseries, bool): + raise TypeError("plot_timeseries must be a boolean.") plot_histograms = plot_settings.get("plot_histograms", False) if not isinstance(plot_histograms, bool) and plot_histograms not in { "across-samples", @@ -1080,6 +1118,16 @@ def plot_data( desc=f"Across-samples plots {run_id} - {stream}", ) + if plot_timeseries: + _dispatch_timeseries_plots( + da_preds=da_preds, + da_tars=da_tars, + output_dir=output_dir, + stream=stream, + run_id=run_id, + num_plot_workers=num_plot_workers, + ) + if plot_animations: last_fstep = list(da_tars.keys())[-1] last_preds = da_preds[last_fstep] From 0a89b01f48e18aa1b183218138a734f6e0a79ffb Mon Sep 17 00:00:00 2001 From: Savvas Melidonis Date: Tue, 2 Jun 2026 14:25:36 +0200 Subject: [PATCH 2/6] add timeseries.py --- .../evaluate/plotting/timeseries.py | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py new file mode 100644 index 000000000..388fefd71 --- /dev/null +++ b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py @@ -0,0 +1,132 @@ +from pathlib import Path + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import xarray as xr + + +class Timeseries: + """ + Initialize the Timeseries class. + + Parameters + ---------- + da_preds: + Dictionary of prediction datasets. + da_tars: + Dictionary of target datasets. + """ + + def __init__(self, da_preds: dict[str, xr.Dataset], da_tars: dict[str, xr.Dataset]): + self.da_preds = da_preds + self.da_tars = da_tars + + preds_steps, tars_steps = [], [] + for da_p, da_t in zip(self.da_preds.values(), self.da_tars.values(), strict=False): + vt = da_p.valid_time.isel(ipoint=0).drop_vars("ipoint") + preds_steps.append(da_p.mean(dim="ipoint").assign_coords(valid_time=vt)) + tars_steps.append(da_t.mean(dim="ipoint").assign_coords(valid_time=vt)) + self.da_preds_ts, self.da_tars_ts = ( + xr.concat(preds_steps, dim="forecast_step"), + xr.concat(tars_steps, dim="forecast_step"), + ) + + def get_preds_tars_per_sample_channel( + self, sample: int | str, channel: str + ) -> tuple[xr.Dataset, xr.Dataset]: + """Get preds/tars for the given sample/channel from the timeseries data. + Parameters + ---------- + sample: int | str + The sample for which to extract data. + channel: str + The channel for which to extract data. + + Returns + ------- + tuple[xr.Dataset, xr.Dataset] + The prediction and target datasets for the given sample and channel. + """ + da_preds_slice, da_tars_slice = ( + self.da_preds_ts.sel(sample=sample, channel=channel), + self.da_tars_ts.sel(sample=sample, channel=channel), + ) + return da_preds_slice, da_tars_slice + + def get_valid_times_per_sample_channel( + self, sample: int | str, channel: str + ) -> npt.NDArray[np.datetime64]: + """Get valid times for the given sample/channel from the timeseries data. + Parameters + ---------- + sample: int | str + The sample for which to extract data. + channel: str + The channel for which to extract data. + + Returns + ------- + npt.NDArray[np.datetime64] + The array of valid times for the given sample and channel. + """ + return self.da_tars_ts.sel(sample=sample, channel=channel).valid_time.values + + def get_channels(self) -> list[str]: + """Get the list of channels from the timeseries data.""" + da_tmp = next(iter(self.da_tars.values())) + return da_tmp.channel.values + + def get_samples(self) -> list[str]: + """Get the list of samples from the timeseries data.""" + da_tmp = next(iter(self.da_tars.values())) + return da_tmp.sample.values + + def plot_single_timeseries( + self, output_dir: str, channel: str, sample: int | str, stream: str + ) -> None: + """Plot and save a timeseries figure for one (channel, sample) pair.""" + da_preds_slice, da_tars_slice = self.get_preds_tars_per_sample_channel(sample, channel) + valid_times = self.get_valid_times_per_sample_channel(sample, channel) + + matplotlib.use("Agg") + fig, ax = plt.subplots(figsize=(15, 7)) + ax.plot(valid_times, da_preds_slice.values, label="Prediction") + ax.plot(valid_times, da_tars_slice.values, label=stream, linestyle="--") + fig.suptitle( + f"Timeseries Average \u2013 {self.region_label()}", + fontsize=13, + fontweight="bold", + ) + ax.set_ylabel(channel) + ax.set_xlabel("Valid Time") + ax.legend() + max_ticks = max(4, 15) + day_interval = max(1, len(valid_times) // max_ticks) + ax.xaxis.set_major_locator(matplotlib.dates.DayLocator(interval=day_interval)) + ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d")) + ax.grid(True, linestyle="--", alpha=0.5) + fig.autofmt_xdate() + out_path = Path(output_dir) / "plots" / stream / "timeseries" + out_path.mkdir(parents=True, exist_ok=True) + fig.savefig( + out_path / f"timeseries_{channel}_sample_{sample}.png", + bbox_inches="tight", + ) + plt.close(fig) + + def region_label(self) -> str: + """Get a human-readable label for the region based on the lat/lon bounds.""" + + _first_da = next(iter(self.da_tars.values())) + lat_min = float(_first_da.lat.min()) + lat_max = float(_first_da.lat.max()) + lon_min = float(_first_da.lon.min()) + lon_max = float(_first_da.lon.max()) + region_label = ( + f"({lat_min:.2f}\u00b0 \u2013 {lat_max:.2f}\u00b0 N, " + f"{lon_min:.2f}\u00b0 \u2013 {lon_max:.2f}\u00b0 E)" + ) + + return region_label From 28cb87bb1480d60f4fb273a84cb2e9edb8367d69 Mon Sep 17 00:00:00 2001 From: Savvas Melidonis Date: Tue, 2 Jun 2026 15:22:34 +0200 Subject: [PATCH 3/6] Take care of ensembles --- .../evaluate/plotting/plot_orchestration.py | 8 ++++++- .../evaluate/plotting/timeseries.py | 24 +++++++++++++------ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index f8b49d0b4..5852b212f 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -771,18 +771,23 @@ def _dispatch_timeseries_plots( stream: str, run_id: str, num_plot_workers: int, + ensemble: list, ) -> None: - """Build and dispatch timeseries plot tasks for all (channel, sample) pairs.""" + """Build and dispatch timeseries plot tasks for all (channel, sample[, ens]) triples.""" data_ts = Timeseries(da_preds, da_tars) + has_ens = any("ens" in v.dims for v in da_preds.values()) + ens_members = ensemble if has_ens else [None] ts_tasks = [ { "output_dir": output_dir, "channel": str(channel), "sample": sample, "stream": stream, + "ens": ens, } for channel in data_ts.get_channels() for sample in data_ts.get_samples() + for ens in ens_members ] calls = [delayed(data_ts.plot_single_timeseries)(**t) for t in ts_tasks] dispatch_parallel( @@ -1126,6 +1131,7 @@ def plot_data( stream=stream, run_id=run_id, num_plot_workers=num_plot_workers, + ensemble=list(available_data.ensemble), ) if plot_animations: diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py index 388fefd71..a96f046bb 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py @@ -84,15 +84,25 @@ def get_samples(self) -> list[str]: return da_tmp.sample.values def plot_single_timeseries( - self, output_dir: str, channel: str, sample: int | str, stream: str + self, + output_dir: str, + channel: str, + sample: int | str, + stream: str, + ens: str | int | None = None, ) -> None: - """Plot and save a timeseries figure for one (channel, sample) pair.""" + """Plot and save a timeseries figure for one (channel, sample[, ens]) triple.""" da_preds_slice, da_tars_slice = self.get_preds_tars_per_sample_channel(sample, channel) + has_ens = ens is not None and "ens" in da_preds_slice.dims and ens != "mean" + if has_ens: + da_preds_slice = da_preds_slice.sel(ens=ens) valid_times = self.get_valid_times_per_sample_channel(sample, channel) + pred_label = "Prediction" if not has_ens else f"Prediction (ens {ens})" + matplotlib.use("Agg") fig, ax = plt.subplots(figsize=(15, 7)) - ax.plot(valid_times, da_preds_slice.values, label="Prediction") + ax.plot(valid_times, da_preds_slice.values, label=pred_label) ax.plot(valid_times, da_tars_slice.values, label=stream, linestyle="--") fig.suptitle( f"Timeseries Average \u2013 {self.region_label()}", @@ -110,10 +120,10 @@ def plot_single_timeseries( fig.autofmt_xdate() out_path = Path(output_dir) / "plots" / stream / "timeseries" out_path.mkdir(parents=True, exist_ok=True) - fig.savefig( - out_path / f"timeseries_{channel}_sample_{sample}.png", - bbox_inches="tight", - ) + fname = f"timeseries_{channel}_sample_{sample}" + if has_ens: + fname += f"_ens_{ens}" + fig.savefig(out_path / f"{fname}.png", bbox_inches="tight") plt.close(fig) def region_label(self) -> str: From 8245ce47166de7a2936a02ad5becd9e7cf95b6a2 Mon Sep 17 00:00:00 2001 From: melidonis1 Date: Thu, 11 Jun 2026 11:57:28 +0200 Subject: [PATCH 4/6] Changes to reply to reviews. Remove functions for retrieving channels and samples --- .../evaluate/plotting/plot_orchestration.py | 14 +++++--- .../evaluate/plotting/timeseries.py | 34 ++----------------- 2 files changed, 11 insertions(+), 37 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 5852b212f..6a12e2536 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -770,8 +770,10 @@ def _dispatch_timeseries_plots( output_dir: str, stream: str, run_id: str, - num_plot_workers: int, + samples: dict, + channels: dict, ensemble: list, + n_workers: int, ) -> None: """Build and dispatch timeseries plot tasks for all (channel, sample[, ens]) triples.""" data_ts = Timeseries(da_preds, da_tars) @@ -785,14 +787,14 @@ def _dispatch_timeseries_plots( "stream": stream, "ens": ens, } - for channel in data_ts.get_channels() - for sample in data_ts.get_samples() + for channel in channels + for sample in samples for ens in ens_members ] calls = [delayed(data_ts.plot_single_timeseries)(**t) for t in ts_tasks] dispatch_parallel( calls, - n_workers=num_plot_workers, + n_workers=n_workers, backend="loky", desc=f"Timeseries {run_id} - {stream}", ) @@ -1130,8 +1132,10 @@ def plot_data( output_dir=output_dir, stream=stream, run_id=run_id, - num_plot_workers=num_plot_workers, + samples=plot_sample_set, + channels=plot_channel_set, ensemble=list(available_data.ensemble), + n_workers=num_plot_workers, ) if plot_animations: diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py index a96f046bb..9342533dc 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py @@ -2,8 +2,6 @@ import matplotlib import matplotlib.pyplot as plt -import numpy as np -import numpy.typing as npt import xarray as xr @@ -55,34 +53,6 @@ def get_preds_tars_per_sample_channel( ) return da_preds_slice, da_tars_slice - def get_valid_times_per_sample_channel( - self, sample: int | str, channel: str - ) -> npt.NDArray[np.datetime64]: - """Get valid times for the given sample/channel from the timeseries data. - Parameters - ---------- - sample: int | str - The sample for which to extract data. - channel: str - The channel for which to extract data. - - Returns - ------- - npt.NDArray[np.datetime64] - The array of valid times for the given sample and channel. - """ - return self.da_tars_ts.sel(sample=sample, channel=channel).valid_time.values - - def get_channels(self) -> list[str]: - """Get the list of channels from the timeseries data.""" - da_tmp = next(iter(self.da_tars.values())) - return da_tmp.channel.values - - def get_samples(self) -> list[str]: - """Get the list of samples from the timeseries data.""" - da_tmp = next(iter(self.da_tars.values())) - return da_tmp.sample.values - def plot_single_timeseries( self, output_dir: str, @@ -96,7 +66,7 @@ def plot_single_timeseries( has_ens = ens is not None and "ens" in da_preds_slice.dims and ens != "mean" if has_ens: da_preds_slice = da_preds_slice.sel(ens=ens) - valid_times = self.get_valid_times_per_sample_channel(sample, channel) + valid_times = self.da_tars_ts.sel(sample=sample, channel=channel).valid_time.values pred_label = "Prediction" if not has_ens else f"Prediction (ens {ens})" @@ -112,7 +82,7 @@ def plot_single_timeseries( ax.set_ylabel(channel) ax.set_xlabel("Valid Time") ax.legend() - max_ticks = max(4, 15) + max_ticks = 20 day_interval = max(1, len(valid_times) // max_ticks) ax.xaxis.set_major_locator(matplotlib.dates.DayLocator(interval=day_interval)) ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d")) From 52d66e852c6b685b134e5fde61c2104e1dffb96d Mon Sep 17 00:00:00 2001 From: melidonis1 Date: Thu, 11 Jun 2026 12:16:32 +0200 Subject: [PATCH 5/6] resolving conflicts rebase to develop --- packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py b/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py index 0cd860137..82036021e 100644 --- a/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py +++ b/packages/evaluate/src/weathergen/evaluate/utils/clim_utils.py @@ -268,5 +268,5 @@ def needs_climatology(metrics_dict: dict) -> bool: True if any metric requires climatology, False otherwise """ metrics = [m for metrics in metrics_dict.values() for m in metrics.keys()] - req_clim = ["acc", "rps", "rpss"] + req_clim = ["acc", "rps", "rpss"] return any(m in req_clim for m in metrics) From 049d23d83ae69df1bafecbbdf65baaa73b98471a Mon Sep 17 00:00:00 2001 From: melidonis1 Date: Thu, 11 Jun 2026 15:08:07 +0200 Subject: [PATCH 6/6] Include also timeseries per region. Change titles in timeseries to regions --- .../evaluate/plotting/plot_orchestration.py | 6 +- .../evaluate/plotting/timeseries.py | 74 +++++++++---------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py index 6a12e2536..d7160fe0d 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/plot_orchestration.py @@ -21,7 +21,6 @@ from joblib import delayed from PIL import Image from tqdm import tqdm -from weathergen.evaluate.plotting.timeseries import Timeseries from weathergen.evaluate.io.data.io_orchestration import dispatch_parallel, get_num_workers from weathergen.evaluate.io.io_reader import Reader, ReaderOutput @@ -44,6 +43,7 @@ from weathergen.evaluate.plotting.plotter import Plotter from weathergen.evaluate.plotting.quantile_plots import QuantilePlots from weathergen.evaluate.plotting.score_cards import ScoreCards +from weathergen.evaluate.plotting.timeseries import Timeseries from weathergen.evaluate.scores.score import VerifiedData, get_score from weathergen.evaluate.utils.array_utils import bias_ranges, common_ranges from weathergen.evaluate.utils.clim_utils import get_climatology, needs_climatology @@ -769,6 +769,7 @@ def _dispatch_timeseries_plots( da_tars: dict, output_dir: str, stream: str, + regions: list[str], run_id: str, samples: dict, channels: dict, @@ -785,11 +786,13 @@ def _dispatch_timeseries_plots( "channel": str(channel), "sample": sample, "stream": stream, + "region": region, "ens": ens, } for channel in channels for sample in samples for ens in ens_members + for region in regions ] calls = [delayed(data_ts.plot_single_timeseries)(**t) for t in ts_tasks] dispatch_parallel( @@ -1131,6 +1134,7 @@ def plot_data( da_tars=da_tars, output_dir=output_dir, stream=stream, + regions=plotter.regions, run_id=run_id, samples=plot_sample_set, channels=plot_channel_set, diff --git a/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py index 9342533dc..cf2f1da53 100644 --- a/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py +++ b/packages/evaluate/src/weathergen/evaluate/plotting/timeseries.py @@ -4,6 +4,8 @@ import matplotlib.pyplot as plt import xarray as xr +from weathergen.evaluate.utils.regions import RegionBoundingBox + class Timeseries: """ @@ -21,22 +23,14 @@ def __init__(self, da_preds: dict[str, xr.Dataset], da_tars: dict[str, xr.Datase self.da_preds = da_preds self.da_tars = da_tars - preds_steps, tars_steps = [], [] - for da_p, da_t in zip(self.da_preds.values(), self.da_tars.values(), strict=False): - vt = da_p.valid_time.isel(ipoint=0).drop_vars("ipoint") - preds_steps.append(da_p.mean(dim="ipoint").assign_coords(valid_time=vt)) - tars_steps.append(da_t.mean(dim="ipoint").assign_coords(valid_time=vt)) - self.da_preds_ts, self.da_tars_ts = ( - xr.concat(preds_steps, dim="forecast_step"), - xr.concat(tars_steps, dim="forecast_step"), - ) - - def get_preds_tars_per_sample_channel( - self, sample: int | str, channel: str + def get_preds_tars_per_region_sample_channel( + self, region: str, sample: int | str, channel: str ) -> tuple[xr.Dataset, xr.Dataset]: """Get preds/tars for the given sample/channel from the timeseries data. Parameters ---------- + region: str + The region for which to extract data. sample: int | str The sample for which to extract data. channel: str @@ -47,11 +41,24 @@ def get_preds_tars_per_sample_channel( tuple[xr.Dataset, xr.Dataset] The prediction and target datasets for the given sample and channel. """ - da_preds_slice, da_tars_slice = ( - self.da_preds_ts.sel(sample=sample, channel=channel), - self.da_tars_ts.sel(sample=sample, channel=channel), + + preds_steps, tars_steps = [], [] + for da_p, da_t in zip(self.da_preds.values(), self.da_tars.values(), strict=False): + # Select sample and channel first so lat/lon become 1D (ipoint,) + da_p = da_p.sel(sample=sample, channel=channel) + da_t = da_t.sel(sample=sample, channel=channel) + if region != "global": + bbox = RegionBoundingBox.from_region_name(region) + da_p = bbox.apply_mask(da_p) + da_t = bbox.apply_mask(da_t) + vt = da_p.valid_time.isel(ipoint=0).drop_vars("ipoint") + preds_steps.append(da_p.mean(dim="ipoint").assign_coords(valid_time=vt)) + tars_steps.append(da_t.mean(dim="ipoint").assign_coords(valid_time=vt)) + da_preds_ts, da_tars_ts = ( + xr.concat(preds_steps, dim="forecast_step"), + xr.concat(tars_steps, dim="forecast_step"), ) - return da_preds_slice, da_tars_slice + return da_preds_ts, da_tars_ts def plot_single_timeseries( self, @@ -59,23 +66,27 @@ def plot_single_timeseries( channel: str, sample: int | str, stream: str, + region: str, ens: str | int | None = None, ) -> None: """Plot and save a timeseries figure for one (channel, sample[, ens]) triple.""" - da_preds_slice, da_tars_slice = self.get_preds_tars_per_sample_channel(sample, channel) - has_ens = ens is not None and "ens" in da_preds_slice.dims and ens != "mean" + + da_preds_ts, da_tars_ts = self.get_preds_tars_per_region_sample_channel( + region, sample, channel + ) + has_ens = ens is not None and "ens" in da_preds_ts.dims and ens != "mean" if has_ens: - da_preds_slice = da_preds_slice.sel(ens=ens) - valid_times = self.da_tars_ts.sel(sample=sample, channel=channel).valid_time.values + da_preds_ts = da_preds_ts.sel(ens=ens) + valid_times = da_tars_ts.valid_time.values pred_label = "Prediction" if not has_ens else f"Prediction (ens {ens})" matplotlib.use("Agg") fig, ax = plt.subplots(figsize=(15, 7)) - ax.plot(valid_times, da_preds_slice.values, label=pred_label) - ax.plot(valid_times, da_tars_slice.values, label=stream, linestyle="--") + ax.plot(valid_times, da_preds_ts.values, label=pred_label) + ax.plot(valid_times, da_tars_ts.values, label=stream, linestyle="--") fig.suptitle( - f"Timeseries Average \u2013 {self.region_label()}", + f"Timeseries Average - {region.capitalize()}", fontsize=13, fontweight="bold", ) @@ -90,23 +101,8 @@ def plot_single_timeseries( fig.autofmt_xdate() out_path = Path(output_dir) / "plots" / stream / "timeseries" out_path.mkdir(parents=True, exist_ok=True) - fname = f"timeseries_{channel}_sample_{sample}" + fname = f"timeseries_{region}_{channel}_sample_{sample}" if has_ens: fname += f"_ens_{ens}" fig.savefig(out_path / f"{fname}.png", bbox_inches="tight") plt.close(fig) - - def region_label(self) -> str: - """Get a human-readable label for the region based on the lat/lon bounds.""" - - _first_da = next(iter(self.da_tars.values())) - lat_min = float(_first_da.lat.min()) - lat_max = float(_first_da.lat.max()) - lon_min = float(_first_da.lon.min()) - lon_max = float(_first_da.lon.max()) - region_label = ( - f"({lat_min:.2f}\u00b0 \u2013 {lat_max:.2f}\u00b0 N, " - f"{lon_min:.2f}\u00b0 \u2013 {lon_max:.2f}\u00b0 E)" - ) - - return region_label