Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,23 @@ def paged_attention_triton(

attn_metadata = fwd_ctx.attn_metadata

if envs.ATOM_USE_UNIFIED_ATTN and self.kv_cache_dtype.startswith("fp8"):
# gfx1250 (MI455) has no pa_decode_gluon kernel (gluon supports gfx942 /
# gfx950 only), so route the decode through the triton unified_attention
# path, which is gfx1250-capable and reads the same SHUFFLE KV cache.
use_unified = (
envs.ATOM_USE_UNIFIED_ATTN
or self.use_flash_layout
or get_gfx() == "gfx1250"
)

if use_unified and self.kv_cache_dtype.startswith("fp8"):
o = torch.empty(*q.shape, dtype=torch.bfloat16, device=q.device)
else:
o = torch.empty_like(q)

num_seqs = attn_metadata.context_lens.shape[0]

if envs.ATOM_USE_UNIFIED_ATTN or self.use_flash_layout:
if use_unified:
# print(q.shape, k_cache.shape, v_cache.shape)
sliding_window = (
(self.sliding_window - 1, 0) if self.sliding_window > 0 else (-1, -1)
Expand Down
103 changes: 103 additions & 0 deletions atom/model_ops/minimax_m3/sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import aiter # noqa: F401 (used by the gluon PA runners for aiter.dtypes.fp8)
import torch
from aiter.jit.utils.chip_info import get_gfx

try:
from vllm.triton_utils import tl, triton
Expand Down Expand Up @@ -125,6 +126,62 @@ def _is_fp8_kv_cache_tensor(kv_cache: torch.Tensor) -> bool:
return kv_cache.dtype in {dtype for dtype in fp8_dtypes if dtype is not None}


def _sparse_decode_unified_attention(
q_view: torch.Tensor, # [num_seqs, gqa_group, head_dim] (kv-head collapsed)
out_view: torch.Tensor, # [num_seqs, gqa_group, head_dim]
k_cache_view: torch.Tensor, # SHUFFLE 5D, num_kv_heads collapsed to 1
v_cache_view: torch.Tensor,
sparse_bt: torch.Tensor, # [num_seqs, max_pages] physical-16 block table
sparse_ctx: torch.Tensor, # [num_seqs] per-row effective context length
sm_scale: float,
num_seqs: int,
) -> None:
"""gfx1250 fallback for the sparse per-token-as-decode gluon kernel.

gfx1250 (MI455) has no ``pa_decode_gluon`` kernel (gluon supports gfx942 /
gfx950 only). The sparse runners have already compacted the indexer's
selected blocks into a dense physical-16 ``sparse_bt`` + exact ``sparse_ctx``
over the (kv-head collapsed) SHUFFLE cache — which is exactly the
``(block_table, seqused_k)`` contract ``unified_attention`` consumes with
``shuffled_kv_cache=True``. Each token is a length-1 causal "sequence",
mirroring the gluon ``max_seqlen_q=1`` per-token-as-decode setup.

bf16 KV cache only: fp8 sparse decode plumbs per-token (per-page) descales
into the gluon kernel, which does not map onto ``unified_attention``'s descale
contract here; the caller raises NotImplementedError for fp8 on gfx1250.
"""
from aiter.ops.triton.unified_attention import unified_attention

# block_size (page granularity) from the SHUFFLE cache:
# key_cache: [num_blocks, num_kv_heads, head_size // x, block_size, x]
block_size = k_cache_view.shape[3]
# Each token is its own length-1 sequence (decode); cu_seqlens_q = 0..num_seqs.
cu_seqlens_q = torch.arange(num_seqs + 1, dtype=torch.int32, device=q_view.device)
# Safe upper bound: full block table width * page size (>= every sparse_ctx).
max_seqlen_k = int(sparse_bt.shape[1]) * int(block_size)

unified_attention(
q_view,
k_cache_view,
v_cache_view,
out_view,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=1,
seqused_k=sparse_ctx,
max_seqlen_k=max_seqlen_k,
softmax_scale=sm_scale,
causal=True,
window_size=(-1, -1),
block_table=sparse_bt,
softcap=0,
q_descale=None,
k_descale=None,
v_descale=None,
sinks=None,
shuffled_kv_cache=True,
)


# ---------------------------------------------------------------------------
# GQA block-sparse attention. BLOCK_SIZE_K == 128, matching one selected block.
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -1177,6 +1234,29 @@ def minimax_m3_sparse_attn_decode_asm(
v_cache_view = v_cache.view(nph16 * _hkv, 1, *v_cache.shape[2:])

num_seqs = T * num_kv_heads

# gfx1250 (MI455): no gluon pa_decode kernel (gluon supports gfx942 / gfx950
# only). Route through the triton unified_attention sparse fallback over the
# same SHUFFLE cache + compacted sparse block table.
if get_gfx() == "gfx1250":
if _is_fp8_kv_cache_tensor(k_cache):
raise NotImplementedError(
"MiniMax-M3 fp8 sparse decode is not yet supported on gfx1250 "
"(MI455): the gluon per-page descale path has no unified_attention "
"equivalent here. Use a bf16 KV cache on gfx1250."
)
_sparse_decode_unified_attention(
q_view,
out_view,
k_cache_view,
v_cache_view,
sparse_bt,
sparse_ctx,
sm_scale,
num_seqs,
)
return

num_kv_heads_view = 1
query_group_size = g
max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads_view)
Expand Down Expand Up @@ -1271,6 +1351,29 @@ def _run_prefill_fp8_gluon(
v_cache_view = v_cache.view(nph16 * _hkv, 1, *v_cache.shape[2:])

num_seqs = T * num_kv_heads

# gfx1250 (MI455): no gluon pa_decode kernel (gluon supports gfx942 / gfx950
# only). Route through the triton unified_attention sparse fallback over the
# same SHUFFLE cache + compacted sparse block table.
if get_gfx() == "gfx1250":
if _is_fp8_kv_cache_tensor(k_cache):
raise NotImplementedError(
"MiniMax-M3 fp8 sparse decode is not yet supported on gfx1250 "
"(MI455): the gluon per-page descale path has no unified_attention "
"equivalent here. Use a bf16 KV cache on gfx1250."
)
Comment on lines +1360 to +1364
_sparse_decode_unified_attention(
q_view,
out_view,
k_cache_view,
v_cache_view,
sparse_bt,
sparse_ctx,
sm_scale,
num_seqs,
)
return

num_kv_heads_view = 1
query_group_size = g
max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads_view)
Expand Down
144 changes: 124 additions & 20 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,70 @@ def rocm_aiter_fused_moe_fake(
)


def _interleave_gate_up_rows_(layer: torch.nn.Module) -> None:
"""Reorder w13 gate/up rows from SEPARATED (gguu) to INTERLEAVED (gugu).

Operates in place on ``layer.w13_weight``, ``layer.w13_weight_scale`` and
``layer.w13_bias`` (if present) along their row axis (dim=1), which is
``2 * intermediate_size`` laid out as ``[gate(0..I-1) | up(0..I-1)]``. The
new order is ``[gate0, up0, gate1, up1, ...]`` so that a downstream consumer
splitting even/odd rows (the triton a16w4 SwiGLU kernel) reads matching
gate/up pairs. ``w2`` (down_proj) has no gate/up split and is untouched.

The reorder is on whole rows (the ``2I`` axis); the MXFP4-packed pairs live
on the LAST axis (``H//2`` bytes/row), so reordering never splits a packed
byte. For FP4 we reorder a ``uint8`` view (bit-exact, no dequant/requant, no
scale recompute — torch has no ``index_select`` for ``float4_e2m1fn_x2``).

Memory: the permutation is applied **in place per expert**, reusing the
existing storage. A plain ``index_select(...).contiguous()`` over the whole
tensor would double peak memory (a full second copy of the ~GB-sized w13),
which OOMs across many layers at load time. Here the only transient is one
expert's rows (a few MB), so peak overhead is negligible.

Idempotency guard: sets ``layer._w13_gate_up_interleaved`` so a double call
(e.g. process_weights_after_loading invoked twice) is a no-op.
"""
if getattr(layer, "_w13_gate_up_interleaved", False):
return

def _interleave_inplace(t: torch.Tensor) -> None:
"""In-place row reorder [g..|u..] -> [g0,u0,g1,u1,...], per expert.

Only FP4 (float4_e2m1fn_x2) needs the uint8 view, because torch has no
index_select/reshape for it AND its bytes are row-contiguous so the row
reorder is exact. Other dtypes (uint8 e8m0 scale, bf16/float bias) are
reordered directly — viewing them as uint8 would reinterpret bytes and
corrupt the values.

Per expert e: view its 2I rows as (2, I, *rest), transpose to
(I, 2, *rest) (the gugu order); reshape forces one small (single-expert)
temp, then copy_ it back into the same storage. No full-tensor duplicate.
"""
_fp4 = getattr(torch, "float4_e2m1fn_x2", None)
if _fp4 is not None and t.dtype == _fp4:
buf = t.view(torch.uint8)
else:
buf = t

E, two_i = buf.shape[0], buf.shape[1]
assert two_i % 2 == 0, f"w13 row dim {two_i} not even"
i = two_i // 2
rest = buf.shape[2:]
for e in range(E):
rows = buf[e] # view into storage, shape (2I, *rest)
gugu = (
rows.view(2, i, *rest).transpose(0, 1).reshape(two_i, *rest)
) # one (single-expert) temp
rows.copy_(gugu) # write back into the same storage

_interleave_inplace(layer.w13_weight.data)
_interleave_inplace(layer.w13_weight_scale.data)
if getattr(layer, "w13_bias", None) is not None:
_interleave_inplace(layer.w13_bias.data)
layer._w13_gate_up_interleaved = True


class Mxfp4MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig):
super().__init__(moe)
Expand Down Expand Up @@ -963,27 +1027,52 @@ def process_weights_after_loading(self, layer):
# the MoE-kernel-only CDNA4 layout.
n_shared = layer.num_fused_shared_experts
if n_shared > 0:
# IMPORTANT: .clone() (not .contiguous()). A uint8 view of a
# contiguous slice is already contiguous, so .contiguous() returns
# a tensor that SHARES storage with w13_weight. The gguu->gugu
# interleave below mutates w13_weight in place, which would then
# corrupt these stashed shared weights (consumed by the gguu
# half-split swiglu_oai_split). Clone to fully detach.
layer.shared_w13_weight = (
layer.w13_weight.data[-n_shared:].view(torch.uint8).contiguous()
layer.w13_weight.data[-n_shared:].view(torch.uint8).clone()
)
layer.shared_w13_weight_scale = layer.w13_weight_scale.data[
-n_shared:
].contiguous()
].clone()
layer.shared_w2_weight = (
layer.w2_weight.data[-n_shared:].view(torch.uint8).contiguous()
layer.w2_weight.data[-n_shared:].view(torch.uint8).clone()
)
layer.shared_w2_weight_scale = layer.w2_weight_scale.data[
-n_shared:
].contiguous()
].clone()
if layer.w13_bias is not None:
layer.shared_w13_bias = layer.w13_bias.data[-n_shared:].contiguous()
layer.shared_w13_bias = layer.w13_bias.data[-n_shared:].clone()
else:
layer.shared_w13_bias = None
if layer.w2_bias is not None:
layer.shared_w2_bias = layer.w2_bias.data[-n_shared:].contiguous()
layer.shared_w2_bias = layer.w2_bias.data[-n_shared:].clone()
else:
layer.shared_w2_bias = None

