Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions packages/common/src/weathergen/common/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,6 @@ def combine(cls, others: list["IOReaderData"]) -> "IOReaderData":
assert other.geoinfos.shape[0] == n_datapoints, "number of datapoints do not match"
assert other.datetimes.shape[0] == n_datapoints, "number of datapoints do not match"

if n_datapoints == 0:
continue

coords = np.concatenate([coords, other.coords])
geoinfos = np.concatenate([geoinfos, other.geoinfos])
data = np.concatenate([data, other.data])
Expand Down Expand Up @@ -361,9 +358,9 @@ def __init__(
def _append_dataset(self, dataset: OutputDataset | None, name: str) -> None:
if dataset:
self.datasets.append(dataset)
else:
msg = f"Missing {name} dataset for item: {self.key.path}"
raise ValueError(msg)
# else:
# msg = f"Missing {name} dataset for item: {self.key.path}"
# raise ValueError(msg)


class ZarrIO:
Expand Down Expand Up @@ -739,11 +736,13 @@ def _extract_coordinates(self, stream_idx, offset_key, datapoints) -> DataCoordi

def _extract_sources(
self, sample: int, stream_idx: int, key: ItemKey, source_interval: TimeRange
) -> OutputDataset:
) -> OutputDataset | None:
channels = self.source_channels[stream_idx]
geoinfo_channels = self.geoinfo_channels[stream_idx]

source: IOReaderData = self.sources[sample][stream_idx]
if source is None:
return None

assert source.data.shape[1] == len(channels), (
f"Number of source channel names {len(channels)} does not align with source data."
Expand Down
20 changes: 7 additions & 13 deletions src/weathergen/utils/validation_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,12 @@ def write_output(
# handle spoof data: do not write since it might corrupt validation (spoofing invisible
# there)
if target_aux_out.physical[t_idx][sname]["is_spoof"][0]:
preds = model_output.get_physical_prediction(t_idx, sname)
# handle forcing streams or if sample is empty
if preds is None:
targets = target_aux_out.physical[t_idx][sname]["target"]
# preds are empty so create copy of target and add ensemble dimension
assert targets[0].shape[0] == 0, "Empty preds but non-empty targets."
preds = [target.clone().unsqueeze(0) for target in targets]
preds_shape = preds[0].shape
targets = target_aux_out.physical[t_idx][sname]["target"]
# for-loop to make sure we have a consistent number of samples
preds_s = [np.zeros((preds_shape[0], 0, preds_shape[2])) for _ in preds]
targets_s = [np.zeros((0, preds_shape[2])) for _ in preds]
t_coords_s = [np.zeros((0, 2)) for _ in preds]
t_times_s = [np.array([]).astype("datetime64[ns]") for _ in preds]
preds_s = [np.zeros((1, 0, t.shape[1])) for t in targets]
targets_s = [np.zeros((0, t.shape[1])) for t in targets]
t_coords_s = [np.zeros((0, 2)) for t in targets]
t_times_s = [np.array([]).astype("datetime64[ns]") for t in targets]

else:
preds = model_output.get_physical_prediction(t_idx, sname)
Expand Down Expand Up @@ -138,7 +131,8 @@ def write_output(
output_stream_names = stream_names

output_streams = {name: stream_names.index(name) for name in output_stream_names}
_logger.debug(f"Using output streams: {output_streams} from streams: {stream_names}")
if batch_idx == 0:
_logger.info(f"Using output streams: {output_streams} from streams: {stream_names}")

target_channels: list[list[str]] = [list(stream.val_target_channels) for stream in cf.streams]
source_channels: list[list[str]] = [list(stream.val_source_channels) for stream in cf.streams]
Expand Down
Loading