Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ requires-python = ">=3.12,<3.13"
dependencies = [
'numpy~=2.2',
'astropy_healpix~=1.1.2',
'healpy>=1.19,<2',
'zarr~=3.1.3',
'pandas~=2.2',
'tqdm',
Expand Down
60 changes: 39 additions & 21 deletions src/weathergen/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from weathergen.model.norms import AdaLayerNorm, RMSNorm
from weathergen.model.positional_encoding import rotary_pos_emb_2d
from weathergen.model.positional_encoding import apply_rope

"""
Attention blocks used by WeatherGenerator.

Some blocks optionally apply 2D RoPE. When enabled, the caller must provide per-token 2D
coordinates aligned with the token order (lat, lon in radians).
Some blocks optionally apply RoPE-like positional modulation. When enabled, the caller must
provide per-token coordinates aligned with the token order (lat, lon in radians).
"""


Expand All @@ -40,7 +40,7 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_2d_rope=False,
rope_mode="none",
):
super(MultiSelfAttentionHeadVarlen, self).__init__()

Expand All @@ -49,7 +49,10 @@ def __init__(
self.with_flash = with_flash
self.softcap = softcap
self.with_residual = with_residual
self.with_2d_rope = with_2d_rope
self.rope_mode = rope_mode
self.rope_post_mod_qk_lnorm = rope_mode == "spherical"
if self.rope_post_mod_qk_lnorm:
assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True"

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -79,6 +82,9 @@ def __init__(
lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity
self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps)
self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps)
post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity
self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)
self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)

self.dtype = attention_dtype

Expand All @@ -96,10 +102,10 @@ def forward(self, x, x_lens, ada_ln_aux=None, coords=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype)
vs = self.proj_heads_v(x).reshape(s)

if self.with_2d_rope:
if coords is None:
raise ValueError("coords must be provided when with_2d_rope=True")
qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1)
qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1)
if self.rope_post_mod_qk_lnorm:
qs = self.post_rope_lnorm_q(qs).to(self.dtype)
ks = self.post_rope_lnorm_k(ks).to(self.dtype)

# set dropout rate according to training/eval mode as required by flash_attn
dropout_rate = self.dropout_rate if self.training else 0.0
Expand Down Expand Up @@ -225,15 +231,18 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_2d_rope=False,
rope_mode="none",
):
super(MultiSelfAttentionHeadLocal, self).__init__()

self.num_heads = num_heads
self.with_flash = with_flash
self.softcap = softcap
self.with_residual = with_residual
self.with_2d_rope = with_2d_rope
self.rope_mode = rope_mode
self.rope_post_mod_qk_lnorm = rope_mode == "spherical"
if self.rope_post_mod_qk_lnorm:
assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True"

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -263,6 +272,9 @@ def __init__(
lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity
self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps)
self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps)
post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity
self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)
self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)

self.dtype = attention_dtype
assert with_flash, "Only flash attention supported."
Expand All @@ -288,10 +300,10 @@ def forward(self, x, coords=None, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype).permute([0, 2, 1, 3])
vs = self.proj_heads_v(x).reshape(s).permute([0, 2, 1, 3])

if self.with_2d_rope:
if coords is None:
raise ValueError("coords must be provided when with_2d_rope=True")
qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=1)
qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 1)
if self.rope_post_mod_qk_lnorm:
qs = self.post_rope_lnorm_q(qs).to(self.dtype)
ks = self.post_rope_lnorm_k(ks).to(self.dtype)

outs = self.flex_attention(qs, ks, vs, block_mask=self.block_mask).transpose(1, 2)

Expand Down Expand Up @@ -540,7 +552,7 @@ def __init__(
dim_aux=None,
norm_eps=1e-5,
attention_dtype=torch.bfloat16,
with_2d_rope=False,
rope_mode="none",
):
super(MultiSelfAttentionHead, self).__init__()

Expand All @@ -549,7 +561,10 @@ def __init__(
self.softcap = softcap
self.dropout_rate = dropout_rate
self.with_residual = with_residual
self.with_2d_rope = with_2d_rope
self.rope_mode = rope_mode
self.rope_post_mod_qk_lnorm = rope_mode == "spherical"
if self.rope_post_mod_qk_lnorm:
assert with_qk_lnorm, "rope_post_mod_qk_lnorm=True requires with_qk_lnorm=True"

assert dim_embed % num_heads == 0
self.dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
Expand Down Expand Up @@ -579,6 +594,9 @@ def __init__(
lnorm = qk_norm if with_qk_lnorm else torch.nn.Identity
self.lnorm_q = lnorm(self.dim_head_proj, eps=norm_eps)
self.lnorm_k = lnorm(self.dim_head_proj, eps=norm_eps)
post_rope_lnorm = norm if self.rope_post_mod_qk_lnorm else torch.nn.Identity
self.post_rope_lnorm_q = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)
self.post_rope_lnorm_k = post_rope_lnorm(self.dim_head_proj, eps=norm_eps)

self.dtype = attention_dtype
if with_flash:
Expand All @@ -599,10 +617,10 @@ def forward(self, x, coords=None, ada_ln_aux=None):
ks = self.lnorm_k(self.proj_heads_k(x).reshape(s)).to(self.dtype)
vs = self.proj_heads_v(x).reshape(s).to(self.dtype)

if self.with_2d_rope:
if coords is None:
raise ValueError("coords must be provided when with_2d_rope=True")
qs, ks = rotary_pos_emb_2d(qs, ks, coords, unsqueeze_dim=2)
qs, ks = apply_rope(qs, ks, coords, self.rope_mode, 2)
if self.rope_post_mod_qk_lnorm:
qs = self.post_rope_lnorm_q(qs).to(self.dtype)
ks = self.post_rope_lnorm_k(ks).to(self.dtype)

# set dropout rate according to training/eval mode as required by flash_attn
dropout_rate = self.dropout_rate if self.training else 0.0
Expand Down
Loading
Loading