# gguu -> gugu interleave for the routed triton SwiGLU kernel.
#
# The checkpoint stores w13 gate/up rows SEPARATED ("gguu":
# [all-gate | all-up]). The aiter default CK fused_moe path consumes
# that directly. The triton a16w4 SwiGLU kernel, however, fuses the
# activation into the GEMM epilogue and splits each tile as
# INTERLEAVED ("gugu": a[...,::2]=gate, a[...,1::2]=up). Feeding gguu
# weights to that kernel mixes gate with gate and misreads up,
# corrupting the output. We interleave the routed w13 rows once here
# so the existing (well-tested) interleaved kernel produces results
# identical to the default path. w2 (down_proj) has no gate/up split.
#
# NOTE: this runs AFTER the shared-expert stash above, which keeps the
# shared experts in gguu for the dense swiglu_oai_split path. Only the
# SwiGLU triton branch needs interleaved rows; the SiLU branch uses
# fused_clamp_act_mul, which half-splits gguu — so gate on activation.
if getattr(layer, "activation", None) == ActivationType.Swiglu:
_interleave_gate_up_rows_(layer)

(
w13_weight,
w13_scale,
Expand Down Expand Up @@ -1297,15 +1386,21 @@ def _apply_shared_experts_dense(self, layer, x, activation):
from aiter.ops.triton.fusions.fused_clamp_act_mul import fused_clamp_act_mul
from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4

# The dense shared-expert GEMM only implements the SiLU activation
# path; SwiGLU models have no fused shared experts, so this assert
# documents the supported scope.
assert (
activation != ActivationType.Swiglu
), "dense shared-expert GEMM only supports the SiLU activation path"
from atom.model_ops.swiglu_oai import swiglu_oai_split

# Two activation flavours are supported, matching the routed experts:
# * SiLU (DeepSeek): silu(gate) * clamp(up), split [gate | up] layout.
# * SwiGLU-OAI (MiniMax-M3 / gpt-oss): gate * sigmoid(alpha*gate) *
# (up + beta) with optional clamp. MiniMax-M3 does not interleave
# gate/up, so the dense GEMM output is split [gate | up] - exactly
# what swiglu_oai_split consumes (mirrors MiniMaxM3MLP.forward and
# the routed swiglu_add_residual=True / alpha path).
is_swiglu = activation == ActivationType.Swiglu

M = x.shape[0]
swiglu_limit = getattr(layer, "swiglu_limit", 0.0)
swiglu_alpha = getattr(layer, "swiglu_alpha", 1.702)
swiglu_beta = getattr(layer, "swiglu_beta", 1.0)

use_a4w4 = self.act_quant == MoEActivationQuant.FP4
if use_a4w4:
Expand All @@ -1331,14 +1426,23 @@ def _shared_expert_gemm(act, weight, weight_scale):
if shared_w13_bias is not None:
gate_up = gate_up + shared_w13_bias[e]
half_n = gate_up.shape[-1] // 2
intermediate = torch.empty((M, half_n), device=x.device, dtype=x.dtype)
fused_clamp_act_mul(
gate_up,
out=intermediate,
swiglu_limit=swiglu_limit,
activation="silu",
dtype_quant=None,
)
if is_swiglu:
intermediate = swiglu_oai_split(
gate_up,
alpha=swiglu_alpha,
beta=swiglu_beta,
limit=swiglu_limit if swiglu_limit > 0 else None,
out_dtype=x.dtype,
)
else:
intermediate = torch.empty((M, half_n), device=x.device, dtype=x.dtype)
fused_clamp_act_mul(
gate_up,
out=intermediate,
swiglu_limit=swiglu_limit,
activation="silu",
dtype_quant=None,
)
out_e = _shared_expert_gemm(
intermediate,
layer.shared_w2_weight[e],
Expand Down
6 changes: 6 additions & 0 deletions atom/models/minimax_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ def __init__(
# padded intermediate avoids backend pad-skip precision issues.
self.experts.quant_method.intermediate_pad = 0
self.experts.swiglu_limit = getattr(config, "swiglu_limit", 7.0)
# SwiGLU-OAI params for the standalone dense shared-expert GEMM
# (Mxfp4MoEMethod._apply_shared_experts_dense). The routed experts use
# alpha (default 1.702) and swiglu_add_residual=True (beta == 1.0); the
# dense shared experts must match.
self.experts.swiglu_alpha = getattr(config, "swiglu_alpha", 1.702)
self.experts.swiglu_beta = getattr(config, "swiglu_beta", 1.0)
self.fuse_shared_experts = (
getattr(self.experts, "num_fused_shared_experts", 0) > 0
)
Expand Down
Loading