Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions config/config_forecasting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ training_config:
type: LossPhysical,
loss_fcts: { "mse": { }, },
},
"cosine_matching": {
type: LossLatentCosineMatching,
weight: 1.0,
target_and_aux_calc: "Physical",
loss_fcts: {
"params": {
cosine_low: 0.68,
cosine_high: 0.78,
},
},
},
}

model_input: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,5 +268,5 @@ def needs_climatology(metrics_dict: dict) -> bool:
True if any metric requires climatology, False otherwise
"""
metrics = [m for metrics in metrics_dict.values() for m in metrics.keys()]
req_clim = ["acc", "rps", "rpss"]
req_clim = ["acc", "rps", "rpss"]
return any(m in req_clim for m in metrics)
10 changes: 9 additions & 1 deletion src/weathergen/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,10 +697,18 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput:
without_grad = p_fwd and self.training and step != max(batch.get_output_idxs())
if without_grad:
# Pushforward mode: advance tokens without grad; no decoding with torch.no_grad():
prev_tokens = tokens
tokens = self.forecast_engine(tokens, step, model_params.rope_coords)
continue

prev_tokens = tokens
tokens = self.forecast_engine(tokens, step, model_params.rope_coords)

# per-token cosine similarity between current and previous patch tokens
cur = tokens[:, self.num_aux_tokens :].reshape(-1, tokens.shape[-1])
prv = prev_tokens[:, self.num_aux_tokens :].reshape(-1, tokens.shape[-1])
cos_sim_to_prev = torch.nn.functional.cosine_similarity(cur, prv.detach(), dim=-1)
output.add_latent_prediction(step, "cos_sim_to_prev", cos_sim_to_prev)

# decoder predictions
output = self.predict_decoders(model_params, step, tokens, batch, output)
# latent predictions (raw and with SSL heads)
Expand Down
3 changes: 2 additions & 1 deletion src/weathergen/train/loss_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from .loss_module_cosine_matching import LossLatentCosineMatching
from .loss_module_physical import LossPhysical
from .loss_module_ssl import LossLatentSSLStudentTeacher

__all__ = [LossPhysical, LossLatentSSLStudentTeacher]
__all__ = [LossPhysical, LossLatentSSLStudentTeacher, LossLatentCosineMatching]
64 changes: 64 additions & 0 deletions src/weathergen/train/loss_modules/loss_module_cosine_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# (C) Copyright 2025 WeatherGenerator contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging

import torch
import torch.nn.functional as F
from omegaconf import DictConfig

from weathergen.train.loss_modules.loss_module_base import LossModuleBase, LossValues
from weathergen.utils.train_logger import Stage

_logger = logging.getLogger(__name__)


class LossLatentCosineMatching(LossModuleBase):
"""
Band hinge on per-token cosine similarity between consecutive FE latent steps.

Penalises tokens whose cosine similarity to the previous step falls outside
[cosine_low, cosine_high]. Both bounds are enforced as soft hinge losses so
the FE is free inside the sweet-spot and pays a quadratic penalty outside it.

cos_sim_to_prev is computed in model.forward() per patch token and stored in
output.latent[step]["cos_sim_to_prev"] — no target calculator needed.
"""

def __init__(
self, cf: DictConfig, mode_cfg: DictConfig, stage: Stage, device: str, **loss_fcts
):
LossModuleBase.__init__(self)
self.cf = cf
self.stage = stage
self.device = device
self.name = "LossLatentCosineMatching"

params = next(iter(loss_fcts.values()), {}) if loss_fcts else {}
self.cosine_low = params.get("cosine_low", 0.68)
self.cosine_high = params.get("cosine_high", 0.78)

def compute_loss(self, preds, targets, metadata, **kwargs) -> LossValues:
acc_loss = torch.tensor(0.0, device=self.device, requires_grad=True)
count = 0

for step_pred in preds.latent:
cos_sim = step_pred.get("cos_sim_to_prev", None)
if cos_sim is None:
continue
step_loss = (
F.relu(cos_sim - self.cosine_high) ** 2 + F.relu(self.cosine_low - cos_sim) ** 2
).mean()
acc_loss = acc_loss + step_loss
count += 1

loss = acc_loss / count if count > 0 else acc_loss
return LossValues(
loss=loss, losses_all={"cosine_band": loss.detach().item()}, stddev_all={}
)
Loading