Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 133 additions & 10 deletions src/weathergen/model/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch
from astropy_healpix import healpy
from torch.utils.checkpoint import checkpoint
import math
import numpy as np

from weathergen.common.config import Config
from weathergen.datasets.batch import ModelBatch
Expand All @@ -25,6 +27,8 @@
# from weathergen.model.model import ModelParams
from weathergen.model.parametrised_prob_dist import LatentInterpolator
from weathergen.model.positional_encoding import positional_encoding_harmonic
from weathergen.utils.utils import get_dtype
from weathergen.datasets.utils import healpix_verts_rots, r3tos2


class EncoderModule(torch.nn.Module):
Expand All @@ -44,6 +48,56 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord
self.healpix_level = cf.healpix_level
self.num_healpix_cells = 12 * 4**self.healpix_level

self.dtype = get_dtype(cf.attention_dtype)

# Positional embeddings
self.max_tokens_local_per_cell = cf.get("ae_local_max_tokens_per_cell", 64)
self.pe_embed = torch.nn.Parameter(
torch.zeros(self.max_tokens_local_per_cell, cf.ae_local_dim_embed, dtype=self.dtype),
requires_grad=False,
)

self.q_cells_lens = torch.nn.Parameter(
torch.ones(self.num_healpix_cells + 1, dtype=torch.int32), requires_grad=False
)
self.q_cells_lens.data[0] = 0

pe = torch.zeros(
self.num_healpix_cells,
cf.ae_local_num_queries,
cf.ae_global_dim_embed,
dtype=self.dtype,
)
self.pe_global = torch.nn.Parameter(pe, requires_grad=False)

# RoPE coordinates
self.rope_2D = cf.get("rope_2D", False)
if self.rope_2D:
self.num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens
total_tokens = (
self.num_healpix_cells + self.num_extra_tokens
) * cf.ae_local_num_queries
self.register_buffer(
"rope_coords",
torch.zeros(
1,
total_tokens,
2,
dtype=self.dtype,
),
)
self.register_buffer(
"rope_cell_coords",
torch.zeros(
self.num_healpix_cells,
2,
dtype=self.dtype,
),
)
else:
self.rope_coords = None
self.rope_cell_coords = None

self.cf = cf
self.sources_size = sources_size
self.targets_num_channels = targets_num_channels
Expand Down Expand Up @@ -117,29 +171,98 @@ def __init__(self, cf: Config, sources_size, targets_num_channels, targets_coord
# global assimilation engine
self.ae_global_engine = GlobalAssimilationEngine(cf, self.num_healpix_cells)

def forward(self, model_params, batch):
def reset_parameters(self) -> None:
"""Creates positional embedding for each grid point for each stream used after stream
embedding, positional embedding for all stream assimilated cell-level local embedding,
initializing queries for local-to-global adapters, HEALPix neighbourhood based parameter
initializing for target prediction.

Sinusoidal positional encoding: Harmonic positional encoding based upon sine and cosine for
both per stream after stream embedding and per cell level for local assimilation.

Query len based parameter creation: Calculate parameters for the calculated token length at
each cell after local assimilation."""

cf = self.cf

dim_embed = cf.ae_local_dim_embed
token_idx_bias = 16
freq_bias = 8
self.pe_embed.data.fill_(0.0)
position = torch.arange(
token_idx_bias,
token_idx_bias + self.max_tokens_local_per_cell,
device=self.pe_embed.device,
).unsqueeze(1)
div = torch.exp(
torch.arange(freq_bias, freq_bias + dim_embed, 2, device=self.pe_embed.device)
* -(math.log(self.max_tokens_local_per_cell) / dim_embed),
)
self.pe_embed.data[:, 0::2] = torch.sin(position * div[: self.pe_embed[:, 0::2].shape[1]])
self.pe_embed.data[:, 1::2] = torch.cos(position * div[: self.pe_embed[:, 1::2].shape[1]])

dim_embed = cf.ae_global_dim_embed

if self.rope_2D:
verts, _ = healpix_verts_rots(self.healpix_level, 0.5, 0.5)
coords = r3tos2(verts.to(self.rope_coords.device)).to(self.rope_coords.dtype)
self.rope_cell_coords.data.copy_(coords)
coords = coords.unsqueeze(1).repeat(1, cf.ae_local_num_queries, 1)
coords_flat = coords.flatten(0, 1).unsqueeze(0)
num_extra_tokens = cf.num_register_tokens + cf.num_class_tokens
offset = num_extra_tokens * cf.ae_local_num_queries
self.rope_coords.data.fill_(0.0)
self.rope_coords.data[:, offset : offset + coords_flat.shape[1], :].copy_(coords_flat)

self.pe_global.data.fill_(0.0)
xs = 2.0 * np.pi * torch.arange(0, dim_embed, 2, device=self.pe_global.device) / dim_embed
self.pe_global.data[..., 0::2] = 0.5 * torch.sin(
torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs)
)
self.pe_global.data[..., 0::2] += (
torch.sin(
torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs)
)
.unsqueeze(1)
.repeat((1, cf.ae_local_num_queries, 1))
)
self.pe_global.data[..., 1::2] = 0.5 * torch.cos(
torch.outer(8 * torch.arange(cf.ae_local_num_queries, device=self.pe_global.device), xs)
)
self.pe_global.data[..., 1::2] += (
torch.cos(
torch.outer(torch.arange(self.num_healpix_cells, device=self.pe_global.device), xs)
)
.unsqueeze(1)
.repeat((1, cf.ae_local_num_queries, 1))
)

