diff --git a/pyproject.toml b/pyproject.toml index 00103cb8c..540d7cf25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ requires-python = ">=3.12,<3.13" dependencies = [ 'numpy~=2.2', 'astropy_healpix~=1.1.2', + 'healpy>=1.19,<2', 'zarr~=3.1.3', 'pandas~=2.2', 'tqdm', diff --git a/src/weathergen/model/attention.py b/src/weathergen/model/attention.py index bf97479e6..fb4250d98 100644 --- a/src/weathergen/model/attention.py +++ b/src/weathergen/model/attention.py @@ -14,13 +14,13 @@ from torch.nn.attention.flex_attention import create_block_mask, flex_attention from weathergen.model.norms import AdaLayerNorm, RMSNorm -from weathergen.model.positional_encoding import rotary_pos_emb_2d +from weathergen.model.positional_encoding import apply_rope """ Attention blocks used by WeatherGenerator. -Some blocks optionally apply 2D RoPE. When enabled, the caller must provide per-token 2D -coordinates aligned with the token order (lat, lon in radians). +Some blocks optionally apply RoPE-like positional modulation. When enabled, the caller must +provide per-token coordinates aligned with the token order (lat, lon in radians). """ @@ -40,7 +40,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, - with_2d_rope=False, + rope_mode="none", ): super(MultiSelfAttentionHeadVarlen, self).__init__() @@ -49,7 +49,10 @@ def __init__( self.with_flash = with_flash self.softcap = softcap self.with_residual = with_residual - self.with_2d_rope = with_2d_rope + self.rope_mode = rope_mode + self.rope_post_mod_qk_lnorm = rope_mode == "spherical" + if self.rope_post_mod_qk_lnorm: + assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True" assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -79,6 +82,9 @@ def __init__( lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps) self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) + post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity + self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) + self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype @@ -96,10 +102,10 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s) - if self.with_2d_rope: - if coords is None: - raise ValueError("coords must be provided when with_2d_rope=True") - qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) + qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1) + if self.rope_post_mod_qk_lnorm: + qs = self.post_rope_lnorm_q(qs).to(self.dtype) + ks = self.post_rope_lnorm_k(ks).to(self.dtype) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 @@ -225,7 +231,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, - with_2d_rope=False, + rope_mode="none", ): super(MultiSelfAttentionHeadLocal, self).__init__() @@ -233,7 +239,10 @@ def __init__( self.with_flash = with_flash self.softcap = softcap self.with_residual = with_residual - self.with_2d_rope = with_2d_rope + self.rope_mode = rope_mode + self.rope_post_mod_qk_lnorm = rope_mode == "spherical" + if self.rope_post_mod_qk_lnorm: + assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True" assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -263,6 +272,9 @@ def __init__( lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps) self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) + post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity + self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) + self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype assert with_flash, "Only flash attention supported." @@ -288,10 +300,10 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3]) vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3]) - if self.with_2d_rope: - if coords is None: - raise ValueError("coords must be provided when with_2d_rope=True") - qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1) + qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1) + if self.rope_post_mod_qk_lnorm: + qs = self.post_rope_lnorm_q(qs).to(self.dtype) + ks = self.post_rope_lnorm_k(ks).to(self.dtype) outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2) @@ -540,7 +552,7 @@ def __init__( dim_aux=None, norm_eps=1e-5, attention_dtype=torch.bfloat16, - with_2d_rope=False, + rope_mode="none", ): super(MultiSelfAttentionHead, self).__init__() @@ -549,7 +561,10 @@ def __init__( self.softcap = softcap self.dropout_rate = dropout_rate self.with_residual = with_residual - self.with_2d_rope = with_2d_rope + self.rope_mode = rope_mode + self.rope_post_mod_qk_lnorm = rope_mode == "spherical" + if self.rope_post_mod_qk_lnorm: + assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True" assert dim_embed % num_heads == 0 self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj @@ -579,6 +594,9 @@ def __init__( lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps) self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps) + post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity + self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) + self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps) self.dtype = attention_dtype if with_flash: @@ -599,10 +617,10 @@ def forward(self, x, coords=None, ada_ln_aux=None): ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype) vs = self.proj_heads_v(x).reshape(s).to(self.dtype) - if self.with_2d_rope: - if coords is None: - raise ValueError("coords must be provided when with_2d_rope=True") - qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2) + qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 2) + if self.rope_post_mod_qk_lnorm: + qs = self.post_rope_lnorm_q(qs).to(self.dtype) + ks = self.post_rope_lnorm_k(ks).to(self.dtype) # set dropout rate according to training/eval mode as required by flash_attn dropout_rate = self.dropout_rate if self.training else 0.0 diff --git a/src/weathergen/model/encoder.py b/src/weathergen/model/encoder.py index 5dea1bdae..36c5ee105 100644 --- a/src/weathergen/model/encoder.py +++ b/src/weathergen/model/encoder.py @@ -6,13 +6,17 @@ # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import logging +import math +import numpy as np import torch from astropy_healpix import healpy from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config from weathergen.datasets.batch import ModelBatch +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.engines import ( EmbeddingEngine, GlobalAssimilationEngine, @@ -24,7 +28,15 @@ # from weathergen.model.model import ModelParams from weathergen.model.parametrised_prob_dist import LatentInterpolator -from weathergen.model.positional_encoding import positional_encoding_harmonic +from weathergen.model.positional_encoding import ( + build_spherical_rope_coeff_tensors, + get_rope_mode, + get_rope_spherical_band, + positional_encoding_harmonic, +) +from weathergen.utils.utils import get_dtype + +logger = logging.getLogger(__name__) class EncoderModule(torch.nn.Module): @@ -44,7 +56,80 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**self.healpix_level - self.cf = cf + self.dtype = get_dtype(cf.attention_dtype) + + # Positional embeddings + self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64) + self.register_buffer( + "pe_embed", + torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype), + ) + + self.register_buffer( + "q_cells_lens", torch.ones(self.num_healpix_cells + 1, dtype=torch.int32) + ) + self.q_cells_lens[0] = 0 + + self.register_buffer( + "pe_global", + torch.zeros( + self.num_healpix_cells, + cf.ae_local_num_queries, + cf.ae_global_dim_embed, + dtype=self.dtype, + ), + ) + + # RoPE coordinates + self.rope_mode = get_rope_mode(cf, logger) + if self.rope_mode != "none": + self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + total_tokens = ( + self.num_healpix_cells + self.num_extra_tokens + ) * cf.ae_local_num_queries + self.register_buffer( + "rope_coords", + torch.zeros( + 1, + total_tokens, + 2, + dtype=self.dtype, + ), + ) + self.register_buffer( + "rope_cell_coords", + torch.zeros( + self.num_healpix_cells, + 2, + dtype=self.dtype, + ), + ) + if self.rope_mode == "spherical": + rope_spherical_band = get_rope_spherical_band(cf) + num_modes = 2 * int(rope_spherical_band) + 1 + self.register_buffer( + "rope_spherical_coeffs", + torch.zeros(1, total_tokens, num_modes, 2, dtype=self.dtype), + ) + self.register_buffer( + "rope_spherical_cell_coeffs", + torch.zeros(self.num_healpix_cells, num_modes, 2, dtype=self.dtype), + ) + self.register_buffer( + "rope_spherical_extra_coeffs", + torch.zeros(self.num_extra_tokens, num_modes, 2, dtype=self.dtype), + ) + else: + self.rope_spherical_coeffs = None + self.rope_spherical_cell_coeffs = None + self.rope_spherical_extra_coeffs = None + else: + self.rope_coords = None + self.rope_cell_coords = None + self.rope_spherical_coeffs = None + self.rope_spherical_cell_coeffs = None + self.rope_spherical_extra_coeffs = None + self.sources_size = sources_size self.targets_num_channels = targets_num_channels self.targets_coords_size = targets_coords_size @@ -117,29 +202,131 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord # global assimilation engine self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells) - def forward(self, model_params, batch): + def reset_parameters(self) -> None: + """Creates positional embedding for each grid point for each stream used after stream + embedding, positional embedding for all stream assimilated cell-level local embedding, + initializing queries for local-to-global adapters, HEALPix neighbourhood based parameter + initializing for target prediction. + + Sinusoidal positional encoding: Harmonic positional encoding based upon sine and cosine for + both per stream after stream embedding and per cell level for local assimilation. + + Query len based parameter creation: Calculate parameters for the calculated token length at + each cell after local assimilation.""" + + cf = self.cf + + dim_embed = cf.ae_local_dim_embed + token_idx_bias = 16 + freq_bias = 8 + self.pe_embed.data.fill_(0.0) + position = torch.arange( + token_idx_bias, + token_idx_bias + self.max_tokens_local_per_cell, + device=self.pe_embed.device, + ).unsqueeze(1) + div = torch.exp( + torch.arange(freq_bias, freq_bias + dim_embed, 2, device=self.pe_embed.device) + * -(math.log(self.max_tokens_local_per_cell) / dim_embed), + ) + self.pe_embed.data[:, 0::2] = torch.sin(position * div[: self.pe_embed[:, 0::2].shape[1]]) + self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) + + dim_embed = cf.ae_global_dim_embed + + if self.rope_mode != "none": + verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) + coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) + self.rope_cell_coords.data.copy_(coords) + coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) + coords_flat = coords.flatten(0, 1).unsqueeze(0) + offset = self.num_extra_tokens * cf.ae_local_num_queries + self.rope_coords.data.fill_(0.0) + self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) + + if self.rope_mode == "spherical": + band = int(get_rope_spherical_band(cf)) + ( + (cell_real, cell_imag), + (extra_real, extra_imag), + (packed_extra_real, packed_extra_imag), + (packed_real, packed_imag), + ) = build_spherical_rope_coeff_tensors( + nside=2**self.healpix_level, + band=band, + num_local_queries=cf.ae_local_num_queries, + num_extra_tokens=self.num_extra_tokens, + device=self.rope_spherical_coeffs.device, + dtype=self.rope_spherical_coeffs.dtype, + ) + self.rope_spherical_cell_coeffs.data[..., 0].copy_(cell_real) + self.rope_spherical_cell_coeffs.data[..., 1].copy_(cell_imag) + self.rope_spherical_extra_coeffs.data[..., 0].copy_(extra_real) + self.rope_spherical_extra_coeffs.data[..., 1].copy_(extra_imag) + + self.rope_spherical_coeffs.data.fill_(0.0) + self.rope_spherical_coeffs.data[:, :offset, :, 0].copy_(packed_extra_real) + self.rope_spherical_coeffs.data[:, :offset, :, 1].copy_(packed_extra_imag) + self.rope_spherical_coeffs.data[ + :, offset : offset + packed_real.shape[1], :, 0 + ].copy_(packed_real) + self.rope_spherical_coeffs.data[ + :, offset : offset + packed_imag.shape[1], :, 1 + ].copy_(packed_imag) + + self.pe_global.data.fill_(0.0) + xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed + self.pe_global.data[..., 0::2] = 0.5 * torch.sin( + torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) + ) + self.pe_global.data[..., 0::2] += ( + torch.sin( + torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) + ) + self.pe_global.data[..., 1::2] = 0.5 * torch.cos( + torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) + ) + self.pe_global.data[..., 1::2] += ( + torch.cos( + torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) + ) + .unsqueeze(1) + .repeat((1, cf.ae_local_num_queries, 1)) + ) + + self.q_cells_lens.data.fill_(1) + self.q_cells_lens.data[0] = 0 + + def forward(self, batch): """ Encoder forward """ stream_cell_tokens = checkpoint( - self.embed_engine, batch, model_params.pe_embed, use_reentrant=False + self.embed_engine, batch, self.pe_embed, use_reentrant=False ) tokens_global, posteriors = checkpoint( - self.assimilate_local, model_params, stream_cell_tokens, batch, use_reentrant=False + self.assimilate_local, stream_cell_tokens, batch, use_reentrant=False ) tokens_global = checkpoint( self.ae_global_engine, tokens_global, - coords=model_params.rope_coords, + coords=( + self.rope_spherical_coeffs.unbind(dim=-1) + if self.rope_spherical_coeffs is not None + else self.rope_coords + ), use_reentrant=False, ) return tokens_global, posteriors - def interpolate_latents(self, tokens: torch.Tensor) -> (torch.Tensor, torch.Tensor): + def interpolate_latents(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ " TODO """ @@ -221,6 +408,8 @@ def aggregation_engine_unmasked( tokens_global_register_class, tokens_lens, rope_cell_coords=None, + rope_cell_coeffs=None, + rope_extra_coeffs=None, ): """ Aggregation engine on the global latents of unmasked cells @@ -251,8 +440,19 @@ def aggregation_engine_unmasked( ) # Build packed coords matching the interleaved token order - if rope_cell_coords is not None: - num_extra = self.num_class_tokens + self.num_register_tokens + num_extra = self.num_class_tokens + self.num_register_tokens + if rope_cell_coeffs is not None: + extra_real, extra_imag = rope_extra_coeffs.unbind(dim=-1) + cell_real, cell_imag = rope_cell_coeffs.unbind(dim=-1) + packed_real = [] + packed_imag = [] + for mask_b in cell_mask.flatten(0, 1): + packed_real.append(extra_real) + packed_imag.append(extra_imag) + packed_real.append(cell_real[mask_b]) + packed_imag.append(cell_imag[mask_b]) + packed_coords = (torch.cat(packed_real, dim=0), torch.cat(packed_imag, dim=0)) + elif rope_cell_coords is not None: zero_coords = torch.zeros( num_extra, 2, device=rope_cell_coords.device, dtype=rope_cell_coords.dtype ) @@ -272,9 +472,7 @@ def aggregation_engine_unmasked( return tokens_global_unmasked - def assimilate_local( - self, model_params, tokens: torch.Tensor, batch: ModelBatch - ) -> torch.Tensor: + def assimilate_local(self, tokens: torch.Tensor, batch: ModelBatch) -> torch.Tensor: """ Processes embedded tokens locally and prepares them for the global assimilation @@ -299,15 +497,15 @@ def assimilate_local( # TODO: re-enable or remove ae_local_queries_per_cell if self.cf.ae_local_queries_per_cell: - tokens_global = (self.q_cells + model_params.pe_global).repeat(rs, 1, 1) + tokens_global = (self.q_cells + self.pe_global).repeat(rs, 1, 1) else: num_tokens = self.num_healpix_cells - tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + model_params.pe_global + tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + self.pe_global tokens_global = tokens_global.repeat(rs, 1, 1) # apply local assimilation engine and project onto global latent vectors tokens_global_unmasked, posteriors = self.assimilate_local_project_chunked( - tokens, tokens_global, cell_lens, model_params.q_cells_lens + tokens, tokens_global, cell_lens, self.q_cells_lens ) # apply aggregation engine on unmasked tokens @@ -315,7 +513,9 @@ def assimilate_local( tokens_global_unmasked, tokens_global_register_class, batch.tokens_lens, - rope_cell_coords=model_params.rope_cell_coords, + rope_cell_coords=self.rope_cell_coords, + rope_cell_coeffs=self.rope_spherical_cell_coeffs, + rope_extra_coeffs=self.rope_spherical_extra_coeffs, ) # final processing diff --git a/src/weathergen/model/engines.py b/src/weathergen/model/engines.py index 9c1a1e3a9..45d7cf817 100644 --- a/src/weathergen/model/engines.py +++ b/src/weathergen/model/engines.py @@ -16,6 +16,7 @@ from torch.utils.checkpoint import checkpoint from weathergen.common.config import Config +from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.attention import ( MultiCrossAttentionHeadVarlen, MultiCrossAttentionHeadVarlenSlicedQ, @@ -29,6 +30,11 @@ StreamEmbedTransformer, ) from weathergen.model.layers import MLP +from weathergen.model.positional_encoding import ( + build_spherical_rope_coeff_tensors, + get_rope_mode, + get_rope_spherical_band, +) from weathergen.model.utils import ActivationFactory from weathergen.utils.utils import get_dtype @@ -390,6 +396,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: super(QueryAggregationEngine, self).__init__() self.cf = cf self.num_healpix_cells = num_healpix_cells + rope_mode = get_rope_mode(self.cf) self.ae_aggregation_blocks = torch.nn.ModuleList() @@ -410,7 +417,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - with_2d_rope=self.cf.get("rope_2D", False), + rope_mode=rope_mode, ) ) else: @@ -466,6 +473,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: super(GlobalAssimilationEngine, self).__init__() self.cf = cf self.num_healpix_cells = num_healpix_cells + rope_mode = get_rope_mode(self.cf) self.ae_global_blocks = torch.nn.ModuleList() @@ -486,7 +494,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - with_2d_rope=self.cf.get("rope_2D", False), + rope_mode=rope_mode, ) ) else: @@ -503,7 +511,7 @@ def __init__(self, cf: Config, num_healpix_cells: int) -> None: qk_norm_type=self.cf.get("qk_norm_type", self.cf.norm_type), norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - with_2d_rope=self.cf.get("rope_2D", False), + rope_mode=rope_mode, ) ) # MLP block @@ -555,6 +563,63 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = self.cf = cf self.num_healpix_cells = num_healpix_cells self.fe_blocks = torch.nn.ModuleList() + self.rope_2D = cf.get("rope_2D", False) + self.healpix_level = ( + cf.get("fe_healpix_level") + if cf.get("fe_healpix_level") is not None + else cf.get("healpix_level") + ) + self.dtype = get_dtype(cf.attention_dtype) + + # RoPE coordinates + self.rope_mode = get_rope_mode(cf) + if self.rope_mode != "none": + self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens + total_tokens = ( + self.num_healpix_cells + self.num_extra_tokens + ) * cf.ae_local_num_queries + self.register_buffer( + "rope_coords", + torch.zeros( + 1, + total_tokens, + 2, + dtype=self.dtype, + ), + ) + self.register_buffer( + "rope_cell_coords", + torch.zeros( + self.num_healpix_cells, + 2, + dtype=self.dtype, + ), + ) + if self.rope_mode == "spherical": + rope_spherical_band = get_rope_spherical_band(cf) + num_modes = 2 * int(rope_spherical_band) + 1 + self.register_buffer( + "rope_spherical_coeffs", + torch.zeros(1, total_tokens, num_modes, 2, dtype=self.dtype), + ) + self.register_buffer( + "rope_spherical_cell_coeffs", + torch.zeros(self.num_healpix_cells, num_modes, 2, dtype=self.dtype), + ) + self.register_buffer( + "rope_spherical_extra_coeffs", + torch.zeros(self.num_extra_tokens, num_modes, 2, dtype=self.dtype), + ) + else: + self.rope_spherical_coeffs = None + self.rope_spherical_cell_coeffs = None + self.rope_spherical_extra_coeffs = None + else: + self.rope_coords = None + self.rope_cell_coords = None + self.rope_spherical_coeffs = None + self.rope_spherical_cell_coeffs = None + self.rope_spherical_extra_coeffs = None global_rate = int(1 / self.cf.forecast_att_dense_rate) if mode_cfg.get("forecast", {}).get("policy") is not None: @@ -573,7 +638,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - with_2d_rope=self.cf.get("rope_2D", False), + rope_mode=self.rope_mode, ) ) else: @@ -591,7 +656,7 @@ def __init__(self, cf: Config, mode_cfg, num_healpix_cells: int, dim_aux: int = dim_aux=dim_aux, norm_eps=self.cf.norm_eps, attention_dtype=get_dtype(self.cf.attention_dtype), - with_2d_rope=self.cf.get("rope_2D", False), + rope_mode=self.rope_mode, ) ) # Add MLP block @@ -621,7 +686,52 @@ def init_weights_final(m): for block in self.fe_blocks: block.apply(init_weights_final) - def forward(self, tokens, fstep, coords=None): + def reset_parameters(self) -> None: + """HEALPix neighbourhood based parameter initializing for target prediction.""" + + cf = self.cf + + if self.rope_mode != "none": + verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) + coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) + self.rope_cell_coords.data.copy_(coords) + coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) + coords_flat = coords.flatten(0, 1).unsqueeze(0) + offset = self.num_extra_tokens * cf.ae_local_num_queries + self.rope_coords.data.fill_(0.0) + self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) + + if self.rope_mode == "spherical": + band = int(get_rope_spherical_band(cf)) + ( + (cell_real, cell_imag), + (extra_real, extra_imag), + (packed_extra_real, packed_extra_imag), + (packed_real, packed_imag), + ) = build_spherical_rope_coeff_tensors( + nside=2**self.healpix_level, + band=band, + num_local_queries=cf.ae_local_num_queries, + num_extra_tokens=self.num_extra_tokens, + device=self.rope_spherical_coeffs.device, + dtype=self.rope_spherical_coeffs.dtype, + ) + self.rope_spherical_cell_coeffs.data[..., 0].copy_(cell_real) + self.rope_spherical_cell_coeffs.data[..., 1].copy_(cell_imag) + self.rope_spherical_extra_coeffs.data[..., 0].copy_(extra_real) + self.rope_spherical_extra_coeffs.data[..., 1].copy_(extra_imag) + + self.rope_spherical_coeffs.data.fill_(0.0) + self.rope_spherical_coeffs.data[:, :offset, :, 0].copy_(packed_extra_real) + self.rope_spherical_coeffs.data[:, :offset, :, 1].copy_(packed_extra_imag) + self.rope_spherical_coeffs.data[ + :, offset : offset + packed_real.shape[1], :, 0 + ].copy_(packed_real) + self.rope_spherical_coeffs.data[ + :, offset : offset + packed_imag.shape[1], :, 1 + ].copy_(packed_imag) + + def forward(self, tokens, fstep): if self.training: # Impute noise to the latent state noise_std = self.cf.get("fe_impute_latent_noise_std", 0.0) @@ -633,7 +743,15 @@ def forward(self, tokens, fstep, coords=None): if isinstance(block, torch.nn.modules.normalization.LayerNorm): tokens = checkpoint(block, tokens, use_reentrant=False) else: - tokens = checkpoint(block, tokens, coords, aux_info, use_reentrant=False) + tokens = checkpoint( + block, + tokens, + self.rope_spherical_coeffs.unbind(dim=-1) + if self.rope_spherical_coeffs is not None + else self.rope_coords, + aux_info, + use_reentrant=False, + ) return tokens diff --git a/src/weathergen/model/model.py b/src/weathergen/model/model.py index d8a30a722..0e3dde055 100644 --- a/src/weathergen/model/model.py +++ b/src/weathergen/model/model.py @@ -10,12 +10,10 @@ # nor does it submit to any jurisdiction. import logging -import math import typing import warnings import astropy_healpix as hp -import astropy_healpix.healpy import numpy as np import torch import torch.nn as nn @@ -23,7 +21,6 @@ from weathergen.common.config import Config from weathergen.datasets.batch import ModelBatch -from weathergen.datasets.utils import healpix_verts_rots, r3tos2 from weathergen.model.encoder import EncoderModule from weathergen.model.engines import ( BilinearDecoder, @@ -92,50 +89,6 @@ def __init__(self, cf) -> None: self.healpix_level = cf.healpix_level self.num_healpix_cells = 12 * 4**cf.healpix_level - self.dtype = get_dtype(cf.attention_dtype) - - # Positional embeddings - self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64) - self.pe_embed = torch.nn.Parameter( - torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype), - requires_grad=False, - ) - - pe = torch.zeros( - self.num_healpix_cells, - cf.ae_local_num_queries, - cf.ae_global_dim_embed, - dtype=self.dtype, - ) - self.pe_global = torch.nn.Parameter(pe, requires_grad=False) - - # RoPE coordinates - self.rope_2D = cf.get("rope_2D", False) - if self.rope_2D: - self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens - total_tokens = ( - self.num_healpix_cells + self.num_extra_tokens - ) * cf.ae_local_num_queries - self.register_buffer( - "rope_coords", - torch.zeros( - 1, - total_tokens, - 2, - dtype=self.dtype, - ), - ) - self.register_buffer( - "rope_cell_coords", - torch.zeros( - self.num_healpix_cells, - 2, - dtype=self.dtype, - ), - ) - else: - self.rope_coords = None - self.rope_cell_coords = None # HEALPix neighbours hlc = self.healpix_level @@ -151,97 +104,15 @@ def __init__(self, cf) -> None: requires_grad=False, ) - self.q_cells_lens = torch.nn.Parameter( - torch.ones(self.num_healpix_cells + 1, dtype=torch.int32), requires_grad=False - ) - self.q_cells_lens.data[0] = 0 - def create(self, cf: Config) -> "ModelParams": - self.reset_parameters(cf) + self.reset_parameters() return self - def reset_parameters(self, cf: Config) -> "ModelParams": - """Creates positional embedding for each grid point for each stream used after stream - embedding, positional embedding for all stream assimilated cell-level local embedding, - initializing queries for local-to-global adapters, HEALPix neighbourhood based parameter - initializing for target prediction. - - Sinusoidal positional encoding: Harmonic positional encoding based upon sine and cosine for - both per stream after stream embedding and per cell level for local assimilation. - - HEALPix neighbourhood structure: Determine the neighbors for each cell and initialize each - with its own cell number as well as the cell numbers of its neighbors. If a cell has - fewer than eight neighbors, use its own cell number to fill the remaining slots. - - Query len based parameter creation: Calculate parameters for the calculated token length at - each cell after local assimilation. - - Args: - cf : Configuration + def reset_parameters(self) -> "ModelParams": + """HEALPix neighbourhood structure: Determine the neighbors for each cell and initialize + each with its own cell number as well as the cell numbers of its neighbors. If a cell has + fewer than eight neighbors, use its own cell number to fill the remaining slots. """ - - # positional encodings - - dim_embed = cf.ae_local_dim_embed - token_idx_bias = 16 - freq_bias = 8 - self.pe_embed.data.fill_(0.0) - position = torch.arange( - token_idx_bias, - token_idx_bias + self.max_tokens_local_per_cell, - device=self.pe_embed.device, - ).unsqueeze(1) - div = torch.exp( - torch.arange(freq_bias, freq_bias + dim_embed, 2, device=self.pe_embed.device) - * -(math.log(self.max_tokens_local_per_cell) / dim_embed), - ) - self.pe_embed.data[:, 0::2] = torch.sin(position * div[: self.pe_embed[:, 0::2].shape[1]]) - self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]]) - - dim_embed = cf.ae_global_dim_embed - - if self.rope_2D: - # Precompute per-cell center coordinates (lat, lon in radians) for 2D RoPE. - # Shape: (num_healpix_cells, ae_local_num_queries, 2) - verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5) - coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype) - # Per-cell coords for QueryAggregationEngine (no query expansion) - self.rope_cell_coords.data.copy_(coords) - coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1) - coords_flat = coords.flatten(0, 1).unsqueeze(0) - offset = self.num_extra_tokens * cf.ae_local_num_queries - self.rope_coords.data.fill_(0.0) - self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat) - - # pe_global: always initialized. RoPE handles relative position in Q/K, but pe_global - # provides per-cell token identity which is critical for masked cells that have no - # content from local assimilation. Without it, masked cells are identical and the - # teacher representation (evaluated without dropout) collapses to low rank. - self.pe_global.data.fill_(0.0) - xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed - self.pe_global.data[..., 0::2] = 0.5 * torch.sin( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 0::2] += ( - torch.sin( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) - ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) - self.pe_global.data[..., 1::2] = 0.5 * torch.cos( - torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs) - ) - self.pe_global.data[..., 1::2] += ( - torch.cos( - torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs) - ) - .unsqueeze(1) - .repeat((1, cf.ae_local_num_queries, 1)) - ) - - # healpix neighborhood structure - hlc = self.healpix_level num_healpix_cells = self.num_healpix_cells with warnings.catch_warnings(action="ignore"): @@ -253,12 +124,6 @@ def reset_parameters(self, cf: Config) -> "ModelParams": self.hp_nbours.data[:, 0] = torch.arange(temp.shape[0], device=self.hp_nbours.device) self.hp_nbours.data[:, 1:] = torch.from_numpy(temp).to(self.hp_nbours.device) - # precompute for varlen attention - self.q_cells_lens.data.fill_(1) - self.q_cells_lens.data[0] = 0 - - # ensure all params have grad set to False - return @@ -588,6 +453,10 @@ def _reset_params(module): pass self.apply(_reset_params) + if self.encoder is not None: + self.encoder.reset_parameters() + if self.forecast_engine is not None: + self.forecast_engine.reset_parameters() def print_num_parameters(self) -> None: """Print number of parameters for entire model and each module used to build the model""" @@ -682,7 +551,7 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: output = ModelOutput(batch.get_output_len()) - tokens, posteriors = self.encoder(model_params, batch) + tokens, posteriors = self.encoder(batch) output.add_latent_prediction(0, "posteriors", posteriors) # recover batch dimension and separate input_steps @@ -697,20 +566,19 @@ def forward(self, model_params: ModelParams, batch: ModelBatch) -> ModelOutput: without_grad = p_fwd and self.training and step != max(batch.get_output_idxs()) if without_grad: # Pushforward mode: advance tokens without grad; no decoding with torch.no_grad(): - tokens = self.forecast_engine(tokens, step, model_params.rope_coords) + tokens = self.forecast_engine(tokens, step) continue - tokens = self.forecast_engine(tokens, step, model_params.rope_coords) + tokens = self.forecast_engine(tokens, step) # decoder predictions output = self.predict_decoders(model_params, step, tokens, batch, output) # latent predictions (raw and with SSL heads) - output = self.predict_latent(model_params, step, tokens, batch, output) + output = self.predict_latent(step, tokens, batch, output) return output def predict_latent( self, - model_params: ModelParams, step: int, tokens: torch.Tensor, batch: ModelBatch, diff --git a/src/weathergen/model/model_interface.py b/src/weathergen/model/model_interface.py index 1423652f7..ce469f2a4 100644 --- a/src/weathergen/model/model_interface.py +++ b/src/weathergen/model/model_interface.py @@ -17,7 +17,7 @@ MixedPrecisionPolicy, fully_shard, ) -from torch.distributed.tensor import distribute_tensor +from torch.distributed.tensor import DTensor, distribute_tensor from weathergen.common.config import Config, get_path_model, merge_configs from weathergen.model.attention import ( @@ -164,7 +164,7 @@ def init_model_and_shard( # model params model_params = ModelParams(cf).create(cf) - model_params.reset_parameters(cf) + model_params.reset_parameters() model_params = model_params.to(f"cuda:{cf.local_rank}") return model, model_params @@ -196,13 +196,15 @@ def load_model(cf, model, device, run_id: str, mini_epoch=-1): if sharded_meta_param is None: logger.warning(f"Parameter {param_name} from checkpoint not found in model.") continue - sharded_tensor = distribute_tensor( - full_tensor, - sharded_meta_param.device_mesh, - sharded_meta_param.placements, - ) - # maybe_sharded_sd[param_name.replace("module.", "")] = nn.Parameter(sharded_tensor) - maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + if isinstance(sharded_meta_param, DTensor): + sharded_tensor = distribute_tensor( + full_tensor, + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + maybe_sharded_sd[param_name] = torch.nn.Parameter(sharded_tensor) + else: + maybe_sharded_sd[param_name] = full_tensor.to(device) # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor mkeys, ukeys = model.load_state_dict(maybe_sharded_sd, strict=False, assign=True) diff --git a/src/weathergen/model/positional_encoding.py b/src/weathergen/model/positional_encoding.py index 411942a9f..2341f75ba 100644 --- a/src/weathergen/model/positional_encoding.py +++ b/src/weathergen/model/positional_encoding.py @@ -8,8 +8,11 @@ # nor does it submit to any jurisdiction. import math +from functools import lru_cache +import healpy as hp import numpy as np +import numpy.typing as npt import torch @@ -173,3 +176,202 @@ def rotary_pos_emb_2d(q, k, coords, base=10000.0, unsqueeze_dim=1): cos, sin = rotary_embedding_2d(coords, q.shape[-1], base=base) return apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=unsqueeze_dim) + + +# Spherical RoPE +def _max_supported_spherical_band(dim_embed: int, num_heads: int) -> int: + head_dim = dim_embed // num_heads + max_complex = (head_dim - (head_dim % 2)) // 2 + return max(0, (max_complex - 1) // 2) + + +def get_rope_mode(cf, logger=None) -> str: + """Resolve RoPE mode, including temporary backwards compatibility for rope_2D.""" + + rope_mode = cf.get("rope_mode", "none") or "none" + rope_2d = cf.get("rope_2D", None) + if rope_2d is not None: + if logger is not None: + logger.warning( + "Config key 'rope_2D' is deprecated and will be removed. Use 'rope_mode' " + "with one of: none, 2d, spherical." + ) + if rope_mode == "none": + rope_mode = "2d" if rope_2d else "none" + return rope_mode + + +def get_rope_spherical_band(cf) -> int: + """Resolve spherical band index, supporting explicit config or automatic selection.""" + + rope_spherical_band = cf.get("rope_spherical_band", None) + if rope_spherical_band is not None: + return int(rope_spherical_band) + + candidates = [ + _max_supported_spherical_band(cf.ae_global_dim_embed, cf.ae_aggregation_num_heads), + _max_supported_spherical_band(cf.ae_global_dim_embed, cf.ae_global_num_heads), + ] + if cf.get("fe_num_blocks", 0) > 0: + candidates.append(_max_supported_spherical_band(cf.ae_global_dim_embed, cf.fe_num_heads)) + return min(candidates) + + +def apply_rope(qs, ks, coords, rope_mode, unsqueeze_dim): + rope_mode = rope_mode or "none" + if rope_mode == "none": + return qs, ks + if coords is None: + raise ValueError(f"coords must be provided when rope_mode={rope_mode}") + if rope_mode == "2d": + return rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=unsqueeze_dim) + if rope_mode == "spherical": + return rotary_pos_emb_spherical(qs, ks, coords, unsqueeze_dim=unsqueeze_dim) + raise ValueError(f"Unsupported rope_mode={rope_mode}") + + +def rotary_pos_emb_spherical( + q: torch.Tensor, + k: torch.Tensor, + coeffs: tuple[torch.Tensor, torch.Tensor], + unsqueeze_dim: int = 1, +): + """Apply spherical-harmonic RoPE-style modulation to q/k using precomputed coefficients. + + Both q and k are multiplied by Y_lm(omega) at their respective positions. Under the real-pair + representation of complex modes, the attention dot product is equivalent to + Re[sum_m Y_lm(omega_r) Y_lm*(omega_s) q_m k_m*]. + """ + + coeff_real, coeff_imag = coeffs + return ( + _apply_complex_modulation(q, coeff_real, coeff_imag, unsqueeze_dim), + _apply_complex_modulation(k, coeff_real, coeff_imag, unsqueeze_dim), + ) + + +def _apply_complex_modulation( + x: torch.Tensor, + coeff_real: torch.Tensor, + coeff_imag: torch.Tensor, + unsqueeze_dim: int, +) -> torch.Tensor: + coeff_real = coeff_real.unsqueeze(unsqueeze_dim).to(dtype=x.dtype) + coeff_imag = coeff_imag.unsqueeze(unsqueeze_dim).to(dtype=x.dtype) + num_complex = coeff_real.shape[-1] + max_complex = (x.shape[-1] - (x.shape[-1] % 2)) // 2 + if num_complex > max_complex: + raise ValueError( + f"Spherical RoPE requires {num_complex} complex modes but the head only supports " + f"{max_complex}. Reduce rope_spherical_band or increase the head dimension." + ) + num_rotary_dims = 2 * num_complex + if num_rotary_dims == 0: + return x + + x_rot = x[..., :num_rotary_dims].reshape(*x.shape[:-1], num_complex, 2) + x_real = x_rot[..., 0] + x_imag = x_rot[..., 1] + out_real = (x_real * coeff_real) - (x_imag * coeff_imag) + out_imag = (x_real * coeff_imag) + (x_imag * coeff_real) + out = torch.stack((out_real, out_imag), dim=-1).flatten(-2, -1) + if num_rotary_dims < x.shape[-1]: + out = torch.cat((out, x[..., num_rotary_dims:]), dim=-1) + return out + + +def build_spherical_rope_coeff_tensors( + nside: int, + band: int, + num_local_queries: int, + num_extra_tokens: int, + device=None, + dtype=torch.float32, +) -> tuple[ + tuple[torch.Tensor, torch.Tensor], + tuple[torch.Tensor, torch.Tensor], + tuple[torch.Tensor, torch.Tensor], + tuple[torch.Tensor, torch.Tensor], +]: + """Build spherical-RoPE coefficient tensors for cell-level, extra tokens, and packed tokens.""" + + real_maps, imag_maps = _healpy_band_maps(nside, band) + cell_real = torch.as_tensor(real_maps, device=device, dtype=dtype) + cell_imag = torch.as_tensor(imag_maps, device=device, dtype=dtype) + + extra_real = torch.ones( + num_extra_tokens, cell_real.shape[-1], device=cell_real.device, dtype=cell_real.dtype + ) + extra_imag = torch.zeros_like(extra_real) + packed_extra_real = ( + extra_real.unsqueeze(1).repeat(1, num_local_queries, 1).flatten(0, 1).unsqueeze(0) + ) + packed_extra_imag = ( + extra_imag.unsqueeze(1).repeat(1, num_local_queries, 1).flatten(0, 1).unsqueeze(0) + ) + + packed_real = cell_real.unsqueeze(1).repeat(1, num_local_queries, 1).flatten(0, 1).unsqueeze(0) + packed_imag = cell_imag.unsqueeze(1).repeat(1, num_local_queries, 1).flatten(0, 1).unsqueeze(0) + + return ( + (cell_real, cell_imag), + (extra_real, extra_imag), + (packed_extra_real, packed_extra_imag), + (packed_real, packed_imag), + ) + + +@lru_cache(maxsize=32) +def _healpy_band_maps( + nside: int, band: int +) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: + """Precompute one spherical-harmonic band on the HEALPix grid using healpy. + + The returned columns store the complex coefficients Y_lm(omega) for fixed l=band and + m=-l,...,+l. These are the position factors used in spherical RoPE: + + q_m^omega = Y_lm(omega) q_m, k_m^omega = Y_lm(omega) k_m. + + The following attention dot product then implicitly forms + Y_lm(omega_r) Y_lm*(omega_s), matching the spherical harmonics addition-theorem + structure. + """ + + num_pixels = hp.nside2npix(nside) + real_maps = np.zeros((num_pixels, 2 * band + 1), dtype=np.float64) + imag_maps = np.zeros((num_pixels, 2 * band + 1), dtype=np.float64) + alm_size = hp.sphtfunc.Alm.getsize(band, band) + + for m in range(0, band + 1): + # healpy stores alm only for m >= 0 and alm2map reconstructs a real field. Setting + # a_lm=1 gives 2 Re[Y_lm] for m>0, while a_lm=i gives -2 Im[Y_lm]. We combine these + # two real maps below to recover the complex coefficient Y_lm itself. + alm_real = np.zeros(alm_size, dtype=np.complex128) + alm_real[hp.sphtfunc.Alm.getidx(band, band, m)] = 1.0 + real_map = hp.alm2map(alm_real, nside=nside, lmax=band, mmax=band, pol=False) + real_map = hp.reorder(real_map, r2n=True) + + if m == 0: + # Y_l0 is real, and healpy returns it directly because there is no -m counterpart + # to merge into the real map. + real_maps[:, band] = real_map + continue + + alm_imag = np.zeros(alm_size, dtype=np.complex128) + alm_imag[hp.sphtfunc.Alm.getidx(band, band, m)] = 1.0j + imag_map = hp.alm2map(alm_imag, nside=nside, lmax=band, mmax=band, pol=False) + imag_map = hp.reorder(imag_map, r2n=True) + + pos_idx = band + m + neg_idx = band - m + sign = -1.0 if m % 2 else 1.0 + + # Columns are ordered as m=-l,...,+l, hence band+m for +m and band-m for -m. + # The negative-order mode follows the standard convention + # Y_l,-m = (-1)^m Y_lm*. + real_maps[:, pos_idx] = real_map / 2.0 + imag_maps[:, pos_idx] = -imag_map / 2.0 + real_maps[:, neg_idx] = sign * real_map / 2.0 + imag_maps[:, neg_idx] = sign * imag_map / 2.0 + + return real_maps, imag_maps diff --git a/src/weathergen/train/trainer.py b/src/weathergen/train/trainer.py index 47bc74214..163613ef3 100644 --- a/src/weathergen/train/trainer.py +++ b/src/weathergen/train/trainer.py @@ -652,7 +652,11 @@ def _get_full_model_state_dict(self): if self.cf.with_ddp and self.cf.with_fsdp: cpu_state_dict = {} for param_name, sharded_param in maybe_sharded_sd.items(): - full_param = sharded_param.full_tensor() + full_param = ( + sharded_param.full_tensor() + if isinstance(sharded_param, DTensor) + else sharded_param + ) if is_root(): cpu_state_dict[param_name] = full_param.cpu() else: