diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 26d687a31..59334ef63 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -487,14 +487,23 @@ def paged_attention_triton( attn_metadata = fwd_ctx.attn_metadata - if envs.ATOM_USE_UNIFIED_ATTN and self.kv_cache_dtype.startswith("fp8"): + # gfx1250 (MI455) has no pa_decode_gluon kernel (gluon supports gfx942 / + # gfx950 only), so route the decode through the triton unified_attention + # path, which is gfx1250-capable and reads the same SHUFFLE KV cache. + use_unified = ( + envs.ATOM_USE_UNIFIED_ATTN + or self.use_flash_layout + or get_gfx() == "gfx1250" + ) + + if use_unified and self.kv_cache_dtype.startswith("fp8"): o = torch.empty(*q.shape, dtype=torch.bfloat16, device=q.device) else: o = torch.empty_like(q) num_seqs = attn_metadata.context_lens.shape[0] - if envs.ATOM_USE_UNIFIED_ATTN or self.use_flash_layout: + if use_unified: # print(q.shape, k_cache.shape, v_cache.shape) sliding_window = ( (self.sliding_window - 1, 0) if self.sliding_window > 0 else (-1, -1) diff --git a/atom/model_ops/minimax_m3/sparse_attn.py b/atom/model_ops/minimax_m3/sparse_attn.py index 2fcf03a9a..f0938f81c 100644 --- a/atom/model_ops/minimax_m3/sparse_attn.py +++ b/atom/model_ops/minimax_m3/sparse_attn.py @@ -14,6 +14,7 @@ import aiter # noqa: F401 (used by the gluon PA runners for aiter.dtypes.fp8) import torch +from aiter.jit.utils.chip_info import get_gfx try: from vllm.triton_utils import tl, triton @@ -125,6 +126,62 @@ def _is_fp8_kv_cache_tensor(kv_cache: torch.Tensor) -> bool: return kv_cache.dtype in {dtype for dtype in fp8_dtypes if dtype is not None} +def _sparse_decode_unified_attention( + q_view: torch.Tensor, # [num_seqs, gqa_group, head_dim] (kv-head collapsed) + out_view: torch.Tensor, # [num_seqs, gqa_group, head_dim] + k_cache_view: torch.Tensor, # SHUFFLE 5D, num_kv_heads collapsed to 1 + v_cache_view: torch.Tensor, + sparse_bt: torch.Tensor, # [num_seqs, max_pages] physical-16 block table + sparse_ctx: torch.Tensor, # [num_seqs] per-row effective context length + sm_scale: float, + num_seqs: int, +) -> None: + """gfx1250 fallback for the sparse per-token-as-decode gluon kernel. + + gfx1250 (MI455) has no ``pa_decode_gluon`` kernel (gluon supports gfx942 / + gfx950 only). The sparse runners have already compacted the indexer's + selected blocks into a dense physical-16 ``sparse_bt`` + exact ``sparse_ctx`` + over the (kv-head collapsed) SHUFFLE cache — which is exactly the + ``(block_table, seqused_k)`` contract ``unified_attention`` consumes with + ``shuffled_kv_cache=True``. Each token is a length-1 causal "sequence", + mirroring the gluon ``max_seqlen_q=1`` per-token-as-decode setup. + + bf16 KV cache only: fp8 sparse decode plumbs per-token (per-page) descales + into the gluon kernel, which does not map onto ``unified_attention``'s descale + contract here; the caller raises NotImplementedError for fp8 on gfx1250. + """ + from aiter.ops.triton.unified_attention import unified_attention + + # block_size (page granularity) from the SHUFFLE cache: + # key_cache: [num_blocks, num_kv_heads, head_size // x, block_size, x] + block_size = k_cache_view.shape[3] + # Each token is its own length-1 sequence (decode); cu_seqlens_q = 0..num_seqs. + cu_seqlens_q = torch.arange(num_seqs + 1, dtype=torch.int32, device=q_view.device) + # Safe upper bound: full block table width * page size (>= every sparse_ctx). + max_seqlen_k = int(sparse_bt.shape[1]) * int(block_size) + + unified_attention( + q_view, + k_cache_view, + v_cache_view, + out_view, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=1, + seqused_k=sparse_ctx, + max_seqlen_k=max_seqlen_k, + softmax_scale=sm_scale, + causal=True, + window_size=(-1, -1), + block_table=sparse_bt, + softcap=0, + q_descale=None, + k_descale=None, + v_descale=None, + sinks=None, + shuffled_kv_cache=True, + ) + + # --------------------------------------------------------------------------- # GQA block-sparse attention. BLOCK_SIZE_K == 128, matching one selected block. # --------------------------------------------------------------------------- @@ -1177,6 +1234,29 @@ def minimax_m3_sparse_attn_decode_asm( v_cache_view = v_cache.view(nph16 * _hkv, 1, *v_cache.shape[2:]) num_seqs = T * num_kv_heads + + # gfx1250 (MI455): no gluon pa_decode kernel (gluon supports gfx942 / gfx950 + # only). Route through the triton unified_attention sparse fallback over the + # same SHUFFLE cache + compacted sparse block table. + if get_gfx() == "gfx1250": + if _is_fp8_kv_cache_tensor(k_cache): + raise NotImplementedError( + "MiniMax-M3 fp8 sparse decode is not yet supported on gfx1250 " + "(MI455): the gluon per-page descale path has no unified_attention " + "equivalent here. Use a bf16 KV cache on gfx1250." + ) + _sparse_decode_unified_attention( + q_view, + out_view, + k_cache_view, + v_cache_view, + sparse_bt, + sparse_ctx, + sm_scale, + num_seqs, + ) + return + num_kv_heads_view = 1 query_group_size = g max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads_view) @@ -1271,6 +1351,29 @@ def _run_prefill_fp8_gluon( v_cache_view = v_cache.view(nph16 * _hkv, 1, *v_cache.shape[2:]) num_seqs = T * num_kv_heads + + # gfx1250 (MI455): no gluon pa_decode kernel (gluon supports gfx942 / gfx950 + # only). Route through the triton unified_attention sparse fallback over the + # same SHUFFLE cache + compacted sparse block table. + if get_gfx() == "gfx1250": + if _is_fp8_kv_cache_tensor(k_cache): + raise NotImplementedError( + "MiniMax-M3 fp8 sparse decode is not yet supported on gfx1250 " + "(MI455): the gluon per-page descale path has no unified_attention " + "equivalent here. Use a bf16 KV cache on gfx1250." + ) + _sparse_decode_unified_attention( + q_view, + out_view, + k_cache_view, + v_cache_view, + sparse_bt, + sparse_ctx, + sm_scale, + num_seqs, + ) + return + num_kv_heads_view = 1 query_group_size = g max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads_view) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 12b848e51..21b38683e 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -779,6 +779,70 @@ def rocm_aiter_fused_moe_fake( ) +def _interleave_gate_up_rows_(layer: torch.nn.Module) -> None: + """Reorder w13 gate/up rows from SEPARATED (gguu) to INTERLEAVED (gugu). + + Operates in place on ``layer.w13_weight``, ``layer.w13_weight_scale`` and + ``layer.w13_bias`` (if present) along their row axis (dim=1), which is + ``2 * intermediate_size`` laid out as ``[gate(0..I-1) | up(0..I-1)]``. The + new order is ``[gate0, up0, gate1, up1, ...]`` so that a downstream consumer + splitting even/odd rows (the triton a16w4 SwiGLU kernel) reads matching + gate/up pairs. ``w2`` (down_proj) has no gate/up split and is untouched. + + The reorder is on whole rows (the ``2I`` axis); the MXFP4-packed pairs live + on the LAST axis (``H//2`` bytes/row), so reordering never splits a packed + byte. For FP4 we reorder a ``uint8`` view (bit-exact, no dequant/requant, no + scale recompute — torch has no ``index_select`` for ``float4_e2m1fn_x2``). + + Memory: the permutation is applied **in place per expert**, reusing the + existing storage. A plain ``index_select(...).contiguous()`` over the whole + tensor would double peak memory (a full second copy of the ~GB-sized w13), + which OOMs across many layers at load time. Here the only transient is one + expert's rows (a few MB), so peak overhead is negligible. + + Idempotency guard: sets ``layer._w13_gate_up_interleaved`` so a double call + (e.g. process_weights_after_loading invoked twice) is a no-op. + """ + if getattr(layer, "_w13_gate_up_interleaved", False): + return + + def _interleave_inplace(t: torch.Tensor) -> None: + """In-place row reorder [g..|u..] -> [g0,u0,g1,u1,...], per expert. + + Only FP4 (float4_e2m1fn_x2) needs the uint8 view, because torch has no + index_select/reshape for it AND its bytes are row-contiguous so the row + reorder is exact. Other dtypes (uint8 e8m0 scale, bf16/float bias) are + reordered directly — viewing them as uint8 would reinterpret bytes and + corrupt the values. + + Per expert e: view its 2I rows as (2, I, *rest), transpose to + (I, 2, *rest) (the gugu order); reshape forces one small (single-expert) + temp, then copy_ it back into the same storage. No full-tensor duplicate. + """ + _fp4 = getattr(torch, "float4_e2m1fn_x2", None) + if _fp4 is not None and t.dtype == _fp4: + buf = t.view(torch.uint8) + else: + buf = t + + E, two_i = buf.shape[0], buf.shape[1] + assert two_i % 2 == 0, f"w13 row dim {two_i} not even" + i = two_i // 2 + rest = buf.shape[2:] + for e in range(E): + rows = buf[e] # view into storage, shape (2I, *rest) + gugu = ( + rows.view(2, i, *rest).transpose(0, 1).reshape(two_i, *rest) + ) # one (single-expert) temp + rows.copy_(gugu) # write back into the same storage + + _interleave_inplace(layer.w13_weight.data) + _interleave_inplace(layer.w13_weight_scale.data) + if getattr(layer, "w13_bias", None) is not None: + _interleave_inplace(layer.w13_bias.data) + layer._w13_gate_up_interleaved = True + + class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): super().__init__(moe) @@ -963,27 +1027,52 @@ def process_weights_after_loading(self, layer): # the MoE-kernel-only CDNA4 layout. n_shared = layer.num_fused_shared_experts if n_shared > 0: + # IMPORTANT: .clone() (not .contiguous()). A uint8 view of a + # contiguous slice is already contiguous, so .contiguous() returns + # a tensor that SHARES storage with w13_weight. The gguu->gugu + # interleave below mutates w13_weight in place, which would then + # corrupt these stashed shared weights (consumed by the gguu + # half-split swiglu_oai_split). Clone to fully detach. layer.shared_w13_weight = ( - layer.w13_weight.data[-n_shared:].view(torch.uint8).contiguous() + layer.w13_weight.data[-n_shared:].view(torch.uint8).clone() ) layer.shared_w13_weight_scale = layer.w13_weight_scale.data[ -n_shared: - ].contiguous() + ].clone() layer.shared_w2_weight = ( - layer.w2_weight.data[-n_shared:].view(torch.uint8).contiguous() + layer.w2_weight.data[-n_shared:].view(torch.uint8).clone() ) layer.shared_w2_weight_scale = layer.w2_weight_scale.data[ -n_shared: - ].contiguous() + ].clone() if layer.w13_bias is not None: - layer.shared_w13_bias = layer.w13_bias.data[-n_shared:].contiguous() + layer.shared_w13_bias = layer.w13_bias.data[-n_shared:].clone() else: layer.shared_w13_bias = None if layer.w2_bias is not None: - layer.shared_w2_bias = layer.w2_bias.data[-n_shared:].contiguous() + layer.shared_w2_bias = layer.w2_bias.data[-n_shared:].clone() else: layer.shared_w2_bias = None + # gguu -> gugu interleave for the routed triton SwiGLU kernel. + # + # The checkpoint stores w13 gate/up rows SEPARATED ("gguu": + # [all-gate | all-up]). The aiter default CK fused_moe path consumes + # that directly. The triton a16w4 SwiGLU kernel, however, fuses the + # activation into the GEMM epilogue and splits each tile as + # INTERLEAVED ("gugu": a[...,::2]=gate, a[...,1::2]=up). Feeding gguu + # weights to that kernel mixes gate with gate and misreads up, + # corrupting the output. We interleave the routed w13 rows once here + # so the existing (well-tested) interleaved kernel produces results + # identical to the default path. w2 (down_proj) has no gate/up split. + # + # NOTE: this runs AFTER the shared-expert stash above, which keeps the + # shared experts in gguu for the dense swiglu_oai_split path. Only the + # SwiGLU triton branch needs interleaved rows; the SiLU branch uses + # fused_clamp_act_mul, which half-splits gguu — so gate on activation. + if getattr(layer, "activation", None) == ActivationType.Swiglu: + _interleave_gate_up_rows_(layer) + ( w13_weight, w13_scale, @@ -1297,15 +1386,21 @@ def _apply_shared_experts_dense(self, layer, x, activation): from aiter.ops.triton.fusions.fused_clamp_act_mul import fused_clamp_act_mul from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4 - # The dense shared-expert GEMM only implements the SiLU activation - # path; SwiGLU models have no fused shared experts, so this assert - # documents the supported scope. - assert ( - activation != ActivationType.Swiglu - ), "dense shared-expert GEMM only supports the SiLU activation path" + from atom.model_ops.swiglu_oai import swiglu_oai_split + + # Two activation flavours are supported, matching the routed experts: + # * SiLU (DeepSeek): silu(gate) * clamp(up), split [gate | up] layout. + # * SwiGLU-OAI (MiniMax-M3 / gpt-oss): gate * sigmoid(alpha*gate) * + # (up + beta) with optional clamp. MiniMax-M3 does not interleave + # gate/up, so the dense GEMM output is split [gate | up] - exactly + # what swiglu_oai_split consumes (mirrors MiniMaxM3MLP.forward and + # the routed swiglu_add_residual=True / alpha path). + is_swiglu = activation == ActivationType.Swiglu M = x.shape[0] swiglu_limit = getattr(layer, "swiglu_limit", 0.0) + swiglu_alpha = getattr(layer, "swiglu_alpha", 1.702) + swiglu_beta = getattr(layer, "swiglu_beta", 1.0) use_a4w4 = self.act_quant == MoEActivationQuant.FP4 if use_a4w4: @@ -1331,14 +1426,23 @@ def _shared_expert_gemm(act, weight, weight_scale): if shared_w13_bias is not None: gate_up = gate_up + shared_w13_bias[e] half_n = gate_up.shape[-1] // 2 - intermediate = torch.empty((M, half_n), device=x.device, dtype=x.dtype) - fused_clamp_act_mul( - gate_up, - out=intermediate, - swiglu_limit=swiglu_limit, - activation="silu", - dtype_quant=None, - ) + if is_swiglu: + intermediate = swiglu_oai_split( + gate_up, + alpha=swiglu_alpha, + beta=swiglu_beta, + limit=swiglu_limit if swiglu_limit > 0 else None, + out_dtype=x.dtype, + ) + else: + intermediate = torch.empty((M, half_n), device=x.device, dtype=x.dtype) + fused_clamp_act_mul( + gate_up, + out=intermediate, + swiglu_limit=swiglu_limit, + activation="silu", + dtype_quant=None, + ) out_e = _shared_expert_gemm( intermediate, layer.shared_w2_weight[e], diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index fd0ddef08..72c9fdc86 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -281,6 +281,12 @@ def __init__( # padded intermediate avoids backend pad-skip precision issues. self.experts.quant_method.intermediate_pad = 0 self.experts.swiglu_limit = getattr(config, "swiglu_limit", 7.0) + # SwiGLU-OAI params for the standalone dense shared-expert GEMM + # (Mxfp4MoEMethod._apply_shared_experts_dense). The routed experts use + # alpha (default 1.702) and swiglu_add_residual=True (beta == 1.0); the + # dense shared experts must match. + self.experts.swiglu_alpha = getattr(config, "swiglu_alpha", 1.702) + self.experts.swiglu_beta = getattr(config, "swiglu_beta", 1.0) self.fuse_shared_experts = ( getattr(self.experts, "num_fused_shared_experts", 0) > 0 )