self.q_cells_lens.data.fill_(1)
self.q_cells_lens.data[0] = 0

def forward(self, batch):
"""
Encoder forward
"""

stream_cell_tokens = checkpoint(
self.embed_engine, batch, model_params.pe_embed, use_reentrant=False
self.embed_engine, batch, self.pe_embed, use_reentrant=False
)

tokens_global, posteriors = checkpoint(
self.assimilate_local, model_params, stream_cell_tokens, batch, use_reentrant=False
self.assimilate_local, stream_cell_tokens, batch, use_reentrant=False
)

tokens_global = checkpoint(
self.ae_global_engine,
tokens_global,
coords=model_params.rope_coords,
coords=self.rope_coords,
use_reentrant=False,
)

return tokens_global, posteriors

def interpolate_latents(self, tokens: torch.Tensor) -> (torch.Tensor, torch.Tensor):
def interpolate_latents(self, tokens: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
""" "
TODO
"""
Expand Down Expand Up @@ -273,7 +396,7 @@ def aggregation_engine_unmasked(
return tokens_global_unmasked

def assimilate_local(
self, model_params, tokens: torch.Tensor, batch: ModelBatch
self, tokens: torch.Tensor, batch: ModelBatch
) -> torch.Tensor:
"""
Processes embedded tokens locally and prepares them for the global assimilation
Expand All @@ -299,23 +422,23 @@ def assimilate_local(

# TODO: re-enable or remove ae_local_queries_per_cell
if self.cf.ae_local_queries_per_cell:
tokens_global = (self.q_cells + model_params.pe_global).repeat(rs, 1, 1)
tokens_global = (self.q_cells + self.pe_global).repeat(rs, 1, 1)
else:
num_tokens = self.num_healpix_cells
tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + model_params.pe_global
tokens_global = self.q_cells.repeat(num_tokens, 1, 1) + self.pe_global
tokens_global = tokens_global.repeat(rs, 1, 1)

# apply local assimilation engine and project onto global latent vectors
tokens_global_unmasked, posteriors = self.assimilate_local_project_chunked(
tokens, tokens_global, cell_lens, model_params.q_cells_lens
tokens, tokens_global, cell_lens, self.q_cells_lens
)

# apply aggregation engine on unmasked tokens
tokens_global_unmasked = self.aggregation_engine_unmasked(
tokens_global_unmasked,
tokens_global_register_class,
batch.tokens_lens,
rope_cell_coords=model_params.rope_cell_coords,
rope_cell_coords=self.rope_cell_coords,
)

# final processing
Expand Down
Loading
Loading