Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from aiter.dist.communication_op import (
tensor_model_parallel_fused_allreduce_rmsnorm,
tensor_model_parallel_fused_allreduce_rmsnorm_quant,
)
from aiter.dist.parallel_state import get_tensor_model_parallel_world_size
from aiter.jit.utils.torch_guard import torch_compile_guard
Expand Down Expand Up @@ -773,6 +774,36 @@ def fused_allreduce_gemma_rms_norm(
return norm(hidden_states, residual)


def fused_allreduce_gemma_rms_norm_quant(
hidden_states: torch.Tensor,
residual: torch.Tensor,
norm: GemmaRMSNorm,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""MiniMax-M3 helper for AR + Gemma RMSNorm + per-token FP8 quant."""
if get_tensor_model_parallel_world_size() > 1:
out_fp8, residual_out, scale_out = (
tensor_model_parallel_fused_allreduce_rmsnorm_quant(
hidden_states.contiguous(),
residual,
norm.weight,
norm.variance_epsilon,
quant_type="per_token",
gemma_norm=True,
)
)
return out_fp8, scale_out, residual_out

from aiter import get_hip_quant
from aiter.utility.dtypes import fp8

normed, residual_out = norm(hidden_states, residual)
out_fp8, scale_out = get_hip_quant(QuantType.per_Token)(
normed,
quant_dtype=fp8,
)
return out_fp8, scale_out, residual_out


# ---------------------------------------------------------------------------
# Fused Q/K RMSNorm Triton kernel
# ---------------------------------------------------------------------------
Expand Down
58 changes: 49 additions & 9 deletions atom/models/minimax_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, Union

import torch
from aiter import ActivationType
from aiter import ActivationType, QuantType, dtypes
from aiter.dist.parallel_state import (
get_pp_group,
get_tensor_model_parallel_world_size,
Expand All @@ -19,6 +19,7 @@
from atom.model_ops.layernorm import (
GemmaRMSNorm,
fused_allreduce_gemma_rms_norm,
fused_allreduce_gemma_rms_norm_quant,
)
from atom.model_ops import module_dispatch_ops as _module_dispatch_ops # noqa: F401
from atom.model_ops.linear import (
Expand Down Expand Up @@ -105,6 +106,15 @@ def _rope_theta(config: PretrainedConfig) -> float:
return getattr(config, "rope_theta", 1000000.0)


def _linear_consumes_per_token_fp8(linear: nn.Module) -> bool:
quant_type = getattr(linear, "quant_type", None)
quant_type_value = getattr(quant_type, "value", quant_type)
return (
quant_type_value == QuantType.per_Token.value
and getattr(linear, "params_dtype", None) == dtypes.fp8
)


def _minimax_m3_cos_sin_cache(
rotary_emb: nn.Module,
query: torch.Tensor,
Expand Down Expand Up @@ -209,8 +219,10 @@ def __init__(
self.swiglu_beta = getattr(config, "swiglu_beta", 1.0)
self.swiglu_limit = getattr(config, "swiglu_limit", 7.0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up = self.gate_up_proj(x)
def forward(
self, x: torch.Tensor, x_scale: torch.Tensor | None = None
) -> torch.Tensor:
gate_up = self.gate_up_proj(x, x_scale=x_scale)
x = swiglu_oai_split(
gate_up,
alpha=self.swiglu_alpha,
Expand Down Expand Up @@ -379,8 +391,9 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
hidden_states_scale: torch.Tensor | None = None,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
qkv = self.qkv_proj(hidden_states, x_scale=hidden_states_scale)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, positions=positions, qkv=qkv)
return self.o_proj(attn_output)
Expand Down Expand Up @@ -515,10 +528,11 @@ def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
hidden_states_scale: torch.Tensor | None = None,
) -> torch.Tensor:
# Keep index Q/K packed with main QKV. Layers that reuse cached top-k skip
# the indexer norm/rope/top-k path, but still compute the packed GEMM.
qkv = self.qkv_proj(hidden_states)
qkv = self.qkv_proj(hidden_states, x_scale=hidden_states_scale)
q, k, v, _, _ = qkv.split(
[
self.q_size,
Expand Down Expand Up @@ -587,9 +601,23 @@ def forward(
residual: torch.Tensor | None,
aux_out: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states_scale = None
fuse_input_ar_rmsnorm_quant = _linear_consumes_per_token_fp8(
self.self_attn.qkv_proj
)
fuse_post_attention_ar_rmsnorm_quant = (
not self.is_moe_layer
and _linear_consumes_per_token_fp8(self.mlp.gate_up_proj)
)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
elif fuse_input_ar_rmsnorm_quant:
hidden_states, hidden_states_scale, residual = (
fused_allreduce_gemma_rms_norm_quant(
hidden_states, residual, self.input_layernorm
)
)
else:
hidden_states, residual = fused_allreduce_gemma_rms_norm(
hidden_states, residual, self.input_layernorm
Expand All @@ -602,12 +630,24 @@ def forward(
if aux_out is not None:
aux_out.append(residual.clone())

hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states)
hidden_states, residual = fused_allreduce_gemma_rms_norm(
hidden_states, residual, self.post_attention_layernorm
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
hidden_states_scale=hidden_states_scale,
)
ffn = self.block_sparse_moe if self.is_moe_layer else self.mlp
hidden_states = ffn(hidden_states)
if fuse_post_attention_ar_rmsnorm_quant:
hidden_states, hidden_states_scale, residual = (
fused_allreduce_gemma_rms_norm_quant(
hidden_states, residual, self.post_attention_layernorm
)
)
hidden_states = ffn(hidden_states, x_scale=hidden_states_scale)
else:
hidden_states, residual = fused_allreduce_gemma_rms_norm(
hidden_states, residual, self.post_attention_layernorm
)
hidden_states = ffn(hidden_states)
return hidden_states, residual


Expand Down
Loading