Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions atom/plugin/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from atom.models.glm4_moe import Glm4MoeForCausalLM
from atom.models.deepseek_v2 import DeepseekV3ForCausalLM, GlmMoeDsaForCausalLM
from atom.models.minimax_m2 import MiniMaxM2ForCausalLM
from atom.models.minimax_m3 import (
MiniMaxM3SparseForCausalLM,
MiniMaxM3SparseForConditionalGeneration,
)
from atom.models.qwen3_5 import (
Qwen3_5MoeForConditionalGenerationTextOnly,
Qwen3_5ForConditionalGenerationTextOnly,
Expand All @@ -22,6 +26,8 @@
"DeepseekV32ForCausalLM": DeepseekV3ForCausalLM,
"GlmMoeDsaForCausalLM": GlmMoeDsaForCausalLM,
"MiniMaxM2ForCausalLM": MiniMaxM2ForCausalLM,
"MiniMaxM3SparseForCausalLM": MiniMaxM3SparseForCausalLM,
"MiniMaxM3SparseForConditionalGeneration": MiniMaxM3SparseForConditionalGeneration,
"Qwen3_5MoeForConditionalGeneration": Qwen3_5MoeForConditionalGenerationTextOnly,
"Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationTextOnly,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,17 @@ def _update_decode_page_table(

def _should_use_native_dense_mha(self, layer) -> bool:
sliding_window_size = getattr(layer, "sliding_window_size", None)
is_minimax_m3 = bool(getattr(layer, "_atom_minimax_m3_dense_mha", False))
if (
is_minimax_m3
and not self.use_mla
and not layer.is_cross_attention
and layer.head_dim == 128
and layer.qk_head_dim == 128
and layer.v_head_dim == 128
and (sliding_window_size is None or sliding_window_size <= -1)
):
return True
return (
not self.use_mla
and not layer.is_cross_attention
Expand Down Expand Up @@ -1748,6 +1759,10 @@ def forward_extend(

if self.use_mla:
return self._forward_extend_mla(q, k, v, layer, forward_batch)
if bool(getattr(layer, "_atom_minimax_m3_dense_mha", False)):
# M3 dense decode benefits from the native ragged path, but batched
# SGLang prefill is safer through the standard varlen extend path.
return self._forward_extend_mha(q, k, v, layer, forward_batch)
if use_native_dense_mha:
return self._forward_extend_native_dense_mha(q, layer, forward_batch)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from atom.model_ops.attention_mla import MLAModules
from atom.model_ops.base_attention import BaseAttention
from atom.model_ops.layernorm import GemmaRMSNorm, fused_qk_norm
from atom.model_ops.utils import atom_parameter
from atom.plugin.prepare import is_plugin_mode, is_sglang
from atom.models.utils import maybe_prefix
Expand All @@ -40,6 +41,8 @@ def __init__(
per_layer_sliding_window: Optional[int] = None,
rotary_emb: Optional[torch.nn.Module] = None,
prefix: Optional[str] = None,
q_norm: Optional[torch.nn.Module] = None,
k_norm: Optional[torch.nn.Module] = None,
**kwargs,
):
super().__init__(
Expand All @@ -55,10 +58,17 @@ def __init__(
per_layer_sliding_window=per_layer_sliding_window,
rotary_emb=rotary_emb,
prefix=prefix,
q_norm=q_norm,
k_norm=k_norm,
**kwargs,
)

self.rotary_emb = rotary_emb
self.q_norm = q_norm
self.k_norm = k_norm
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim

if is_sglang():
from sglang.srt.layers.radix_attention import RadixAttention
Expand Down Expand Up @@ -141,9 +151,24 @@ def forward_impl_plugin_mode(
save_kv_cache = kwargs.get("save_kv_cache", True)
assert forward_batch is not None, "forward_batch is required for sglang"

# sglang's RadixAttention does not apply rope internally.
# Apply it here when the model passes rotary_emb at construction
# and hasn't already applied rope (e.g. fused qknorm path).
# sglang's RadixAttention does not apply q/k norm or rope internally.
# Apply them here to match ATOM native Attention semantics.
if self.q_norm is not None and self.k_norm is not None:
eps = getattr(self.q_norm, "variance_epsilon", None) or getattr(
self.q_norm, "eps", None
)
add_unit_offset = isinstance(self.q_norm, GemmaRMSNorm)
query, key = fused_qk_norm(
query.view(-1, self.num_heads, self.head_dim),
key.view(-1, self.num_kv_heads, self.head_dim),
self.q_norm.weight,
self.k_norm.weight,
eps,
add_unit_offset=add_unit_offset,
)
query = query.view(-1, self.num_heads * self.head_dim)
key = key.view(-1, self.num_kv_heads * self.head_dim)

if self.rotary_emb is not None and positions is not None:
query, key = self.rotary_emb(positions, query, key)

Expand Down
Loading
Loading