From 8ea7d9e470181a0af0704c51ab84dccf4b30b986 Mon Sep 17 00:00:00 2001 From: Leon Ling Date: Thu, 25 Jun 2026 02:24:23 +0000 Subject: [PATCH 1/8] Support SwiGLU-OAI in dense shared-expert GEMM (MiniMax-M3) Mxfp4MoEMethod._apply_shared_experts_dense hard-asserted the SiLU activation path, so MiniMax-M3 (ActivationType.Swiglu with fused shared experts) crashed with "dense shared-expert GEMM only supports the SiLU activation path". MiniMax-M3 uses SwiGLU-OAI (gate*sigmoid(alpha*gate)*(up+beta)) and does not interleave gate/up, so the dense GEMM output is split [gate|up] -- exactly what swiglu_oai_split consumes. The dense shared expert now mirrors MiniMaxM3MLP.forward and the routed experts' alpha / swiglu_add_residual=True path. - moe.py: drop the assert; branch the activation step. SiLU keeps fused_clamp_act_mul (DeepSeek, unchanged); SwiGLU uses swiglu_oai_split with alpha/beta/limit read from the layer. - minimax_m3.py: stash swiglu_alpha/swiglu_beta on self.experts (from config, defaults 1.702/1.0) next to swiglu_limit. - tests: numerical test that reuses the same kernel GEMM and varies only the activation, isolating the fix from mxfp4/bf16 GEMM precision. Verified on MI350X: fixed path matches the SwiGLU-OAI reference exactly, old SiLU behaviour diverged ~20%. Co-Authored-By: Claude Opus 4.8 (1M context) --- atom/model_ops/moe.py | 45 +++-- atom/models/minimax_m3.py | 6 + tests/test_mxfp4_shared_experts_swiglu.py | 222 ++++++++++++++++++++++ 3 files changed, 259 insertions(+), 14 deletions(-) create mode 100644 tests/test_mxfp4_shared_experts_swiglu.py diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 12b848e51..d4f9d2925 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1297,15 +1297,21 @@ def _apply_shared_experts_dense(self, layer, x, activation): from aiter.ops.triton.fusions.fused_clamp_act_mul import fused_clamp_act_mul from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4 - # The dense shared-expert GEMM only implements the SiLU activation - # path; SwiGLU models have no fused shared experts, so this assert - # documents the supported scope. - assert ( - activation != ActivationType.Swiglu - ), "dense shared-expert GEMM only supports the SiLU activation path" + from atom.model_ops.swiglu_oai import swiglu_oai_split + + # Two activation flavours are supported, matching the routed experts: + # * SiLU (DeepSeek): silu(gate) * clamp(up), split [gate | up] layout. + # * SwiGLU-OAI (MiniMax-M3 / gpt-oss): gate * sigmoid(alpha*gate) * + # (up + beta) with optional clamp. MiniMax-M3 does not interleave + # gate/up, so the dense GEMM output is split [gate | up] - exactly + # what swiglu_oai_split consumes (mirrors MiniMaxM3MLP.forward and + # the routed swiglu_add_residual=True / alpha path). + is_swiglu = activation == ActivationType.Swiglu M = x.shape[0] swiglu_limit = getattr(layer, "swiglu_limit", 0.0) + swiglu_alpha = getattr(layer, "swiglu_alpha", 1.702) + swiglu_beta = getattr(layer, "swiglu_beta", 1.0) use_a4w4 = self.act_quant == MoEActivationQuant.FP4 if use_a4w4: @@ -1331,14 +1337,25 @@ def _shared_expert_gemm(act, weight, weight_scale): if shared_w13_bias is not None: gate_up = gate_up + shared_w13_bias[e] half_n = gate_up.shape[-1] // 2 - intermediate = torch.empty((M, half_n), device=x.device, dtype=x.dtype) - fused_clamp_act_mul( - gate_up, - out=intermediate, - swiglu_limit=swiglu_limit, - activation="silu", - dtype_quant=None, - ) + if is_swiglu: + intermediate = swiglu_oai_split( + gate_up, + alpha=swiglu_alpha, + beta=swiglu_beta, + limit=swiglu_limit if swiglu_limit > 0 else None, + out_dtype=x.dtype, + ) + else: + intermediate = torch.empty( + (M, half_n), device=x.device, dtype=x.dtype + ) + fused_clamp_act_mul( + gate_up, + out=intermediate, + swiglu_limit=swiglu_limit, + activation="silu", + dtype_quant=None, + ) out_e = _shared_expert_gemm( intermediate, layer.shared_w2_weight[e], diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index fd0ddef08..72c9fdc86 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -281,6 +281,12 @@ def __init__( # padded intermediate avoids backend pad-skip precision issues. self.experts.quant_method.intermediate_pad = 0 self.experts.swiglu_limit = getattr(config, "swiglu_limit", 7.0) + # SwiGLU-OAI params for the standalone dense shared-expert GEMM + # (Mxfp4MoEMethod._apply_shared_experts_dense). The routed experts use + # alpha (default 1.702) and swiglu_add_residual=True (beta == 1.0); the + # dense shared experts must match. + self.experts.swiglu_alpha = getattr(config, "swiglu_alpha", 1.702) + self.experts.swiglu_beta = getattr(config, "swiglu_beta", 1.0) self.fuse_shared_experts = ( getattr(self.experts, "num_fused_shared_experts", 0) > 0 ) diff --git a/tests/test_mxfp4_shared_experts_swiglu.py b/tests/test_mxfp4_shared_experts_swiglu.py new file mode 100644 index 000000000..dd8c8f592 --- /dev/null +++ b/tests/test_mxfp4_shared_experts_swiglu.py @@ -0,0 +1,222 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Numerical test for the dense shared-expert path with SwiGLU-OAI activation. + +Regression: ``Mxfp4MoEMethod._apply_shared_experts_dense`` hard-asserted the +SiLU activation path, so MiniMax-M3 (``ActivationType.Swiglu`` *with* fused +shared experts) crashed with:: + + AssertionError: dense shared-expert GEMM only supports the SiLU activation path + +MiniMax-M3 does not interleave gate/up weights, so the dense GEMM output is +split ``[gate | up]`` — exactly what ``swiglu_oai_split`` consumes. The dense +shared expert must therefore replicate ``MiniMaxM3MLP.forward``: +gate_up GEMM -> swiglu_oai_split -> down GEMM, with the SwiGLU-OAI math +``gate * sigmoid(alpha*gate) * (up + beta)`` (alpha=1.702, beta=1.0), not SiLU. + +These tests run the *real* fixed code path (gemm_a16wfp4 + the activation +branch). The fix only changes the *activation*, so the reference reuses the +*same* ``gemm_a16wfp4`` for both matmuls and differs only in the activation +math (computed independently in plain torch). This isolates the activation and +avoids conflating it with the kernel's mxfp4/bf16 GEMM precision. The tests +prove: + * the SwiGLU branch matches the SwiGLU-OAI reference, and + * it is genuinely different from the SiLU reference (i.e. the fix changed + behaviour, it is not silently equivalent), and + * the SiLU branch (DeepSeek) is unchanged. +""" + +import types + +import pytest +import torch + +cuda_only = pytest.mark.skipif( + not torch.cuda.is_available(), reason="requires an AMD GPU" +) + +SCALE_GROUP_SIZE = 32 + +# e2m1 (fp4) decode table: sign | 2-bit exp | 1-bit mantissa. +_MXFP4_TABLE = [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +] + + +def _mxfp4_to_f32(packed: torch.Tensor) -> torch.Tensor: + """Decode a uint8 tensor of two packed e2m1 nibbles to f32 (last dim x2).""" + x = packed.repeat_interleave(2, dim=-1) + x[..., ::2] = x[..., ::2] & 0xF + x[..., 1::2] = x[..., 1::2] >> 4 + table = torch.tensor(_MXFP4_TABLE, dtype=torch.float32, device=packed.device) + return table[x.long()] + + +def _e8m0_to_f32(scale: torch.Tensor) -> torch.Tensor: + return 2.0 ** (scale.to(torch.float32) - 127) + + +def _make_weight(n: int, k: int, *, seed: int): + """Random fp4-packed weight (n, k//2) + e8m0 scales (n, k//32). + + The bf16 dequantization is not returned: the reference reuses the kernel's + own GEMM, so we never need a from-scratch dequant matmul. + """ + g = torch.Generator(device="cuda").manual_seed(seed) + low = torch.randint(0, 16, (n, k // 2), dtype=torch.uint8, device="cuda", generator=g) + high = torch.randint(0, 16, (n, k // 2), dtype=torch.uint8, device="cuda", generator=g) + packed = low | (high << 4) + # e8m0 scales near 1.0 (bias 127) keep the dequant range sane. + scales = torch.randint( + 125, 130, (n, k // SCALE_GROUP_SIZE), dtype=torch.uint8, device="cuda", generator=g + ) + return packed, scales + + +def _ref_swiglu_oai(gate_up, alpha, beta, limit): + n = gate_up.shape[-1] // 2 + gate = gate_up[:, :n].to(torch.float32) + up = gate_up[:, n:].to(torch.float32) + if limit is not None: + gate = torch.clamp(gate, max=limit) + up = torch.clamp(up, min=-limit, max=limit) + return (gate * torch.sigmoid(alpha * gate) * (up + beta)).to(gate_up.dtype) + + +def _ref_silu(gate_up, limit): + n = gate_up.shape[-1] // 2 + gate = gate_up[:, :n].to(torch.float32) + up = gate_up[:, n:].to(torch.float32) + if limit > 0: + gate = torch.clamp(gate, max=limit) + up = torch.clamp(up, min=-limit, max=limit) + return (gate * torch.sigmoid(gate) * up).to(gate_up.dtype) + + +def _build_method_and_layer(hidden, inter, *, alpha, beta, limit): + from atom.config import LayerQuantConfig + from atom.model_ops.moe import Mxfp4MoEMethod, MoEActivationQuant + from aiter import QuantType + from unittest.mock import MagicMock + + qc = LayerQuantConfig( + quant_type=QuantType.per_1x32, + quant_dtype=torch.float4_e2m1fn_x2, + quant_method="quark", + ) + method = Mxfp4MoEMethod(qc, MagicMock()) + # Exercise the a16w4 (bf16 activation) path so we feed plain bf16 inputs. + method.act_quant = MoEActivationQuant.BF16 + + w13, s13 = _make_weight(2 * inter, hidden, seed=1) # (2I, H) + w2, s2 = _make_weight(hidden, inter, seed=2) # (H, I) + + layer = types.SimpleNamespace( + num_fused_shared_experts=1, + shared_w13_weight=w13.unsqueeze(0), + shared_w13_weight_scale=s13.unsqueeze(0), + shared_w2_weight=w2.unsqueeze(0), + shared_w2_weight_scale=s2.unsqueeze(0), + shared_w13_bias=None, + shared_w2_bias=None, + swiglu_limit=limit, + swiglu_alpha=alpha, + swiglu_beta=beta, + ) + return method, layer, (w13, s13), (w2, s2) + + +def _kernel_gemm(act, packed_scale): + """Reuse the exact same GEMM the dense path uses, so the only difference + between the dense path and the reference is the activation.""" + from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4 + + weight, scale = packed_scale + return gemm_a16wfp4(act, weight, scale, dtype=torch.bfloat16) + + +@cuda_only +def test_swiglu_shared_expert_matches_reference(): + from aiter import ActivationType + from atom.model_ops.moe import Mxfp4MoEMethod + + if not _fp4_available(): + pytest.skip("MXFP4 not supported on this architecture") + + hidden, inter, M = 256, 256, 64 + alpha, beta, limit = 1.702, 1.0, 7.0 + method, layer, w13, w2 = _build_method_and_layer( + hidden, inter, alpha=alpha, beta=beta, limit=limit + ) + + torch.manual_seed(0) + x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") * 0.5 + + out = Mxfp4MoEMethod._apply_shared_experts_dense( + method, layer, x, ActivationType.Swiglu + ) + + # Reference reuses the SAME kernel GEMM; only the activation is computed + # independently (plain torch), isolating the fix from GEMM precision. + gate_up = _kernel_gemm(x, w13) + inter_ref = _ref_swiglu_oai(gate_up, alpha, beta, limit) + out_ref = _kernel_gemm(inter_ref, w2) + + # SiLU on the same gate_up must be clearly different (proves the branch + # matters and the fix is not silently equivalent to the old code). + inter_silu = _ref_silu(gate_up, limit) + out_silu = _kernel_gemm(inter_silu, w2) + + err_swiglu = (out.float() - out_ref.float()).abs().mean().item() + err_vs_silu = (out_ref.float() - out_silu.float()).abs().mean().item() + + torch.testing.assert_close(out.float(), out_ref.float(), rtol=1e-2, atol=1e-2) + assert err_vs_silu > 10 * max(err_swiglu, 1e-6), ( + f"swiglu vs silu too close to distinguish " + f"(err_swiglu={err_swiglu}, err_vs_silu={err_vs_silu})" + ) + + +@cuda_only +def test_silu_shared_expert_unchanged(): + from aiter import ActivationType + from atom.model_ops.moe import Mxfp4MoEMethod + + if not _fp4_available(): + pytest.skip("MXFP4 not supported on this architecture") + + hidden, inter, M = 256, 256, 64 + limit = 7.0 + method, layer, w13, w2 = _build_method_and_layer( + hidden, inter, alpha=1.702, beta=1.0, limit=limit + ) + + torch.manual_seed(0) + x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") * 0.5 + + out = Mxfp4MoEMethod._apply_shared_experts_dense( + method, layer, x, ActivationType.Silu + ) + + gate_up = _kernel_gemm(x, w13) + inter_ref = _ref_silu(gate_up, limit) + out_ref = _kernel_gemm(inter_ref, w2) + + torch.testing.assert_close(out.float(), out_ref.float(), rtol=1e-2, atol=1e-2) + + +def _fp4_available(): + try: + import aiter.ops.triton.utils._triton.arch_info as arch_info + + return arch_info.is_fp4_avail() + except Exception: + return torch.cuda.is_available() + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__, "-v"])) From 5b89f2ebb1d9ac6d916220636610fbde2aa0de56 Mon Sep 17 00:00:00 2001 From: Leon Ling Date: Thu, 25 Jun 2026 06:40:00 +0000 Subject: [PATCH 2/8] Route MiniMax-M3 decode to unified_attention on gfx1250 (MI455) gfx1250 has no pa_decode_gluon kernel (gluon supports gfx942 / gfx950 only), so MiniMax-M3 decode crashed with "pa_decode_gluon only supports gfx942 (CDNA3) and gfx950 (CDNA4)". The compacted sparse block table / context lengths the runners already build are exactly the (block_table, seqused_k) contract unified_attention consumes over the same SHUFFLE KV cache, so route both the full-attn and sparse decode paths through the triton unified_attention on gfx1250. Full-attn (attention_mha.py): - paged_attention_triton: add a use_unified flag that includes get_gfx() == "gfx1250" so decode takes the unified_attention branch instead of run_pa_decode_gluon. Gluon retained for CDNA3/CDNA4. Sparse (minimax_m3/sparse_attn.py) -- note ATOM_USE_UNIFIED_ATTN does NOT gate this path; the sparse runners call run_pa_decode_gluon directly: - add _sparse_decode_unified_attention helper feeding the kv-head collapsed SHUFFLE cache + sparse_bt + sparse_ctx into unified_attention(shuffled_kv_cache=True), each token a length-1 causal sequence (mirrors gluon max_seqlen_q=1 per-token-as-decode). - gfx1250 branch in minimax_m3_sparse_attn_decode_asm and _run_prefill_fp8_gluon: bf16 -> helper; fp8 -> NotImplementedError (gluon per-page descale has no unified_attention equivalent yet). Caveat: validated to compile/import on gfx950; the sparse path's GQA/block_table semantics still need MI455 numerical validation against the gfx950 gluon reference, and fp8 KV cache on gfx1250 is unsupported. Co-Authored-By: Claude Opus 4.8 (1M context) --- atom/model_ops/attention_mha.py | 13 ++- atom/model_ops/minimax_m3/sparse_attn.py | 105 +++++++++++++++++++++++ 2 files changed, 116 insertions(+), 2 deletions(-) diff --git a/atom/model_ops/attention_mha.py b/atom/model_ops/attention_mha.py index 26d687a31..59334ef63 100644 --- a/atom/model_ops/attention_mha.py +++ b/atom/model_ops/attention_mha.py @@ -487,14 +487,23 @@ def paged_attention_triton( attn_metadata = fwd_ctx.attn_metadata - if envs.ATOM_USE_UNIFIED_ATTN and self.kv_cache_dtype.startswith("fp8"): + # gfx1250 (MI455) has no pa_decode_gluon kernel (gluon supports gfx942 / + # gfx950 only), so route the decode through the triton unified_attention + # path, which is gfx1250-capable and reads the same SHUFFLE KV cache. + use_unified = ( + envs.ATOM_USE_UNIFIED_ATTN + or self.use_flash_layout + or get_gfx() == "gfx1250" + ) + + if use_unified and self.kv_cache_dtype.startswith("fp8"): o = torch.empty(*q.shape, dtype=torch.bfloat16, device=q.device) else: o = torch.empty_like(q) num_seqs = attn_metadata.context_lens.shape[0] - if envs.ATOM_USE_UNIFIED_ATTN or self.use_flash_layout: + if use_unified: # print(q.shape, k_cache.shape, v_cache.shape) sliding_window = ( (self.sliding_window - 1, 0) if self.sliding_window > 0 else (-1, -1) diff --git a/atom/model_ops/minimax_m3/sparse_attn.py b/atom/model_ops/minimax_m3/sparse_attn.py index 2fcf03a9a..61f36cc73 100644 --- a/atom/model_ops/minimax_m3/sparse_attn.py +++ b/atom/model_ops/minimax_m3/sparse_attn.py @@ -14,6 +14,7 @@ import aiter # noqa: F401 (used by the gluon PA runners for aiter.dtypes.fp8) import torch +from aiter.jit.utils.chip_info import get_gfx try: from vllm.triton_utils import tl, triton @@ -125,6 +126,64 @@ def _is_fp8_kv_cache_tensor(kv_cache: torch.Tensor) -> bool: return kv_cache.dtype in {dtype for dtype in fp8_dtypes if dtype is not None} +def _sparse_decode_unified_attention( + q_view: torch.Tensor, # [num_seqs, gqa_group, head_dim] (kv-head collapsed) + out_view: torch.Tensor, # [num_seqs, gqa_group, head_dim] + k_cache_view: torch.Tensor, # SHUFFLE 5D, num_kv_heads collapsed to 1 + v_cache_view: torch.Tensor, + sparse_bt: torch.Tensor, # [num_seqs, max_pages] physical-16 block table + sparse_ctx: torch.Tensor, # [num_seqs] per-row effective context length + sm_scale: float, + num_seqs: int, +) -> None: + """gfx1250 fallback for the sparse per-token-as-decode gluon kernel. + + gfx1250 (MI455) has no ``pa_decode_gluon`` kernel (gluon supports gfx942 / + gfx950 only). The sparse runners have already compacted the indexer's + selected blocks into a dense physical-16 ``sparse_bt`` + exact ``sparse_ctx`` + over the (kv-head collapsed) SHUFFLE cache — which is exactly the + ``(block_table, seqused_k)`` contract ``unified_attention`` consumes with + ``shuffled_kv_cache=True``. Each token is a length-1 causal "sequence", + mirroring the gluon ``max_seqlen_q=1`` per-token-as-decode setup. + + bf16 KV cache only: fp8 sparse decode plumbs per-token (per-page) descales + into the gluon kernel, which does not map onto ``unified_attention``'s descale + contract here; the caller raises NotImplementedError for fp8 on gfx1250. + """ + from aiter.ops.triton.unified_attention import unified_attention + + # block_size (page granularity) from the SHUFFLE cache: + # key_cache: [num_blocks, num_kv_heads, head_size // x, block_size, x] + block_size = k_cache_view.shape[3] + # Each token is its own length-1 sequence (decode); cu_seqlens_q = 0..num_seqs. + cu_seqlens_q = torch.arange( + num_seqs + 1, dtype=torch.int32, device=q_view.device + ) + # Safe upper bound: full block table width * page size (>= every sparse_ctx). + max_seqlen_k = int(sparse_bt.shape[1]) * int(block_size) + + unified_attention( + q_view, + k_cache_view, + v_cache_view, + out_view, + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=1, + seqused_k=sparse_ctx, + max_seqlen_k=max_seqlen_k, + softmax_scale=sm_scale, + causal=True, + window_size=(-1, -1), + block_table=sparse_bt, + softcap=0, + q_descale=None, + k_descale=None, + v_descale=None, + sinks=None, + shuffled_kv_cache=True, + ) + + # --------------------------------------------------------------------------- # GQA block-sparse attention. BLOCK_SIZE_K == 128, matching one selected block. # --------------------------------------------------------------------------- @@ -1177,6 +1236,29 @@ def minimax_m3_sparse_attn_decode_asm( v_cache_view = v_cache.view(nph16 * _hkv, 1, *v_cache.shape[2:]) num_seqs = T * num_kv_heads + + # gfx1250 (MI455): no gluon pa_decode kernel (gluon supports gfx942 / gfx950 + # only). Route through the triton unified_attention sparse fallback over the + # same SHUFFLE cache + compacted sparse block table. + if get_gfx() == "gfx1250": + if _is_fp8_kv_cache_tensor(k_cache): + raise NotImplementedError( + "MiniMax-M3 fp8 sparse decode is not yet supported on gfx1250 " + "(MI455): the gluon per-page descale path has no unified_attention " + "equivalent here. Use a bf16 KV cache on gfx1250." + ) + _sparse_decode_unified_attention( + q_view, + out_view, + k_cache_view, + v_cache_view, + sparse_bt, + sparse_ctx, + sm_scale, + num_seqs, + ) + return + num_kv_heads_view = 1 query_group_size = g max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads_view) @@ -1271,6 +1353,29 @@ def _run_prefill_fp8_gluon( v_cache_view = v_cache.view(nph16 * _hkv, 1, *v_cache.shape[2:]) num_seqs = T * num_kv_heads + + # gfx1250 (MI455): no gluon pa_decode kernel (gluon supports gfx942 / gfx950 + # only). Route through the triton unified_attention sparse fallback over the + # same SHUFFLE cache + compacted sparse block table. + if get_gfx() == "gfx1250": + if _is_fp8_kv_cache_tensor(k_cache): + raise NotImplementedError( + "MiniMax-M3 fp8 sparse decode is not yet supported on gfx1250 " + "(MI455): the gluon per-page descale path has no unified_attention " + "equivalent here. Use a bf16 KV cache on gfx1250." + ) + _sparse_decode_unified_attention( + q_view, + out_view, + k_cache_view, + v_cache_view, + sparse_bt, + sparse_ctx, + sm_scale, + num_seqs, + ) + return + num_kv_heads_view = 1 query_group_size = g max_context_partition_num = get_recommended_splits(num_seqs, num_kv_heads_view) From 92f2201399f7c5615515ea9f94c98fce09b83538 Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 29 Jun 2026 11:59:12 +0000 Subject: [PATCH 3/8] add dump for minimax Signed-off-by: ganyi --- atom/models/minimax_m3.py | 11 +++ atom/utils/debug_helper/__init__.py | 2 + atom/utils/debug_helper/dump.py | 123 ++++++++++++++++++++++++ atom/utils/envs.py | 6 ++ curl_minimax.sh | 4 + serve_minimax.sh | 143 ++++++++++++++++++++++++++++ 6 files changed, 289 insertions(+) create mode 100644 curl_minimax.sh create mode 100644 serve_minimax.sh diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index 72c9fdc86..41a548ac1 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -43,6 +43,7 @@ make_layers, maybe_prefix, ) +from atom.utils.debug_helper import maybe_dump_minimax_m3_layer from atom.utils.decorators import support_torch_compile from torch import nn from transformers import PretrainedConfig @@ -588,6 +589,10 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) + # Debug dump bookkeeping (env-gated; see maybe_dump_minimax_m3_layer). + self.layer_num = layer_num + self._last_layer_idx = config.num_hidden_layers - 1 + def forward( self, positions: torch.Tensor, @@ -611,11 +616,17 @@ def forward( aux_out.append(residual.clone()) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + maybe_dump_minimax_m3_layer( + hidden_states, self.layer_num, "attn", self._last_layer_idx + ) hidden_states, residual = fused_allreduce_gemma_rms_norm( hidden_states, residual, self.post_attention_layernorm ) ffn = self.block_sparse_moe if self.is_moe_layer else self.mlp hidden_states = ffn(hidden_states) + maybe_dump_minimax_m3_layer( + hidden_states, self.layer_num, "moe", self._last_layer_idx + ) return hidden_states, residual diff --git a/atom/utils/debug_helper/__init__.py b/atom/utils/debug_helper/__init__.py index d23de5e14..a3835c9c5 100644 --- a/atom/utils/debug_helper/__init__.py +++ b/atom/utils/debug_helper/__init__.py @@ -31,6 +31,7 @@ ) from atom.utils.debug_helper.dump import ( install_block_forward_hooks, + maybe_dump_minimax_m3_layer, maybe_dump_weights_and_exit, maybe_log_topk, ) @@ -43,6 +44,7 @@ __all__ = [ # dump "install_block_forward_hooks", + "maybe_dump_minimax_m3_layer", "maybe_dump_weights_and_exit", "maybe_log_topk", # compare primitives diff --git a/atom/utils/debug_helper/dump.py b/atom/utils/debug_helper/dump.py index 0122efbb7..3c97c4939 100644 --- a/atom/utils/debug_helper/dump.py +++ b/atom/utils/debug_helper/dump.py @@ -10,6 +10,8 @@ ATOM_FWD_DUMP_LAYER_ATTR / ATOM_FWD_DUMP_ONE_SHOT ATOM_WEIGHT_DUMP_DIR / ATOM_WEIGHT_DUMP_LAYERS / ATOM_WEIGHT_DUMP_EXIT ATOM_DEBUG_TOPK / ATOM_DEBUG_TOPK_PATH + ATOM_M3_DUMP_DIR / ATOM_M3_DUMP_EXIT + ATOM_MINIMAX_DUMP_DIR / ATOM_MINIMAX_DUMP_ABORT Output file naming ------------------ @@ -33,6 +35,7 @@ from __future__ import annotations +import logging import os import sys from typing import Optional @@ -158,6 +161,126 @@ def _find_layer_id(mod_name: str) -> Optional[int]: return n +# === MiniMax-M3 per-layer attn/moe dump (call-site helper) =========== + +# Tracks which (layer, stage) pairs are already saved per step kind, so each is +# dumped exactly once for the first prefill step and once for the first decode step. +_M3_DONE: dict[str, set[tuple[int, str]]] = {"prefill": set(), "decode": set()} + + +def _is_dummy_run() -> bool: + """True if the current forward is a warmup / CUDAGraph-capture dummy run. + + Reads is_dummy_run off the forward context's Context. Returns False when no + context is available (e.g. unit tests calling the dump helper directly), so + standalone usage still dumps. + """ + try: + from atom.utils.forward_context import get_forward_context + + ctx = get_forward_context().context + return bool(getattr(ctx, "is_dummy_run", False)) + except Exception: + return False + + +def maybe_dump_minimax_m3_layer( + hidden_states: torch.Tensor, + layer_idx: int, + stage: str, + last_layer_idx: int, +) -> None: + """Print mean/var and dump a MiniMax-M3 layer hidden state. Env-gated no-op. + + No-op unless ATOM_M3_DUMP_DIR is set, so it is safe to leave wired into the + model forward in production (zero overhead on the unset path). Intended to be + called from MiniMaxM3DecoderLayer.forward right after the attention block + output (stage="attn") and after the MoE/MLP block output (stage="moe"). + + Warmup / CUDAGraph-capture dummy forwards are skipped (gated on the forward + context's is_dummy_run flag), so only real request data is dumped. + + Captures the first prefill step and the first decode step only. The step kind + is inferred from the token count: decode == 1 token (per the num_tokens==1 + rule), otherwise prefill. Each (layer, stage) is saved once per kind, so later + prefill/decode steps are ignored. + + Files: {ATOM_M3_DUMP_DIR}/{kind}_layer{LL}_{stage}_rank{R}.pt + + When ATOM_M3_DUMP_EXIT is set (default), the process exits right after the last + layer's MoE output of the first decode step. + + Requires eager execution (--enforce-eager / --level 0): with compilation on, + the model forward is traced by Dynamo and the .item()/torch.save calls here + are not traceable. Leaving the env unset keeps this a pure no-op, so compiled + runs are unaffected. + """ + dump_dir = envs.ATOM_M3_DUMP_DIR + if not dump_dir: + return + if not isinstance(hidden_states, torch.Tensor): + return + + # Skip warmup / capture dummy forwards — only dump real request data. The + # warmup pass sets is_dummy_run=True on the forward context (see + # model_runner.warmup_model). Absent a context (e.g. unit tests), treat as + # real so standalone calls still dump. + if _is_dummy_run(): + return + + kind = "decode" if hidden_states.shape[0] == 1 else "prefill" + key = (layer_idx, stage) + if key in _M3_DONE[kind]: + return + _M3_DONE[kind].add(key) + + os.makedirs(dump_dir, exist_ok=True) + rank = _get_rank() + logger = logging.getLogger("atom") + + tf = hidden_states.detach().float() + mean = tf.mean().item() + var = tf.var().item() + logger.info( + f"[M3_DUMP] {kind} layer{layer_idx:03d} {stage}: " + f"mean={mean:.6e} var={var:.6e} shape={tuple(hidden_states.shape)} " + f"rank={rank}" + ) + fname = os.path.join( + dump_dir, f"{kind}_layer{layer_idx:03d}_{stage}_rank{rank}.pt" + ) + torch.save( + { + "hidden": hidden_states.detach().cpu(), + "shape": tuple(hidden_states.shape), + "layer": layer_idx, + "stage": stage, + "kind": kind, + "mean": mean, + "var": var, + }, + fname, + ) + + # Exit after the last layer's MoE output of the first decode step. + if ( + envs.ATOM_M3_DUMP_EXIT + and kind == "decode" + and stage == "moe" + and layer_idx == last_layer_idx + ): + logger.info( + "[M3_DUMP] captured first decode step through last layer " + f"{layer_idx}; exiting." + ) + import torch.distributed as dist + + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + sys.exit(0) + + # === Weight dump ===================================================== diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 3bd2a53ab..0b9ab3ddd 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -183,6 +183,12 @@ "ATOM_FWD_DUMP_LAYER_ATTR", "layer_id" ), "ATOM_FWD_DUMP_ONE_SHOT": lambda: os.getenv("ATOM_FWD_DUMP_ONE_SHOT", "1") == "1", + # MiniMax-M3 per-layer attn/moe hidden dump + mean/var print. Captures one + # prefill step and one decode step (decode = num_tokens == 1), then exits + # after the last layer's MoE of the first decode step. Requires eager mode + # (--enforce-eager) so the submodule forward hooks fire during decode. + "ATOM_M3_DUMP_DIR": lambda: os.getenv("ATOM_M3_DUMP_DIR", ""), + "ATOM_M3_DUMP_EXIT": lambda: os.getenv("ATOM_M3_DUMP_EXIT", "1") == "1", # Per-rank weight dump + sys.exit(0) — for byte-equal weight comparison. "ATOM_WEIGHT_DUMP_DIR": lambda: os.getenv("ATOM_WEIGHT_DUMP_DIR", ""), "ATOM_WEIGHT_DUMP_LAYERS": lambda: os.getenv("ATOM_WEIGHT_DUMP_LAYERS", "0"), diff --git a/curl_minimax.sh b/curl_minimax.sh new file mode 100644 index 000000000..c6f0f4345 --- /dev/null +++ b/curl_minimax.sh @@ -0,0 +1,4 @@ +curl -X POST "http://localhost:8014/v1/completions" \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "The capital of China is", "temperature": 0, "top_p": 1, "top_k": 1, "repetition_penalty": 1.0, "presence_penalty": 0, "frequency_penalty": 0, "stream": false, "ignore_eos": false, "n": 1, "seed": 123, "max_tokens": 20}' \ No newline at end of file diff --git a/serve_minimax.sh b/serve_minimax.sh new file mode 100644 index 000000000..536431f5c --- /dev/null +++ b/serve_minimax.sh @@ -0,0 +1,143 @@ +# export AITER_QUICK_REDUCE_QUANTIZATION=INT4 +# export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 +# export ATOM_USE_GLUON_PA_DECODE=1 +# export HIP_VISIBLE_DEVICES=6,7 + +# vllm serve /workspace/shared/data/amd_int/models/MiniMax-M2.5 \ +# --host localhost \ +# --port 8100 \ +# --async-scheduling \ +# --load-format fastsafetensors \ +# --tensor-parallel-size 2 \ +# --trust-remote-code \ +# --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ +# --kv-cache-dtype fp8 \ +# --max-num-batched-tokens 16384 \ +# --max-model-len 16384 \ +# --gpu-memory-utilization 0.9 \ +# --no-enable-prefix-caching \ +# --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile", "torch_profiler_with_stack": "False"}' \ + # --enforce-eager + +# model_path=/workspace/shared/data/amd_int/models/MiniMax-M3-MXFP4/ + +# model_path=/workspace/shared/data/amd_int/models/MiniMax-M3 +export HIP_VISIBLE_DEVICES=0,1,2,3 +export ATOM_FORCE_ATTN_TRITON=1 +export ATOM_M3_DUMP_DIR="./minimax_dump/" +# export HSA_ENABLE_SDMA=1 +# export HSA_USE_SVM=1 +# export HSA_XNACK=1 +# export ATOM_USE_TRITON_MOE=1 +# export ATOM_USE_TRITON_GEMM=1 +# export ENABLE_CK=0 +# export AITER_USE_OPUS_MOE_SORTING=1 +# export ATOM_USE_UNIFIED_ATTN=0 + + +# python -m atom.entrypoints.openai_server \ +# --model $model_path \ +# -tp 4 --server-port 8013 --trust-remote-code --gpu-memory-utilization 0.7 \ +# --block-size 128 \ +# --no-enable_prefix_caching \ + # --torch-profiler-dir ./trace --mark-trace +main_model=/workspace/shared/data/amd_int/models/MiniMax-M3-MXFP4 +# draft_model=/workspace/shared/data/amd_int/models/MiniMax-M3-EAGLE3 +# export HIP_VISIBLE_DEVICES=0,1,2,3 +python -m atom.entrypoints.openai_server --model $main_model \ + -tp 4 --server-port 8014 --trust-remote-code --gpu-memory-utilization 0.8 --block-size 128 --no-enable_prefix_caching \ + --max-num-batched-tokens 32768 --max-model-len 32768 --max-num-seqs 128 --enforce-eager --level 0 + # --kv_cache_dtype fp8 \ + # --torch-profiler-dir ./trace + # --method eagle3 --draft-model $draft_model --num-speculative-tokens 3 \ + + # --enforce-eager + + +#!/usr/bin/env bash +# set -euo pipefail + +# rm -f "${VLLM_CACHE_ROOT:-$HOME/.cache/vllm}"/modelinfos/*minimax_m3* 2>/dev/null || true +# ok +# SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# LOG_FILE="$SCRIPT_DIR/minimax-m3-preview-2.log" + +# unset VL MM_BATCHED MM_ENCODER_ATTN + +# export HIP_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +# # export HIP_VISIBLE_DEVICES="4,5,6,7" +# export ATOM_USE_TRITON_MOE="${ATOM_USE_TRITON_MOE:-1}" + +# MODEL="/shared/data/amd_int/models/MiniMax-M3" +# SERVED_NAME="${SERVED_NAME:-MiniMax-M3}" +# PORT="${PORT:-8000}" +# TP="${TP:-8}" +# GPU_MEM_UTIL="${GPU_MEM_UTIL:-0.90}" +# # GPU_MEM_UTIL="${GPU_MEM_UTIL:-0.8}" +# MAX_LEN="${MAX_LEN:-16384}" +# MAX_BATCHED_TOKENS="${MAX_BATCHED_TOKENS:-16384}" +# # MAX_SEQS="${MAX_SEQS:-8}" +# ATTN="TRITON_ATTN" +# LOAD_FORMAT="${LOAD_FORMAT:-auto}" +# SKIP_TOKENIZER_INIT="${SKIP_TOKENIZER_INIT:-0}" +# ENFORCE_EAGER="${ENFORCE_EAGER:-0}" +# EXTRA_ARGS=() +# if [[ "$SKIP_TOKENIZER_INIT" == "1" ]]; then +# EXTRA_ARGS+=(--skip-tokenizer-init) +# fi +# if [[ "$ENFORCE_EAGER" == "1" ]]; then +# EXTRA_ARGS+=(--enforce-eager) +# fi + +# echo "### serve: model=$MODEL" +# echo "### serve: devices=$HIP_VISIBLE_DEVICES tp=$TP port=$PORT max_len=$MAX_LEN" +# echo "### serve: log=$LOG_FILE" +# export VLLM_CUSTOM_SCOPES_FOR_PROFILING=1 +# export VLLM_BATCH_INVARIANT="${VLLM_BATCH_INVARIANT:-1}" +# vllm serve "$MODEL" \ +# --dtype bfloat16 \ +# --load-format "$LOAD_FORMAT" \ +# --host localhost \ +# --port "$PORT" \ +# --tensor-parallel-size "$TP" \ +# --gpu-memory-utilization "$GPU_MEM_UTIL" \ +# --max-model-len "$MAX_LEN" \ +# --max-num-batched-tokens "$MAX_BATCHED_TOKENS" \ +# --block-size 128 \ +# --no-enable-prefix-caching \ +# --language-model-only \ +# --no-trust-remote-code \ +# --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ +# "${EXTRA_ARGS[@]}" \ +# 2>&1 | tee "$LOG_FILE" + +# export HIP_VISIBLE_DEVICES=4,5,6,7 +# AITER_QUICK_REDUCE_QUANTIZATION=INT4 \ +# HSA_ENABLE_SDMA=1 \ +# HSA_USE_SVM=1 \ +# HSA_XNACK=1 \ +# AITER_DISABLE_KERNARG_PRELOAD=1 \ +# ATOM_USE_TRITON_MOE=1 \ +# ATOM_FORCE_ATTN_TRITON=1 \ +# ATOM_LOADER_USE_THREADPOOL=0 \ +# AITER_ROPE_TRITON_BACKEND=1 \ +# ATOM_ENABLE_DS_QKNORM_FUSION=0 \ +# ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION=0 \ +# ATOM_ENABLE_DS_QKNORM_QUANT_FUSION=0 \ +# ATOM_USE_TRITON_GEMM=1 \ +# ENABLE_DS_QKNORM_FUSION=0 \ +# ENABLE_CK=0 \ +# AITER_USE_OPUS_MOE_SORTING=1 \ +# ATOM_USE_UNIFIED_ATTN=0 \ +# python -m atom.entrypoints.openai_server \ +# --model /workspace/shared/data/amd_int/models/MiniMax-M3-MXFP4 \ +# --server-port 8013 \ +# --trust-remote-code \ +# -tp 4 \ +# --gpu-memory-utilization 0.8 \ +# --block-size 128 \ +# --max-model-len 32768 \ +# --max-num-seqs 128 \ +# --max-num-batched-tokens 32768 \ +# --torch-profiler-dir /app/trace \ +# --no-enable_prefix_caching \ No newline at end of file From e3f906d3f7e22924e39e9a6fd40a4f712ae5d854 Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 29 Jun 2026 13:45:40 +0000 Subject: [PATCH 4/8] maybe acc right Signed-off-by: ganyi --- atom/model_ops/moe.py | 89 +++++++++++++- atom/models/minimax_m3.py | 12 +- atom/utils/debug_helper/__init__.py | 4 + atom/utils/debug_helper/dump.py | 176 ++++++++++++++++++++++++++++ atom/utils/envs.py | 12 ++ 5 files changed, 281 insertions(+), 12 deletions(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index d4f9d2925..077c9b113 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -779,6 +779,39 @@ def rocm_aiter_fused_moe_fake( ) +def _interleave_gate_up_rows_(layer: torch.nn.Module) -> None: + """Reorder w13 gate/up rows from SEPARATED (gguu) to INTERLEAVED (gugu). + + Operates in place on ``layer.w13_weight``, ``layer.w13_weight_scale`` and + ``layer.w13_bias`` (if present) along their row axis (dim=1), which is + ``2 * intermediate_size`` laid out as ``[gate(0..I-1) | up(0..I-1)]``. The + new order is ``[gate0, up0, gate1, up1, ...]`` so that a downstream consumer + splitting even/odd rows (the triton a16w4 SwiGLU kernel) reads matching + gate/up pairs. ``w2`` (down_proj) has no gate/up split and is untouched. + + Idempotency guard: sets ``layer._w13_gate_up_interleaved`` so a double call + (e.g. process_weights_after_loading invoked twice) is a no-op. + """ + if getattr(layer, "_w13_gate_up_interleaved", False): + return + + def _interleave_rows(t: torch.Tensor) -> torch.Tensor: + # t: (E, 2I, ...) with rows [gate(I) | up(I)] -> [g0,u0,g1,u1,...] + two_i = t.shape[1] + assert two_i % 2 == 0, f"w13 row dim {two_i} not even" + i = two_i // 2 + idx = torch.empty(two_i, dtype=torch.long, device=t.device) + idx[0::2] = torch.arange(i, device=t.device) # gate -> even + idx[1::2] = torch.arange(i, device=t.device) + i # up -> odd + return t.index_select(1, idx).contiguous() + + layer.w13_weight.data = _interleave_rows(layer.w13_weight.data) + layer.w13_weight_scale.data = _interleave_rows(layer.w13_weight_scale.data) + if getattr(layer, "w13_bias", None) is not None: + layer.w13_bias.data = _interleave_rows(layer.w13_bias.data) + layer._w13_gate_up_interleaved = True + + class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig): super().__init__(moe) @@ -938,6 +971,16 @@ def process_weights_after_loading(self, layer): if layer.w2_bias is not None: layer.w2_bias.data = layer.w2_bias.data.to(torch.float32) + # Env-gated raw (pre-transform) MXFP4 expert weight stash, for the + # default-vs-triton fused_moe isolation test. The default (CK shuffle) + # and triton (_swizzle_mxfp4) paths transform these SAME raw tensors + # differently; saving them here lets the offline test apply both + # transforms and compare kernels on identical weights. No-op unless + # ATOM_MOE_RAWW_DUMP_DIR is set. + from atom.utils.debug_helper import maybe_dump_mxfp4_raw_weights + + # maybe_dump_mxfp4_raw_weights(layer) + if self.static_input_scales: layer.w13_input_scale = atom_parameter( layer.w13_input_scale.max().to(torch.float32) @@ -984,6 +1027,25 @@ def process_weights_after_loading(self, layer): else: layer.shared_w2_bias = None + # gguu -> gugu interleave for the routed triton SwiGLU kernel. + # + # The checkpoint stores w13 gate/up rows SEPARATED ("gguu": + # [all-gate | all-up]). The aiter default CK fused_moe path consumes + # that directly. The triton a16w4 SwiGLU kernel, however, fuses the + # activation into the GEMM epilogue and splits each tile as + # INTERLEAVED ("gugu": a[...,::2]=gate, a[...,1::2]=up). Feeding gguu + # weights to that kernel mixes gate with gate and misreads up, + # corrupting the output. We interleave the routed w13 rows once here + # so the existing (well-tested) interleaved kernel produces results + # identical to the default path. w2 (down_proj) has no gate/up split. + # + # NOTE: this runs AFTER the shared-expert stash above, which keeps the + # shared experts in gguu for the dense swiglu_oai_split path. Only the + # SwiGLU triton branch needs interleaved rows; the SiLU branch uses + # fused_clamp_act_mul, which half-splits gguu — so gate on activation. + if getattr(layer, "activation", None) == ActivationType.Swiglu: + _interleave_gate_up_rows_(layer) + ( w13_weight, w13_scale, @@ -1231,12 +1293,11 @@ def apply( "swiglu_limit": getattr(layer, "swiglu_limit", 0.0), } if self.fused_experts is None: - return fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, + # Positional kernel args, then the keyword kernel args (incl. + # moe_extra_args: gate_mode / swiglu_limit). Built once so the same + # values feed both the kernel call and the env-gated I/O dump. + fused_moe_pos = (x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids) + fused_moe_kw = dict( expert_mask=layer.expert_mask, activation=activation, quant_type=self.quant_type, @@ -1251,6 +1312,22 @@ def apply( bias2=layer.w2_bias, **moe_extra_args, ) + # Env-gated kernel I/O dump for offline correctness debugging. + from atom.utils.debug_helper import maybe_dump_fused_moe_io + + _moe_dump_args = { + "x": x, + "w1": layer.w13_weight, + "w2": layer.w2_weight, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + **fused_moe_kw, + } + layer_name = getattr(layer, "layer_name", "") + # maybe_dump_fused_moe_io(layer_name, _moe_dump_args) + moe_out = fused_moe(*fused_moe_pos, **fused_moe_kw) + # maybe_dump_fused_moe_io(layer_name, _moe_dump_args, output=moe_out) + return moe_out return self.fused_experts( hidden_states=x, w1=layer.w13_weight, diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index 41a548ac1..5a8061faa 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -616,17 +616,17 @@ def forward( aux_out.append(residual.clone()) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) - maybe_dump_minimax_m3_layer( - hidden_states, self.layer_num, "attn", self._last_layer_idx - ) + # maybe_dump_minimax_m3_layer( + # hidden_states, self.layer_num, "attn", self._last_layer_idx + # ) hidden_states, residual = fused_allreduce_gemma_rms_norm( hidden_states, residual, self.post_attention_layernorm ) ffn = self.block_sparse_moe if self.is_moe_layer else self.mlp hidden_states = ffn(hidden_states) - maybe_dump_minimax_m3_layer( - hidden_states, self.layer_num, "moe", self._last_layer_idx - ) + # maybe_dump_minimax_m3_layer( + # hidden_states, self.layer_num, "moe", self._last_layer_idx + # ) return hidden_states, residual diff --git a/atom/utils/debug_helper/__init__.py b/atom/utils/debug_helper/__init__.py index a3835c9c5..9708db405 100644 --- a/atom/utils/debug_helper/__init__.py +++ b/atom/utils/debug_helper/__init__.py @@ -31,7 +31,9 @@ ) from atom.utils.debug_helper.dump import ( install_block_forward_hooks, + maybe_dump_fused_moe_io, maybe_dump_minimax_m3_layer, + maybe_dump_mxfp4_raw_weights, maybe_dump_weights_and_exit, maybe_log_topk, ) @@ -44,7 +46,9 @@ __all__ = [ # dump "install_block_forward_hooks", + "maybe_dump_fused_moe_io", "maybe_dump_minimax_m3_layer", + "maybe_dump_mxfp4_raw_weights", "maybe_dump_weights_and_exit", "maybe_log_topk", # compare primitives diff --git a/atom/utils/debug_helper/dump.py b/atom/utils/debug_helper/dump.py index 3c97c4939..b378d5e6c 100644 --- a/atom/utils/debug_helper/dump.py +++ b/atom/utils/debug_helper/dump.py @@ -281,6 +281,182 @@ def maybe_dump_minimax_m3_layer( sys.exit(0) +# === fused_moe kernel I/O dump ======================================= + +# Tracks which (layer, kind) pairs already dumped, so each layer is captured once +# per first prefill step and once per first decode step. +_MOE_DONE: set[tuple[str, str]] = set() + + +def _to_cpu(obj): + """Detach + move tensors to CPU recursively; pass scalars/enums through.""" + if isinstance(obj, torch.Tensor): + return obj.detach().cpu() + if isinstance(obj, (list, tuple)): + return type(obj)(_to_cpu(o) for o in obj) + if isinstance(obj, dict): + return {k: _to_cpu(v) for k, v in obj.items()} + return obj + + +def maybe_dump_fused_moe_io( + layer_name: str, + kernel_inputs: dict, + output: Optional[torch.Tensor] = None, +) -> None: + """Dump fused_moe kernel inputs (and optionally output). Env-gated no-op. + + No-op unless ATOM_MOE_DUMP_DIR is set. Skips warmup / CUDAGraph-capture dummy + forwards (is_dummy_run). Captures the first prefill step and the first decode + step per layer (kind inferred from the activation token count: 1 == decode, + else prefill). Each (layer, kind) is saved once. + + Intended to bracket the fused_moe(...) call in Mxfp4MoEMethod.apply: call once + before the kernel with the full kwargs dict (output=None) to record inputs, + then again after with the same layer_name + the realized output tensor to + append the output to the same file. + + kernel_inputs should map arg name -> value (tensors, scalars, enums). Tensors + are detached + moved to CPU; non-tensors are stored verbatim as metadata, so + the saved file is self-contained for offline kernel replay. + + Files: {ATOM_MOE_DUMP_DIR}/moe_{kind}_{safe_layer_name}_rank{R}.pt + Requires eager mode (--enforce-eager): the .cpu()/torch.save calls are not + Dynamo-traceable; the env-unset no-op path keeps compiled runs unaffected. + """ + dump_dir = envs.ATOM_MOE_DUMP_DIR + if not dump_dir: + return + if _is_dummy_run(): + return + + # Determine token count from the activation input ("x"/"hidden_states"). + x = kernel_inputs.get("x") + if x is None: + x = kernel_inputs.get("hidden_states") + if not isinstance(x, torch.Tensor): + return + kind = "decode" if x.shape[0] == 1 else "prefill" + + # Optional per-layer-index filter (parse "model.layers.." from the name). + wanted = _parse_layer_set(envs.ATOM_MOE_DUMP_LAYERS) + if wanted is not None: + import re + + m = re.search(r"(?:^|\.)layers\.(\d+)(?:\.|$)", layer_name) + if m is None or int(m.group(1)) not in wanted: + return + + key = (layer_name, kind) + is_input_call = output is None + if is_input_call and key in _MOE_DONE: + return # already captured this (layer, kind) + + os.makedirs(dump_dir, exist_ok=True) + rank = _get_rank() + logger = logging.getLogger("atom") + safe = layer_name.replace("/", "_").replace(".", "_") or "moe" + fname = os.path.join(dump_dir, f"moe_{kind}_{safe}_rank{rank}.pt") + + if is_input_call: + _MOE_DONE.add(key) + pkt = { + "_layer_name": layer_name, + "_kind": kind, + "_rank": rank, + "inputs": {k: _to_cpu(v) for k, v in kernel_inputs.items()}, + } + torch.save(pkt, fname) + tensor_args = [ + k for k, v in kernel_inputs.items() if isinstance(v, torch.Tensor) + ] + logger.info( + f"[MOE_DUMP] {kind} {layer_name}: saved inputs " + f"({len(tensor_args)} tensors) shape_x={tuple(x.shape)} rank={rank}" + ) + return + + # Output call: append output to the existing input file (must follow input). + if not isinstance(output, torch.Tensor): + return + pkt = {} + if os.path.exists(fname): + # weights_only=False: the input pkt stores enums (ActivationType, + # QuantType) that are not in torch's safe-globals allowlist. + pkt = torch.load(fname, weights_only=False) + pkt["output"] = output.detach().cpu() + pkt["_output_shape"] = tuple(output.shape) + torch.save(pkt, fname) + of = output.detach().float() + logger.info( + f"[MOE_DUMP] {kind} {layer_name}: saved output " + f"mean={of.mean().item():.6e} var={of.var().item():.6e} " + f"shape={tuple(output.shape)} rank={rank}" + ) + + +# === MXFP4 raw expert weight stash (isolation test) ================== + +_MOE_RAWW_DONE: set[str] = set() + + +def maybe_dump_mxfp4_raw_weights(layer: "torch.nn.Module") -> None: + """Stash raw (pre-transform) MXFP4 expert weights for one MoE layer. + + No-op unless ATOM_MOE_RAWW_DUMP_DIR is set. Called at the top of + Mxfp4MoEMethod.process_weights_after_loading, BEFORE the CK shuffle / triton + swizzle, so the saved tensors are the common raw input both paths consume. + The offline isolation test applies both transforms to these and compares + kernels. Saves once per layer_name. + + File: {ATOM_MOE_RAWW_DUMP_DIR}/raww_{safe_layer_name}_rank{R}.pt + keys: w13_weight, w13_weight_scale, w2_weight, w2_weight_scale, + w13_bias, w2_bias, w13_input_scale, w2_input_scale, layer_name + """ + dump_dir = envs.ATOM_MOE_RAWW_DUMP_DIR + if not dump_dir: + return + layer_name = getattr(layer, "layer_name", "") or getattr(layer, "prefix", "") + + wanted = _parse_layer_set(envs.ATOM_MOE_RAWW_DUMP_LAYERS) + if wanted is not None: + import re + + m = re.search(r"(?:^|\.)layers\.(\d+)(?:\.|$)", layer_name) + if m is None or int(m.group(1)) not in wanted: + return + if layer_name in _MOE_RAWW_DONE: + return + _MOE_RAWW_DONE.add(layer_name) + + os.makedirs(dump_dir, exist_ok=True) + rank = _get_rank() + + def _g(name): + v = getattr(layer, name, None) + if isinstance(v, torch.Tensor): + return v.detach().cpu().clone() + return v + + pkt = { + "layer_name": layer_name, + "_rank": rank, + "w13_weight": _g("w13_weight"), + "w13_weight_scale": _g("w13_weight_scale"), + "w2_weight": _g("w2_weight"), + "w2_weight_scale": _g("w2_weight_scale"), + "w13_bias": _g("w13_bias"), + "w2_bias": _g("w2_bias"), + "w13_input_scale": _g("w13_input_scale"), + "w2_input_scale": _g("w2_input_scale"), + } + safe = layer_name.replace("/", "_").replace(".", "_") or "moe" + torch.save(pkt, os.path.join(dump_dir, f"raww_{safe}_rank{rank}.pt")) + logging.getLogger("atom").info( + f"[MOE_RAWW] saved raw MXFP4 weights for {layer_name} rank={rank}" + ) + + # === Weight dump ===================================================== diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 0b9ab3ddd..229dbfe18 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -189,6 +189,18 @@ # (--enforce-eager) so the submodule forward hooks fire during decode. "ATOM_M3_DUMP_DIR": lambda: os.getenv("ATOM_M3_DUMP_DIR", ""), "ATOM_M3_DUMP_EXIT": lambda: os.getenv("ATOM_M3_DUMP_EXIT", "1") == "1", + # fused_moe kernel I/O dump — saves all kernel tensor args (activations, + # expert weights, scales, routing, biases) + the kernel output for one + # prefill and one decode step per layer, for offline kernel-correctness + # replay. Requires eager mode. ATOM_MOE_DUMP_LAYERS optionally restricts to a + # comma-separated set of layer indices (default: all). + "ATOM_MOE_DUMP_DIR": lambda: os.getenv("ATOM_MOE_DUMP_DIR", ""), + "ATOM_MOE_DUMP_LAYERS": lambda: os.getenv("ATOM_MOE_DUMP_LAYERS", ""), + # Raw (pre-transform) MXFP4 expert weight stash for the default-vs-triton + # fused_moe isolation test. Saves w13/w2 weight+scale+bias per MoE layer + # before process_weights_after_loading shuffles/swizzles them. + "ATOM_MOE_RAWW_DUMP_DIR": lambda: os.getenv("ATOM_MOE_RAWW_DUMP_DIR", ""), + "ATOM_MOE_RAWW_DUMP_LAYERS": lambda: os.getenv("ATOM_MOE_RAWW_DUMP_LAYERS", ""), # Per-rank weight dump + sys.exit(0) — for byte-equal weight comparison. "ATOM_WEIGHT_DUMP_DIR": lambda: os.getenv("ATOM_WEIGHT_DUMP_DIR", ""), "ATOM_WEIGHT_DUMP_LAYERS": lambda: os.getenv("ATOM_WEIGHT_DUMP_LAYERS", "0"), From cbd2abd9f7e6cbe8815545c3dececbeeca93a016 Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 29 Jun 2026 14:07:00 +0000 Subject: [PATCH 5/8] uint8 to view Signed-off-by: ganyi --- atom/model_ops/moe.py | 41 +++++++++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 077c9b113..258d56f65 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -789,26 +789,47 @@ def _interleave_gate_up_rows_(layer: torch.nn.Module) -> None: splitting even/odd rows (the triton a16w4 SwiGLU kernel) reads matching gate/up pairs. ``w2`` (down_proj) has no gate/up split and is untouched. + The reorder is on whole rows (the ``2I`` axis); the MXFP4-packed pairs live + on the LAST axis (``H//2`` bytes/row), so reordering never splits a packed + byte. We therefore reorder a ``uint8`` view of the FP4 weight — bit-exact, + no dequant/requant, no scale recompute (torch has no ``index_select`` for the + ``float4_e2m1fn_x2`` dtype itself). The e8m0 scale (already uint8) and the + bf16 bias are reordered the same way so they stay matched to the weight rows. + Idempotency guard: sets ``layer._w13_gate_up_interleaved`` so a double call (e.g. process_weights_after_loading invoked twice) is a no-op. """ if getattr(layer, "_w13_gate_up_interleaved", False): return - def _interleave_rows(t: torch.Tensor) -> torch.Tensor: - # t: (E, 2I, ...) with rows [gate(I) | up(I)] -> [g0,u0,g1,u1,...] - two_i = t.shape[1] + def _gugu_idx(two_i: int, device) -> torch.Tensor: + # rows [gate(0..I-1) | up(0..I-1)] -> [g0,u0,g1,u1,...] assert two_i % 2 == 0, f"w13 row dim {two_i} not even" i = two_i // 2 - idx = torch.empty(two_i, dtype=torch.long, device=t.device) - idx[0::2] = torch.arange(i, device=t.device) # gate -> even - idx[1::2] = torch.arange(i, device=t.device) + i # up -> odd - return t.index_select(1, idx).contiguous() + idx = torch.empty(two_i, dtype=torch.long, device=device) + idx[0::2] = torch.arange(i, device=device) # gate -> even + idx[1::2] = torch.arange(i, device=device) + i # up -> odd + return idx + + def _reorder_rows(t: torch.Tensor) -> torch.Tensor: + """index_select on dim=1, via a uint8 view if the dtype is unsupported. + + FP4 (float4_e2m1fn_x2) has no index_select; view it as uint8 (bytes are + row-contiguous, so a row reorder is exact), select, then view back. + """ + idx = _gugu_idx(t.shape[1], t.device) + orig_dtype = t.dtype + try: + return t.index_select(1, idx).contiguous() + except (RuntimeError, NotImplementedError): + return ( + t.view(torch.uint8).index_select(1, idx).contiguous().view(orig_dtype) + ) - layer.w13_weight.data = _interleave_rows(layer.w13_weight.data) - layer.w13_weight_scale.data = _interleave_rows(layer.w13_weight_scale.data) + layer.w13_weight.data = _reorder_rows(layer.w13_weight.data) + layer.w13_weight_scale.data = _reorder_rows(layer.w13_weight_scale.data) if getattr(layer, "w13_bias", None) is not None: - layer.w13_bias.data = _interleave_rows(layer.w13_bias.data) + layer.w13_bias.data = _reorder_rows(layer.w13_bias.data) layer._w13_gate_up_interleaved = True From 0b460f78a730d96f8f850e454fa91f33dbc47a25 Mon Sep 17 00:00:00 2001 From: ganyi Date: Mon, 29 Jun 2026 14:30:11 +0000 Subject: [PATCH 6/8] reduce memory consumption Signed-off-by: ganyi --- atom/model_ops/moe.py | 64 +++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 258d56f65..5befde68a 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -791,10 +791,14 @@ def _interleave_gate_up_rows_(layer: torch.nn.Module) -> None: The reorder is on whole rows (the ``2I`` axis); the MXFP4-packed pairs live on the LAST axis (``H//2`` bytes/row), so reordering never splits a packed - byte. We therefore reorder a ``uint8`` view of the FP4 weight — bit-exact, - no dequant/requant, no scale recompute (torch has no ``index_select`` for the - ``float4_e2m1fn_x2`` dtype itself). The e8m0 scale (already uint8) and the - bf16 bias are reordered the same way so they stay matched to the weight rows. + byte. For FP4 we reorder a ``uint8`` view (bit-exact, no dequant/requant, no + scale recompute — torch has no ``index_select`` for ``float4_e2m1fn_x2``). + + Memory: the permutation is applied **in place per expert**, reusing the + existing storage. A plain ``index_select(...).contiguous()`` over the whole + tensor would double peak memory (a full second copy of the ~GB-sized w13), + which OOMs across many layers at load time. Here the only transient is one + expert's rows (a few MB), so peak overhead is negligible. Idempotency guard: sets ``layer._w13_gate_up_interleaved`` so a double call (e.g. process_weights_after_loading invoked twice) is a no-op. @@ -802,34 +806,40 @@ def _interleave_gate_up_rows_(layer: torch.nn.Module) -> None: if getattr(layer, "_w13_gate_up_interleaved", False): return - def _gugu_idx(two_i: int, device) -> torch.Tensor: - # rows [gate(0..I-1) | up(0..I-1)] -> [g0,u0,g1,u1,...] - assert two_i % 2 == 0, f"w13 row dim {two_i} not even" - i = two_i // 2 - idx = torch.empty(two_i, dtype=torch.long, device=device) - idx[0::2] = torch.arange(i, device=device) # gate -> even - idx[1::2] = torch.arange(i, device=device) + i # up -> odd - return idx + def _interleave_inplace(t: torch.Tensor) -> None: + """In-place row reorder [g..|u..] -> [g0,u0,g1,u1,...], per expert. - def _reorder_rows(t: torch.Tensor) -> torch.Tensor: - """index_select on dim=1, via a uint8 view if the dtype is unsupported. + Only FP4 (float4_e2m1fn_x2) needs the uint8 view, because torch has no + index_select/reshape for it AND its bytes are row-contiguous so the row + reorder is exact. Other dtypes (uint8 e8m0 scale, bf16/float bias) are + reordered directly — viewing them as uint8 would reinterpret bytes and + corrupt the values. - FP4 (float4_e2m1fn_x2) has no index_select; view it as uint8 (bytes are - row-contiguous, so a row reorder is exact), select, then view back. + Per expert e: view its 2I rows as (2, I, *rest), transpose to + (I, 2, *rest) (the gugu order); reshape forces one small (single-expert) + temp, then copy_ it back into the same storage. No full-tensor duplicate. """ - idx = _gugu_idx(t.shape[1], t.device) - orig_dtype = t.dtype - try: - return t.index_select(1, idx).contiguous() - except (RuntimeError, NotImplementedError): - return ( - t.view(torch.uint8).index_select(1, idx).contiguous().view(orig_dtype) - ) + _fp4 = getattr(torch, "float4_e2m1fn_x2", None) + if _fp4 is not None and t.dtype == _fp4: + buf = t.view(torch.uint8) + else: + buf = t - layer.w13_weight.data = _reorder_rows(layer.w13_weight.data) - layer.w13_weight_scale.data = _reorder_rows(layer.w13_weight_scale.data) + E, two_i = buf.shape[0], buf.shape[1] + assert two_i % 2 == 0, f"w13 row dim {two_i} not even" + i = two_i // 2 + rest = buf.shape[2:] + for e in range(E): + rows = buf[e] # view into storage, shape (2I, *rest) + gugu = ( + rows.view(2, i, *rest).transpose(0, 1).reshape(two_i, *rest) + ) # one (single-expert) temp + rows.copy_(gugu) # write back into the same storage + + _interleave_inplace(layer.w13_weight.data) + _interleave_inplace(layer.w13_weight_scale.data) if getattr(layer, "w13_bias", None) is not None: - layer.w13_bias.data = _reorder_rows(layer.w13_bias.data) + _interleave_inplace(layer.w13_bias.data) layer._w13_gate_up_interleaved = True From a319eff4cac207af6a1b759e8aaa4012213bb9ca Mon Sep 17 00:00:00 2001 From: ganyi Date: Tue, 30 Jun 2026 02:19:29 +0000 Subject: [PATCH 7/8] prefill correct Signed-off-by: ganyi --- atom/model_ops/moe.py | 32 +++++++--- atom/models/minimax_m3.py | 20 +++++++ atom/utils/debug_helper/__init__.py | 2 + atom/utils/debug_helper/compare.py | 40 +++++++++++++ atom/utils/debug_helper/dump.py | 91 +++++++++++++++++++++++++++++ atom/utils/envs.py | 4 ++ 6 files changed, 180 insertions(+), 9 deletions(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 5befde68a..e14fc665b 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1010,7 +1010,7 @@ def process_weights_after_loading(self, layer): # ATOM_MOE_RAWW_DUMP_DIR is set. from atom.utils.debug_helper import maybe_dump_mxfp4_raw_weights - # maybe_dump_mxfp4_raw_weights(layer) + maybe_dump_mxfp4_raw_weights(layer) if self.static_input_scales: layer.w13_input_scale = atom_parameter( @@ -1037,24 +1037,30 @@ def process_weights_after_loading(self, layer): # the MoE-kernel-only CDNA4 layout. n_shared = layer.num_fused_shared_experts if n_shared > 0: + # IMPORTANT: .clone() (not .contiguous()). A uint8 view of a + # contiguous slice is already contiguous, so .contiguous() returns + # a tensor that SHARES storage with w13_weight. The gguu->gugu + # interleave below mutates w13_weight in place, which would then + # corrupt these stashed shared weights (consumed by the gguu + # half-split swiglu_oai_split). Clone to fully detach. layer.shared_w13_weight = ( - layer.w13_weight.data[-n_shared:].view(torch.uint8).contiguous() + layer.w13_weight.data[-n_shared:].view(torch.uint8).clone() ) layer.shared_w13_weight_scale = layer.w13_weight_scale.data[ -n_shared: - ].contiguous() + ].clone() layer.shared_w2_weight = ( - layer.w2_weight.data[-n_shared:].view(torch.uint8).contiguous() + layer.w2_weight.data[-n_shared:].view(torch.uint8).clone() ) layer.shared_w2_weight_scale = layer.w2_weight_scale.data[ -n_shared: - ].contiguous() + ].clone() if layer.w13_bias is not None: - layer.shared_w13_bias = layer.w13_bias.data[-n_shared:].contiguous() + layer.shared_w13_bias = layer.w13_bias.data[-n_shared:].clone() else: layer.shared_w13_bias = None if layer.w2_bias is not None: - layer.shared_w2_bias = layer.w2_bias.data[-n_shared:].contiguous() + layer.shared_w2_bias = layer.w2_bias.data[-n_shared:].clone() else: layer.shared_w2_bias = None @@ -1267,6 +1273,11 @@ def apply( _moe_result = _moe_result + self._apply_shared_experts_dense( layer, x, activation ) + from atom.utils.debug_helper import maybe_dump_moe_apply_io + + maybe_dump_moe_apply_io( + getattr(layer, "layer_name", ""), x, router_logits, _moe_result + ) return _moe_result assert ( @@ -1355,9 +1366,12 @@ def apply( **fused_moe_kw, } layer_name = getattr(layer, "layer_name", "") - # maybe_dump_fused_moe_io(layer_name, _moe_dump_args) + maybe_dump_fused_moe_io(layer_name, _moe_dump_args) moe_out = fused_moe(*fused_moe_pos, **fused_moe_kw) - # maybe_dump_fused_moe_io(layer_name, _moe_dump_args, output=moe_out) + maybe_dump_fused_moe_io(layer_name, _moe_dump_args, output=moe_out) + from atom.utils.debug_helper import maybe_dump_moe_apply_io + + maybe_dump_moe_apply_io(layer_name, x, router_logits, moe_out) return moe_out return self.fused_experts( hidden_states=x, diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index 5a8061faa..12ee273ff 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -317,6 +317,26 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_experts is not None: routed_output = routed_output + self.shared_experts(hidden_states) + import os as _os_m3 + + _md = _os_m3.getenv("ATOM_M3_MOEOUT_DIR") + if _md: + _ln = getattr(getattr(self, "experts", None), "layer_name", "") + import re as _re_m3 + + _m = _re_m3.search(r"\.layers\.(\d+)\.", _ln) + if _m and int(_m.group(1)) in (3, 30, 59): + import torch as _t_m3 + import torch.distributed as _dist_m3 + + _os_m3.makedirs(_md, exist_ok=True) + _rk = _dist_m3.get_rank() if _dist_m3.is_initialized() else 0 + _kind = "decode" if hidden_states.shape[0] == 1 else "prefill" + _t_m3.save( + {"out": routed_output.detach().cpu(), "in": hidden_states.detach().cpu()}, + _os_m3.path.join(_md, f"moeout_l{_m.group(1)}_{_kind}_rank{_rk}.pt"), + ) + return routed_output.view(orig_shape) diff --git a/atom/utils/debug_helper/__init__.py b/atom/utils/debug_helper/__init__.py index 9708db405..5b37894cc 100644 --- a/atom/utils/debug_helper/__init__.py +++ b/atom/utils/debug_helper/__init__.py @@ -33,6 +33,7 @@ install_block_forward_hooks, maybe_dump_fused_moe_io, maybe_dump_minimax_m3_layer, + maybe_dump_moe_apply_io, maybe_dump_mxfp4_raw_weights, maybe_dump_weights_and_exit, maybe_log_topk, @@ -48,6 +49,7 @@ "install_block_forward_hooks", "maybe_dump_fused_moe_io", "maybe_dump_minimax_m3_layer", + "maybe_dump_moe_apply_io", "maybe_dump_mxfp4_raw_weights", "maybe_dump_weights_and_exit", "maybe_log_topk", diff --git a/atom/utils/debug_helper/compare.py b/atom/utils/debug_helper/compare.py index 086b04a0a..a00b4bef6 100644 --- a/atom/utils/debug_helper/compare.py +++ b/atom/utils/debug_helper/compare.py @@ -371,6 +371,39 @@ def cmd_layer_bisect(args: argparse.Namespace) -> int: return rc +def cmd_moe_apply(args: argparse.Namespace) -> int: + """Compare MoE apply() boundary dumps from two runs (default vs triton). + + Each dir holds apply_{kind}_{layer}_rank{R}.pt with keys x / router_logits / + output (written by maybe_dump_moe_apply_io). Matches files by basename across + --a (e.g. default / correct) and --b (e.g. triton / broken) and prints, per + file, cos for x, router_logits, output. The first stage where cos drops tells + you the divergence: + - router_logits drops -> routing bug + - x matches, output drops -> kernel / weight-layout / scale bug + """ + a_files = {os.path.basename(p): p for p in glob.glob(os.path.join(args.a, "*.pt"))} + b_files = {os.path.basename(p): p for p in glob.glob(os.path.join(args.b, "*.pt"))} + common = sorted(set(a_files) & set(b_files)) + if not common: + print(f"no matching files between {args.a} and {args.b}") + return 1 + print(f"{'file':45s} {'stage':14s} {'cos':>9s} {'max_abs':>10s} {'rel':>9s}") + for name in common: + a = torch.load(a_files[name], map_location="cpu", weights_only=False) + b = torch.load(b_files[name], map_location="cpu", weights_only=False) + for stage in ("x", "router_logits", "output"): + ta, tb = a.get(stage), b.get(stage) + if ta is None or tb is None: + continue + cos, ma, rel = cos_max(ta, tb) + print( + f"{name[:45]:45s} {stage:14s} {flag_for(cos, rel)} " + f"{cos:9.6f} {ma:10.3e} {rel:9.3e}" + ) + return 0 + + def cmd_schema(args: argparse.Namespace) -> int: """Show schema diff between two dump files.""" a = torch.load(args.a, map_location="cpu", weights_only=False) @@ -405,6 +438,13 @@ def main(argv: Optional[list[str]] = None) -> int: lb.add_argument("--threshold", type=float, default=COS_NUM_DRIFT) lb.set_defaults(func=cmd_layer_bisect) + ma = sub.add_parser( + "moe-apply", help="Compare MoE apply() dumps (default dir vs triton dir)" + ) + ma.add_argument("--a", required=True, help="dir A (e.g. default / correct run)") + ma.add_argument("--b", required=True, help="dir B (e.g. triton / broken run)") + ma.set_defaults(func=cmd_moe_apply) + sc = sub.add_parser("schema", help="Schema diff between two dump files") sc.add_argument("--a", required=True) sc.add_argument("--b", required=True) diff --git a/atom/utils/debug_helper/dump.py b/atom/utils/debug_helper/dump.py index b378d5e6c..2f672724e 100644 --- a/atom/utils/debug_helper/dump.py +++ b/atom/utils/debug_helper/dump.py @@ -395,6 +395,97 @@ def maybe_dump_fused_moe_io( ) +# === MoE apply() boundary dump (default-vs-triton compare) =========== + +# (layer, kind) already dumped — capture first prefill + first decode per layer. +_MOE_APPLY_DONE: set[tuple[str, str]] = set() + + +def maybe_dump_moe_apply_io( + layer_name: str, + x: torch.Tensor, + router_logits: torch.Tensor, + output: torch.Tensor, +) -> None: + """Dump the MoE apply() boundary: x, router_logits, output. Env-gated no-op. + + No-op unless ATOM_MOE_APPLY_DUMP_DIR is set. These three tensors exist on + BOTH the aiter default and triton paths, so dumping from one run with + ATOM_USE_TRITON_MOE=0 (correct) and another with =1 (broken) lets the compare + CLI diff them per layer to localize divergence: + - x identical, router_logits identical, output differs -> kernel/layout bug + - router_logits differ -> routing bug + + Skips warmup/dummy forwards. Captures first prefill + first decode per layer. + Optional layer filter via ATOM_MOE_APPLY_DUMP_LAYERS (comma-separated). + File: {dir}/apply_{kind}_{safe_layer}_rank{R}.pt + Requires eager mode. + """ + dump_dir = envs.ATOM_MOE_APPLY_DUMP_DIR + if not dump_dir: + return + _dbg = os.getenv("ATOM_MOE_APPLY_DUMP_DEBUG") == "1" + # ATOM_MOE_APPLY_DUMP_ALL=1 bypasses the dummy/warmup skip — keep only the + # LAST write per (layer,kind) so the final real forward wins over warmups. + _skip_dummy = os.getenv("ATOM_MOE_APPLY_DUMP_ALL") != "1" + if _skip_dummy and _is_dummy_run(): + if _dbg: + print(f"[MOE_APPLY_DBG] skip dummy_run layer={layer_name}", flush=True) + return + if not isinstance(x, torch.Tensor): + return + kind = "decode" if x.shape[0] == 1 else "prefill" + + wanted = _parse_layer_set(envs.ATOM_MOE_APPLY_DUMP_LAYERS) + if wanted is not None: + import re + + m = re.search(r"(?:^|\.)layers\.(\d+)(?:\.|$)", layer_name) + if m is None or int(m.group(1)) not in wanted: + if _dbg: + print( + f"[MOE_APPLY_DBG] skip filter layer={layer_name!r} " + f"match={None if m is None else m.group(1)}", + flush=True, + ) + return + if _dbg: + print( + f"[MOE_APPLY_DBG] WRITING layer={layer_name!r} kind={kind} " + f"x={tuple(x.shape)}", + flush=True, + ) + + key = (layer_name, kind) + if _skip_dummy: + # Normal mode: first real call wins, dedup to avoid file churn. + if key in _MOE_APPLY_DONE: + return + _MOE_APPLY_DONE.add(key) + # ALL mode: overwrite each call so the LAST (real) forward wins. + + os.makedirs(dump_dir, exist_ok=True) + rank = _get_rank() + safe = layer_name.replace("/", "_").replace(".", "_") or "moe" + fname = os.path.join(dump_dir, f"apply_{kind}_{safe}_rank{rank}.pt") + torch.save( + { + "_layer_name": layer_name, + "_kind": kind, + "_rank": rank, + "x": x.detach().cpu(), + "router_logits": router_logits.detach().cpu(), + "output": output.detach().cpu(), + }, + fname, + ) + of = output.detach().float() + logging.getLogger("atom").info( + f"[MOE_APPLY] {kind} {layer_name}: x={tuple(x.shape)} " + f"out mean={of.mean().item():.6e} var={of.var().item():.6e} rank={rank}" + ) + + # === MXFP4 raw expert weight stash (isolation test) ================== _MOE_RAWW_DONE: set[str] = set() diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 229dbfe18..ba31a4c7a 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -201,6 +201,10 @@ # before process_weights_after_loading shuffles/swizzles them. "ATOM_MOE_RAWW_DUMP_DIR": lambda: os.getenv("ATOM_MOE_RAWW_DUMP_DIR", ""), "ATOM_MOE_RAWW_DUMP_LAYERS": lambda: os.getenv("ATOM_MOE_RAWW_DUMP_LAYERS", ""), + # MoE apply() boundary dump (x / router_logits / output) for comparing the + # default (ATOM_USE_TRITON_MOE=0) vs triton (=1) paths across two runs. + "ATOM_MOE_APPLY_DUMP_DIR": lambda: os.getenv("ATOM_MOE_APPLY_DUMP_DIR", ""), + "ATOM_MOE_APPLY_DUMP_LAYERS": lambda: os.getenv("ATOM_MOE_APPLY_DUMP_LAYERS", ""), # Per-rank weight dump + sys.exit(0) — for byte-equal weight comparison. "ATOM_WEIGHT_DUMP_DIR": lambda: os.getenv("ATOM_WEIGHT_DUMP_DIR", ""), "ATOM_WEIGHT_DUMP_LAYERS": lambda: os.getenv("ATOM_WEIGHT_DUMP_LAYERS", "0"), From c2efefbac650b0ab809268baa1014c53c6626d33 Mon Sep 17 00:00:00 2001 From: Leon Ling Date: Tue, 30 Jun 2026 06:10:01 +0000 Subject: [PATCH 8/8] Cleanup --- atom/model_ops/minimax_m3/sparse_attn.py | 4 +- atom/model_ops/moe.py | 49 +-- atom/models/minimax_m3.py | 31 -- atom/utils/debug_helper/__init__.py | 8 - atom/utils/debug_helper/compare.py | 40 --- atom/utils/debug_helper/dump.py | 390 ---------------------- atom/utils/envs.py | 22 -- curl_minimax.sh | 4 - serve_minimax.sh | 143 -------- tests/test_mxfp4_shared_experts_swiglu.py | 222 ------------ 10 files changed, 8 insertions(+), 905 deletions(-) delete mode 100644 curl_minimax.sh delete mode 100644 serve_minimax.sh delete mode 100644 tests/test_mxfp4_shared_experts_swiglu.py diff --git a/atom/model_ops/minimax_m3/sparse_attn.py b/atom/model_ops/minimax_m3/sparse_attn.py index 61f36cc73..f0938f81c 100644 --- a/atom/model_ops/minimax_m3/sparse_attn.py +++ b/atom/model_ops/minimax_m3/sparse_attn.py @@ -156,9 +156,7 @@ def _sparse_decode_unified_attention( # key_cache: [num_blocks, num_kv_heads, head_size // x, block_size, x] block_size = k_cache_view.shape[3] # Each token is its own length-1 sequence (decode); cu_seqlens_q = 0..num_seqs. - cu_seqlens_q = torch.arange( - num_seqs + 1, dtype=torch.int32, device=q_view.device - ) + cu_seqlens_q = torch.arange(num_seqs + 1, dtype=torch.int32, device=q_view.device) # Safe upper bound: full block table width * page size (>= every sparse_ctx). max_seqlen_k = int(sparse_bt.shape[1]) * int(block_size) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index e14fc665b..21b38683e 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1002,16 +1002,6 @@ def process_weights_after_loading(self, layer): if layer.w2_bias is not None: layer.w2_bias.data = layer.w2_bias.data.to(torch.float32) - # Env-gated raw (pre-transform) MXFP4 expert weight stash, for the - # default-vs-triton fused_moe isolation test. The default (CK shuffle) - # and triton (_swizzle_mxfp4) paths transform these SAME raw tensors - # differently; saving them here lets the offline test apply both - # transforms and compare kernels on identical weights. No-op unless - # ATOM_MOE_RAWW_DUMP_DIR is set. - from atom.utils.debug_helper import maybe_dump_mxfp4_raw_weights - - maybe_dump_mxfp4_raw_weights(layer) - if self.static_input_scales: layer.w13_input_scale = atom_parameter( layer.w13_input_scale.max().to(torch.float32) @@ -1273,11 +1263,6 @@ def apply( _moe_result = _moe_result + self._apply_shared_experts_dense( layer, x, activation ) - from atom.utils.debug_helper import maybe_dump_moe_apply_io - - maybe_dump_moe_apply_io( - getattr(layer, "layer_name", ""), x, router_logits, _moe_result - ) return _moe_result assert ( @@ -1335,11 +1320,12 @@ def apply( "swiglu_limit": getattr(layer, "swiglu_limit", 0.0), } if self.fused_experts is None: - # Positional kernel args, then the keyword kernel args (incl. - # moe_extra_args: gate_mode / swiglu_limit). Built once so the same - # values feed both the kernel call and the env-gated I/O dump. - fused_moe_pos = (x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids) - fused_moe_kw = dict( + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, expert_mask=layer.expert_mask, activation=activation, quant_type=self.quant_type, @@ -1354,25 +1340,6 @@ def apply( bias2=layer.w2_bias, **moe_extra_args, ) - # Env-gated kernel I/O dump for offline correctness debugging. - from atom.utils.debug_helper import maybe_dump_fused_moe_io - - _moe_dump_args = { - "x": x, - "w1": layer.w13_weight, - "w2": layer.w2_weight, - "topk_weights": topk_weights, - "topk_ids": topk_ids, - **fused_moe_kw, - } - layer_name = getattr(layer, "layer_name", "") - maybe_dump_fused_moe_io(layer_name, _moe_dump_args) - moe_out = fused_moe(*fused_moe_pos, **fused_moe_kw) - maybe_dump_fused_moe_io(layer_name, _moe_dump_args, output=moe_out) - from atom.utils.debug_helper import maybe_dump_moe_apply_io - - maybe_dump_moe_apply_io(layer_name, x, router_logits, moe_out) - return moe_out return self.fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -1468,9 +1435,7 @@ def _shared_expert_gemm(act, weight, weight_scale): out_dtype=x.dtype, ) else: - intermediate = torch.empty( - (M, half_n), device=x.device, dtype=x.dtype - ) + intermediate = torch.empty((M, half_n), device=x.device, dtype=x.dtype) fused_clamp_act_mul( gate_up, out=intermediate, diff --git a/atom/models/minimax_m3.py b/atom/models/minimax_m3.py index 12ee273ff..72c9fdc86 100644 --- a/atom/models/minimax_m3.py +++ b/atom/models/minimax_m3.py @@ -43,7 +43,6 @@ make_layers, maybe_prefix, ) -from atom.utils.debug_helper import maybe_dump_minimax_m3_layer from atom.utils.decorators import support_torch_compile from torch import nn from transformers import PretrainedConfig @@ -317,26 +316,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.shared_experts is not None: routed_output = routed_output + self.shared_experts(hidden_states) - import os as _os_m3 - - _md = _os_m3.getenv("ATOM_M3_MOEOUT_DIR") - if _md: - _ln = getattr(getattr(self, "experts", None), "layer_name", "") - import re as _re_m3 - - _m = _re_m3.search(r"\.layers\.(\d+)\.", _ln) - if _m and int(_m.group(1)) in (3, 30, 59): - import torch as _t_m3 - import torch.distributed as _dist_m3 - - _os_m3.makedirs(_md, exist_ok=True) - _rk = _dist_m3.get_rank() if _dist_m3.is_initialized() else 0 - _kind = "decode" if hidden_states.shape[0] == 1 else "prefill" - _t_m3.save( - {"out": routed_output.detach().cpu(), "in": hidden_states.detach().cpu()}, - _os_m3.path.join(_md, f"moeout_l{_m.group(1)}_{_kind}_rank{_rk}.pt"), - ) - return routed_output.view(orig_shape) @@ -609,10 +588,6 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) - # Debug dump bookkeeping (env-gated; see maybe_dump_minimax_m3_layer). - self.layer_num = layer_num - self._last_layer_idx = config.num_hidden_layers - 1 - def forward( self, positions: torch.Tensor, @@ -636,17 +611,11 @@ def forward( aux_out.append(residual.clone()) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) - # maybe_dump_minimax_m3_layer( - # hidden_states, self.layer_num, "attn", self._last_layer_idx - # ) hidden_states, residual = fused_allreduce_gemma_rms_norm( hidden_states, residual, self.post_attention_layernorm ) ffn = self.block_sparse_moe if self.is_moe_layer else self.mlp hidden_states = ffn(hidden_states) - # maybe_dump_minimax_m3_layer( - # hidden_states, self.layer_num, "moe", self._last_layer_idx - # ) return hidden_states, residual diff --git a/atom/utils/debug_helper/__init__.py b/atom/utils/debug_helper/__init__.py index 5b37894cc..d23de5e14 100644 --- a/atom/utils/debug_helper/__init__.py +++ b/atom/utils/debug_helper/__init__.py @@ -31,10 +31,6 @@ ) from atom.utils.debug_helper.dump import ( install_block_forward_hooks, - maybe_dump_fused_moe_io, - maybe_dump_minimax_m3_layer, - maybe_dump_moe_apply_io, - maybe_dump_mxfp4_raw_weights, maybe_dump_weights_and_exit, maybe_log_topk, ) @@ -47,10 +43,6 @@ __all__ = [ # dump "install_block_forward_hooks", - "maybe_dump_fused_moe_io", - "maybe_dump_minimax_m3_layer", - "maybe_dump_moe_apply_io", - "maybe_dump_mxfp4_raw_weights", "maybe_dump_weights_and_exit", "maybe_log_topk", # compare primitives diff --git a/atom/utils/debug_helper/compare.py b/atom/utils/debug_helper/compare.py index a00b4bef6..086b04a0a 100644 --- a/atom/utils/debug_helper/compare.py +++ b/atom/utils/debug_helper/compare.py @@ -371,39 +371,6 @@ def cmd_layer_bisect(args: argparse.Namespace) -> int: return rc -def cmd_moe_apply(args: argparse.Namespace) -> int: - """Compare MoE apply() boundary dumps from two runs (default vs triton). - - Each dir holds apply_{kind}_{layer}_rank{R}.pt with keys x / router_logits / - output (written by maybe_dump_moe_apply_io). Matches files by basename across - --a (e.g. default / correct) and --b (e.g. triton / broken) and prints, per - file, cos for x, router_logits, output. The first stage where cos drops tells - you the divergence: - - router_logits drops -> routing bug - - x matches, output drops -> kernel / weight-layout / scale bug - """ - a_files = {os.path.basename(p): p for p in glob.glob(os.path.join(args.a, "*.pt"))} - b_files = {os.path.basename(p): p for p in glob.glob(os.path.join(args.b, "*.pt"))} - common = sorted(set(a_files) & set(b_files)) - if not common: - print(f"no matching files between {args.a} and {args.b}") - return 1 - print(f"{'file':45s} {'stage':14s} {'cos':>9s} {'max_abs':>10s} {'rel':>9s}") - for name in common: - a = torch.load(a_files[name], map_location="cpu", weights_only=False) - b = torch.load(b_files[name], map_location="cpu", weights_only=False) - for stage in ("x", "router_logits", "output"): - ta, tb = a.get(stage), b.get(stage) - if ta is None or tb is None: - continue - cos, ma, rel = cos_max(ta, tb) - print( - f"{name[:45]:45s} {stage:14s} {flag_for(cos, rel)} " - f"{cos:9.6f} {ma:10.3e} {rel:9.3e}" - ) - return 0 - - def cmd_schema(args: argparse.Namespace) -> int: """Show schema diff between two dump files.""" a = torch.load(args.a, map_location="cpu", weights_only=False) @@ -438,13 +405,6 @@ def main(argv: Optional[list[str]] = None) -> int: lb.add_argument("--threshold", type=float, default=COS_NUM_DRIFT) lb.set_defaults(func=cmd_layer_bisect) - ma = sub.add_parser( - "moe-apply", help="Compare MoE apply() dumps (default dir vs triton dir)" - ) - ma.add_argument("--a", required=True, help="dir A (e.g. default / correct run)") - ma.add_argument("--b", required=True, help="dir B (e.g. triton / broken run)") - ma.set_defaults(func=cmd_moe_apply) - sc = sub.add_parser("schema", help="Schema diff between two dump files") sc.add_argument("--a", required=True) sc.add_argument("--b", required=True) diff --git a/atom/utils/debug_helper/dump.py b/atom/utils/debug_helper/dump.py index 2f672724e..0122efbb7 100644 --- a/atom/utils/debug_helper/dump.py +++ b/atom/utils/debug_helper/dump.py @@ -10,8 +10,6 @@ ATOM_FWD_DUMP_LAYER_ATTR / ATOM_FWD_DUMP_ONE_SHOT ATOM_WEIGHT_DUMP_DIR / ATOM_WEIGHT_DUMP_LAYERS / ATOM_WEIGHT_DUMP_EXIT ATOM_DEBUG_TOPK / ATOM_DEBUG_TOPK_PATH - ATOM_M3_DUMP_DIR / ATOM_M3_DUMP_EXIT - ATOM_MINIMAX_DUMP_DIR / ATOM_MINIMAX_DUMP_ABORT Output file naming ------------------ @@ -35,7 +33,6 @@ from __future__ import annotations -import logging import os import sys from typing import Optional @@ -161,393 +158,6 @@ def _find_layer_id(mod_name: str) -> Optional[int]: return n -# === MiniMax-M3 per-layer attn/moe dump (call-site helper) =========== - -# Tracks which (layer, stage) pairs are already saved per step kind, so each is -# dumped exactly once for the first prefill step and once for the first decode step. -_M3_DONE: dict[str, set[tuple[int, str]]] = {"prefill": set(), "decode": set()} - - -def _is_dummy_run() -> bool: - """True if the current forward is a warmup / CUDAGraph-capture dummy run. - - Reads is_dummy_run off the forward context's Context. Returns False when no - context is available (e.g. unit tests calling the dump helper directly), so - standalone usage still dumps. - """ - try: - from atom.utils.forward_context import get_forward_context - - ctx = get_forward_context().context - return bool(getattr(ctx, "is_dummy_run", False)) - except Exception: - return False - - -def maybe_dump_minimax_m3_layer( - hidden_states: torch.Tensor, - layer_idx: int, - stage: str, - last_layer_idx: int, -) -> None: - """Print mean/var and dump a MiniMax-M3 layer hidden state. Env-gated no-op. - - No-op unless ATOM_M3_DUMP_DIR is set, so it is safe to leave wired into the - model forward in production (zero overhead on the unset path). Intended to be - called from MiniMaxM3DecoderLayer.forward right after the attention block - output (stage="attn") and after the MoE/MLP block output (stage="moe"). - - Warmup / CUDAGraph-capture dummy forwards are skipped (gated on the forward - context's is_dummy_run flag), so only real request data is dumped. - - Captures the first prefill step and the first decode step only. The step kind - is inferred from the token count: decode == 1 token (per the num_tokens==1 - rule), otherwise prefill. Each (layer, stage) is saved once per kind, so later - prefill/decode steps are ignored. - - Files: {ATOM_M3_DUMP_DIR}/{kind}_layer{LL}_{stage}_rank{R}.pt - - When ATOM_M3_DUMP_EXIT is set (default), the process exits right after the last - layer's MoE output of the first decode step. - - Requires eager execution (--enforce-eager / --level 0): with compilation on, - the model forward is traced by Dynamo and the .item()/torch.save calls here - are not traceable. Leaving the env unset keeps this a pure no-op, so compiled - runs are unaffected. - """ - dump_dir = envs.ATOM_M3_DUMP_DIR - if not dump_dir: - return - if not isinstance(hidden_states, torch.Tensor): - return - - # Skip warmup / capture dummy forwards — only dump real request data. The - # warmup pass sets is_dummy_run=True on the forward context (see - # model_runner.warmup_model). Absent a context (e.g. unit tests), treat as - # real so standalone calls still dump. - if _is_dummy_run(): - return - - kind = "decode" if hidden_states.shape[0] == 1 else "prefill" - key = (layer_idx, stage) - if key in _M3_DONE[kind]: - return - _M3_DONE[kind].add(key) - - os.makedirs(dump_dir, exist_ok=True) - rank = _get_rank() - logger = logging.getLogger("atom") - - tf = hidden_states.detach().float() - mean = tf.mean().item() - var = tf.var().item() - logger.info( - f"[M3_DUMP] {kind} layer{layer_idx:03d} {stage}: " - f"mean={mean:.6e} var={var:.6e} shape={tuple(hidden_states.shape)} " - f"rank={rank}" - ) - fname = os.path.join( - dump_dir, f"{kind}_layer{layer_idx:03d}_{stage}_rank{rank}.pt" - ) - torch.save( - { - "hidden": hidden_states.detach().cpu(), - "shape": tuple(hidden_states.shape), - "layer": layer_idx, - "stage": stage, - "kind": kind, - "mean": mean, - "var": var, - }, - fname, - ) - - # Exit after the last layer's MoE output of the first decode step. - if ( - envs.ATOM_M3_DUMP_EXIT - and kind == "decode" - and stage == "moe" - and layer_idx == last_layer_idx - ): - logger.info( - "[M3_DUMP] captured first decode step through last layer " - f"{layer_idx}; exiting." - ) - import torch.distributed as dist - - if dist.is_initialized(): - dist.barrier() - dist.destroy_process_group() - sys.exit(0) - - -# === fused_moe kernel I/O dump ======================================= - -# Tracks which (layer, kind) pairs already dumped, so each layer is captured once -# per first prefill step and once per first decode step. -_MOE_DONE: set[tuple[str, str]] = set() - - -def _to_cpu(obj): - """Detach + move tensors to CPU recursively; pass scalars/enums through.""" - if isinstance(obj, torch.Tensor): - return obj.detach().cpu() - if isinstance(obj, (list, tuple)): - return type(obj)(_to_cpu(o) for o in obj) - if isinstance(obj, dict): - return {k: _to_cpu(v) for k, v in obj.items()} - return obj - - -def maybe_dump_fused_moe_io( - layer_name: str, - kernel_inputs: dict, - output: Optional[torch.Tensor] = None, -) -> None: - """Dump fused_moe kernel inputs (and optionally output). Env-gated no-op. - - No-op unless ATOM_MOE_DUMP_DIR is set. Skips warmup / CUDAGraph-capture dummy - forwards (is_dummy_run). Captures the first prefill step and the first decode - step per layer (kind inferred from the activation token count: 1 == decode, - else prefill). Each (layer, kind) is saved once. - - Intended to bracket the fused_moe(...) call in Mxfp4MoEMethod.apply: call once - before the kernel with the full kwargs dict (output=None) to record inputs, - then again after with the same layer_name + the realized output tensor to - append the output to the same file. - - kernel_inputs should map arg name -> value (tensors, scalars, enums). Tensors - are detached + moved to CPU; non-tensors are stored verbatim as metadata, so - the saved file is self-contained for offline kernel replay. - - Files: {ATOM_MOE_DUMP_DIR}/moe_{kind}_{safe_layer_name}_rank{R}.pt - Requires eager mode (--enforce-eager): the .cpu()/torch.save calls are not - Dynamo-traceable; the env-unset no-op path keeps compiled runs unaffected. - """ - dump_dir = envs.ATOM_MOE_DUMP_DIR - if not dump_dir: - return - if _is_dummy_run(): - return - - # Determine token count from the activation input ("x"/"hidden_states"). - x = kernel_inputs.get("x") - if x is None: - x = kernel_inputs.get("hidden_states") - if not isinstance(x, torch.Tensor): - return - kind = "decode" if x.shape[0] == 1 else "prefill" - - # Optional per-layer-index filter (parse "model.layers.." from the name). - wanted = _parse_layer_set(envs.ATOM_MOE_DUMP_LAYERS) - if wanted is not None: - import re - - m = re.search(r"(?:^|\.)layers\.(\d+)(?:\.|$)", layer_name) - if m is None or int(m.group(1)) not in wanted: - return - - key = (layer_name, kind) - is_input_call = output is None - if is_input_call and key in _MOE_DONE: - return # already captured this (layer, kind) - - os.makedirs(dump_dir, exist_ok=True) - rank = _get_rank() - logger = logging.getLogger("atom") - safe = layer_name.replace("/", "_").replace(".", "_") or "moe" - fname = os.path.join(dump_dir, f"moe_{kind}_{safe}_rank{rank}.pt") - - if is_input_call: - _MOE_DONE.add(key) - pkt = { - "_layer_name": layer_name, - "_kind": kind, - "_rank": rank, - "inputs": {k: _to_cpu(v) for k, v in kernel_inputs.items()}, - } - torch.save(pkt, fname) - tensor_args = [ - k for k, v in kernel_inputs.items() if isinstance(v, torch.Tensor) - ] - logger.info( - f"[MOE_DUMP] {kind} {layer_name}: saved inputs " - f"({len(tensor_args)} tensors) shape_x={tuple(x.shape)} rank={rank}" - ) - return - - # Output call: append output to the existing input file (must follow input). - if not isinstance(output, torch.Tensor): - return - pkt = {} - if os.path.exists(fname): - # weights_only=False: the input pkt stores enums (ActivationType, - # QuantType) that are not in torch's safe-globals allowlist. - pkt = torch.load(fname, weights_only=False) - pkt["output"] = output.detach().cpu() - pkt["_output_shape"] = tuple(output.shape) - torch.save(pkt, fname) - of = output.detach().float() - logger.info( - f"[MOE_DUMP] {kind} {layer_name}: saved output " - f"mean={of.mean().item():.6e} var={of.var().item():.6e} " - f"shape={tuple(output.shape)} rank={rank}" - ) - - -# === MoE apply() boundary dump (default-vs-triton compare) =========== - -# (layer, kind) already dumped — capture first prefill + first decode per layer. -_MOE_APPLY_DONE: set[tuple[str, str]] = set() - - -def maybe_dump_moe_apply_io( - layer_name: str, - x: torch.Tensor, - router_logits: torch.Tensor, - output: torch.Tensor, -) -> None: - """Dump the MoE apply() boundary: x, router_logits, output. Env-gated no-op. - - No-op unless ATOM_MOE_APPLY_DUMP_DIR is set. These three tensors exist on - BOTH the aiter default and triton paths, so dumping from one run with - ATOM_USE_TRITON_MOE=0 (correct) and another with =1 (broken) lets the compare - CLI diff them per layer to localize divergence: - - x identical, router_logits identical, output differs -> kernel/layout bug - - router_logits differ -> routing bug - - Skips warmup/dummy forwards. Captures first prefill + first decode per layer. - Optional layer filter via ATOM_MOE_APPLY_DUMP_LAYERS (comma-separated). - File: {dir}/apply_{kind}_{safe_layer}_rank{R}.pt - Requires eager mode. - """ - dump_dir = envs.ATOM_MOE_APPLY_DUMP_DIR - if not dump_dir: - return - _dbg = os.getenv("ATOM_MOE_APPLY_DUMP_DEBUG") == "1" - # ATOM_MOE_APPLY_DUMP_ALL=1 bypasses the dummy/warmup skip — keep only the - # LAST write per (layer,kind) so the final real forward wins over warmups. - _skip_dummy = os.getenv("ATOM_MOE_APPLY_DUMP_ALL") != "1" - if _skip_dummy and _is_dummy_run(): - if _dbg: - print(f"[MOE_APPLY_DBG] skip dummy_run layer={layer_name}", flush=True) - return - if not isinstance(x, torch.Tensor): - return - kind = "decode" if x.shape[0] == 1 else "prefill" - - wanted = _parse_layer_set(envs.ATOM_MOE_APPLY_DUMP_LAYERS) - if wanted is not None: - import re - - m = re.search(r"(?:^|\.)layers\.(\d+)(?:\.|$)", layer_name) - if m is None or int(m.group(1)) not in wanted: - if _dbg: - print( - f"[MOE_APPLY_DBG] skip filter layer={layer_name!r} " - f"match={None if m is None else m.group(1)}", - flush=True, - ) - return - if _dbg: - print( - f"[MOE_APPLY_DBG] WRITING layer={layer_name!r} kind={kind} " - f"x={tuple(x.shape)}", - flush=True, - ) - - key = (layer_name, kind) - if _skip_dummy: - # Normal mode: first real call wins, dedup to avoid file churn. - if key in _MOE_APPLY_DONE: - return - _MOE_APPLY_DONE.add(key) - # ALL mode: overwrite each call so the LAST (real) forward wins. - - os.makedirs(dump_dir, exist_ok=True) - rank = _get_rank() - safe = layer_name.replace("/", "_").replace(".", "_") or "moe" - fname = os.path.join(dump_dir, f"apply_{kind}_{safe}_rank{rank}.pt") - torch.save( - { - "_layer_name": layer_name, - "_kind": kind, - "_rank": rank, - "x": x.detach().cpu(), - "router_logits": router_logits.detach().cpu(), - "output": output.detach().cpu(), - }, - fname, - ) - of = output.detach().float() - logging.getLogger("atom").info( - f"[MOE_APPLY] {kind} {layer_name}: x={tuple(x.shape)} " - f"out mean={of.mean().item():.6e} var={of.var().item():.6e} rank={rank}" - ) - - -# === MXFP4 raw expert weight stash (isolation test) ================== - -_MOE_RAWW_DONE: set[str] = set() - - -def maybe_dump_mxfp4_raw_weights(layer: "torch.nn.Module") -> None: - """Stash raw (pre-transform) MXFP4 expert weights for one MoE layer. - - No-op unless ATOM_MOE_RAWW_DUMP_DIR is set. Called at the top of - Mxfp4MoEMethod.process_weights_after_loading, BEFORE the CK shuffle / triton - swizzle, so the saved tensors are the common raw input both paths consume. - The offline isolation test applies both transforms to these and compares - kernels. Saves once per layer_name. - - File: {ATOM_MOE_RAWW_DUMP_DIR}/raww_{safe_layer_name}_rank{R}.pt - keys: w13_weight, w13_weight_scale, w2_weight, w2_weight_scale, - w13_bias, w2_bias, w13_input_scale, w2_input_scale, layer_name - """ - dump_dir = envs.ATOM_MOE_RAWW_DUMP_DIR - if not dump_dir: - return - layer_name = getattr(layer, "layer_name", "") or getattr(layer, "prefix", "") - - wanted = _parse_layer_set(envs.ATOM_MOE_RAWW_DUMP_LAYERS) - if wanted is not None: - import re - - m = re.search(r"(?:^|\.)layers\.(\d+)(?:\.|$)", layer_name) - if m is None or int(m.group(1)) not in wanted: - return - if layer_name in _MOE_RAWW_DONE: - return - _MOE_RAWW_DONE.add(layer_name) - - os.makedirs(dump_dir, exist_ok=True) - rank = _get_rank() - - def _g(name): - v = getattr(layer, name, None) - if isinstance(v, torch.Tensor): - return v.detach().cpu().clone() - return v - - pkt = { - "layer_name": layer_name, - "_rank": rank, - "w13_weight": _g("w13_weight"), - "w13_weight_scale": _g("w13_weight_scale"), - "w2_weight": _g("w2_weight"), - "w2_weight_scale": _g("w2_weight_scale"), - "w13_bias": _g("w13_bias"), - "w2_bias": _g("w2_bias"), - "w13_input_scale": _g("w13_input_scale"), - "w2_input_scale": _g("w2_input_scale"), - } - safe = layer_name.replace("/", "_").replace(".", "_") or "moe" - torch.save(pkt, os.path.join(dump_dir, f"raww_{safe}_rank{rank}.pt")) - logging.getLogger("atom").info( - f"[MOE_RAWW] saved raw MXFP4 weights for {layer_name} rank={rank}" - ) - - # === Weight dump ===================================================== diff --git a/atom/utils/envs.py b/atom/utils/envs.py index ba31a4c7a..3bd2a53ab 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -183,28 +183,6 @@ "ATOM_FWD_DUMP_LAYER_ATTR", "layer_id" ), "ATOM_FWD_DUMP_ONE_SHOT": lambda: os.getenv("ATOM_FWD_DUMP_ONE_SHOT", "1") == "1", - # MiniMax-M3 per-layer attn/moe hidden dump + mean/var print. Captures one - # prefill step and one decode step (decode = num_tokens == 1), then exits - # after the last layer's MoE of the first decode step. Requires eager mode - # (--enforce-eager) so the submodule forward hooks fire during decode. - "ATOM_M3_DUMP_DIR": lambda: os.getenv("ATOM_M3_DUMP_DIR", ""), - "ATOM_M3_DUMP_EXIT": lambda: os.getenv("ATOM_M3_DUMP_EXIT", "1") == "1", - # fused_moe kernel I/O dump — saves all kernel tensor args (activations, - # expert weights, scales, routing, biases) + the kernel output for one - # prefill and one decode step per layer, for offline kernel-correctness - # replay. Requires eager mode. ATOM_MOE_DUMP_LAYERS optionally restricts to a - # comma-separated set of layer indices (default: all). - "ATOM_MOE_DUMP_DIR": lambda: os.getenv("ATOM_MOE_DUMP_DIR", ""), - "ATOM_MOE_DUMP_LAYERS": lambda: os.getenv("ATOM_MOE_DUMP_LAYERS", ""), - # Raw (pre-transform) MXFP4 expert weight stash for the default-vs-triton - # fused_moe isolation test. Saves w13/w2 weight+scale+bias per MoE layer - # before process_weights_after_loading shuffles/swizzles them. - "ATOM_MOE_RAWW_DUMP_DIR": lambda: os.getenv("ATOM_MOE_RAWW_DUMP_DIR", ""), - "ATOM_MOE_RAWW_DUMP_LAYERS": lambda: os.getenv("ATOM_MOE_RAWW_DUMP_LAYERS", ""), - # MoE apply() boundary dump (x / router_logits / output) for comparing the - # default (ATOM_USE_TRITON_MOE=0) vs triton (=1) paths across two runs. - "ATOM_MOE_APPLY_DUMP_DIR": lambda: os.getenv("ATOM_MOE_APPLY_DUMP_DIR", ""), - "ATOM_MOE_APPLY_DUMP_LAYERS": lambda: os.getenv("ATOM_MOE_APPLY_DUMP_LAYERS", ""), # Per-rank weight dump + sys.exit(0) — for byte-equal weight comparison. "ATOM_WEIGHT_DUMP_DIR": lambda: os.getenv("ATOM_WEIGHT_DUMP_DIR", ""), "ATOM_WEIGHT_DUMP_LAYERS": lambda: os.getenv("ATOM_WEIGHT_DUMP_LAYERS", "0"), diff --git a/curl_minimax.sh b/curl_minimax.sh deleted file mode 100644 index c6f0f4345..000000000 --- a/curl_minimax.sh +++ /dev/null @@ -1,4 +0,0 @@ -curl -X POST "http://localhost:8014/v1/completions" \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "The capital of China is", "temperature": 0, "top_p": 1, "top_k": 1, "repetition_penalty": 1.0, "presence_penalty": 0, "frequency_penalty": 0, "stream": false, "ignore_eos": false, "n": 1, "seed": 123, "max_tokens": 20}' \ No newline at end of file diff --git a/serve_minimax.sh b/serve_minimax.sh deleted file mode 100644 index 536431f5c..000000000 --- a/serve_minimax.sh +++ /dev/null @@ -1,143 +0,0 @@ -# export AITER_QUICK_REDUCE_QUANTIZATION=INT4 -# export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1 -# export ATOM_USE_GLUON_PA_DECODE=1 -# export HIP_VISIBLE_DEVICES=6,7 - -# vllm serve /workspace/shared/data/amd_int/models/MiniMax-M2.5 \ -# --host localhost \ -# --port 8100 \ -# --async-scheduling \ -# --load-format fastsafetensors \ -# --tensor-parallel-size 2 \ -# --trust-remote-code \ -# --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \ -# --kv-cache-dtype fp8 \ -# --max-num-batched-tokens 16384 \ -# --max-model-len 16384 \ -# --gpu-memory-utilization 0.9 \ -# --no-enable-prefix-caching \ -# --profiler-config '{"profiler": "torch", "torch_profiler_dir": "./vllm_profile", "torch_profiler_with_stack": "False"}' \ - # --enforce-eager - -# model_path=/workspace/shared/data/amd_int/models/MiniMax-M3-MXFP4/ - -# model_path=/workspace/shared/data/amd_int/models/MiniMax-M3 -export HIP_VISIBLE_DEVICES=0,1,2,3 -export ATOM_FORCE_ATTN_TRITON=1 -export ATOM_M3_DUMP_DIR="./minimax_dump/" -# export HSA_ENABLE_SDMA=1 -# export HSA_USE_SVM=1 -# export HSA_XNACK=1 -# export ATOM_USE_TRITON_MOE=1 -# export ATOM_USE_TRITON_GEMM=1 -# export ENABLE_CK=0 -# export AITER_USE_OPUS_MOE_SORTING=1 -# export ATOM_USE_UNIFIED_ATTN=0 - - -# python -m atom.entrypoints.openai_server \ -# --model $model_path \ -# -tp 4 --server-port 8013 --trust-remote-code --gpu-memory-utilization 0.7 \ -# --block-size 128 \ -# --no-enable_prefix_caching \ - # --torch-profiler-dir ./trace --mark-trace -main_model=/workspace/shared/data/amd_int/models/MiniMax-M3-MXFP4 -# draft_model=/workspace/shared/data/amd_int/models/MiniMax-M3-EAGLE3 -# export HIP_VISIBLE_DEVICES=0,1,2,3 -python -m atom.entrypoints.openai_server --model $main_model \ - -tp 4 --server-port 8014 --trust-remote-code --gpu-memory-utilization 0.8 --block-size 128 --no-enable_prefix_caching \ - --max-num-batched-tokens 32768 --max-model-len 32768 --max-num-seqs 128 --enforce-eager --level 0 - # --kv_cache_dtype fp8 \ - # --torch-profiler-dir ./trace - # --method eagle3 --draft-model $draft_model --num-speculative-tokens 3 \ - - # --enforce-eager - - -#!/usr/bin/env bash -# set -euo pipefail - -# rm -f "${VLLM_CACHE_ROOT:-$HOME/.cache/vllm}"/modelinfos/*minimax_m3* 2>/dev/null || true -# ok -# SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -# LOG_FILE="$SCRIPT_DIR/minimax-m3-preview-2.log" - -# unset VL MM_BATCHED MM_ENCODER_ATTN - -# export HIP_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" -# # export HIP_VISIBLE_DEVICES="4,5,6,7" -# export ATOM_USE_TRITON_MOE="${ATOM_USE_TRITON_MOE:-1}" - -# MODEL="/shared/data/amd_int/models/MiniMax-M3" -# SERVED_NAME="${SERVED_NAME:-MiniMax-M3}" -# PORT="${PORT:-8000}" -# TP="${TP:-8}" -# GPU_MEM_UTIL="${GPU_MEM_UTIL:-0.90}" -# # GPU_MEM_UTIL="${GPU_MEM_UTIL:-0.8}" -# MAX_LEN="${MAX_LEN:-16384}" -# MAX_BATCHED_TOKENS="${MAX_BATCHED_TOKENS:-16384}" -# # MAX_SEQS="${MAX_SEQS:-8}" -# ATTN="TRITON_ATTN" -# LOAD_FORMAT="${LOAD_FORMAT:-auto}" -# SKIP_TOKENIZER_INIT="${SKIP_TOKENIZER_INIT:-0}" -# ENFORCE_EAGER="${ENFORCE_EAGER:-0}" -# EXTRA_ARGS=() -# if [[ "$SKIP_TOKENIZER_INIT" == "1" ]]; then -# EXTRA_ARGS+=(--skip-tokenizer-init) -# fi -# if [[ "$ENFORCE_EAGER" == "1" ]]; then -# EXTRA_ARGS+=(--enforce-eager) -# fi - -# echo "### serve: model=$MODEL" -# echo "### serve: devices=$HIP_VISIBLE_DEVICES tp=$TP port=$PORT max_len=$MAX_LEN" -# echo "### serve: log=$LOG_FILE" -# export VLLM_CUSTOM_SCOPES_FOR_PROFILING=1 -# export VLLM_BATCH_INVARIANT="${VLLM_BATCH_INVARIANT:-1}" -# vllm serve "$MODEL" \ -# --dtype bfloat16 \ -# --load-format "$LOAD_FORMAT" \ -# --host localhost \ -# --port "$PORT" \ -# --tensor-parallel-size "$TP" \ -# --gpu-memory-utilization "$GPU_MEM_UTIL" \ -# --max-model-len "$MAX_LEN" \ -# --max-num-batched-tokens "$MAX_BATCHED_TOKENS" \ -# --block-size 128 \ -# --no-enable-prefix-caching \ -# --language-model-only \ -# --no-trust-remote-code \ -# --compilation-config '{"cudagraph_mode": "FULL_DECODE_ONLY"}' \ -# "${EXTRA_ARGS[@]}" \ -# 2>&1 | tee "$LOG_FILE" - -# export HIP_VISIBLE_DEVICES=4,5,6,7 -# AITER_QUICK_REDUCE_QUANTIZATION=INT4 \ -# HSA_ENABLE_SDMA=1 \ -# HSA_USE_SVM=1 \ -# HSA_XNACK=1 \ -# AITER_DISABLE_KERNARG_PRELOAD=1 \ -# ATOM_USE_TRITON_MOE=1 \ -# ATOM_FORCE_ATTN_TRITON=1 \ -# ATOM_LOADER_USE_THREADPOOL=0 \ -# AITER_ROPE_TRITON_BACKEND=1 \ -# ATOM_ENABLE_DS_QKNORM_FUSION=0 \ -# ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION=0 \ -# ATOM_ENABLE_DS_QKNORM_QUANT_FUSION=0 \ -# ATOM_USE_TRITON_GEMM=1 \ -# ENABLE_DS_QKNORM_FUSION=0 \ -# ENABLE_CK=0 \ -# AITER_USE_OPUS_MOE_SORTING=1 \ -# ATOM_USE_UNIFIED_ATTN=0 \ -# python -m atom.entrypoints.openai_server \ -# --model /workspace/shared/data/amd_int/models/MiniMax-M3-MXFP4 \ -# --server-port 8013 \ -# --trust-remote-code \ -# -tp 4 \ -# --gpu-memory-utilization 0.8 \ -# --block-size 128 \ -# --max-model-len 32768 \ -# --max-num-seqs 128 \ -# --max-num-batched-tokens 32768 \ -# --torch-profiler-dir /app/trace \ -# --no-enable_prefix_caching \ No newline at end of file diff --git a/tests/test_mxfp4_shared_experts_swiglu.py b/tests/test_mxfp4_shared_experts_swiglu.py deleted file mode 100644 index dd8c8f592..000000000 --- a/tests/test_mxfp4_shared_experts_swiglu.py +++ /dev/null @@ -1,222 +0,0 @@ -# SPDX-License-Identifier: MIT -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - -"""Numerical test for the dense shared-expert path with SwiGLU-OAI activation. - -Regression: ``Mxfp4MoEMethod._apply_shared_experts_dense`` hard-asserted the -SiLU activation path, so MiniMax-M3 (``ActivationType.Swiglu`` *with* fused -shared experts) crashed with:: - - AssertionError: dense shared-expert GEMM only supports the SiLU activation path - -MiniMax-M3 does not interleave gate/up weights, so the dense GEMM output is -split ``[gate | up]`` — exactly what ``swiglu_oai_split`` consumes. The dense -shared expert must therefore replicate ``MiniMaxM3MLP.forward``: -gate_up GEMM -> swiglu_oai_split -> down GEMM, with the SwiGLU-OAI math -``gate * sigmoid(alpha*gate) * (up + beta)`` (alpha=1.702, beta=1.0), not SiLU. - -These tests run the *real* fixed code path (gemm_a16wfp4 + the activation -branch). The fix only changes the *activation*, so the reference reuses the -*same* ``gemm_a16wfp4`` for both matmuls and differs only in the activation -math (computed independently in plain torch). This isolates the activation and -avoids conflating it with the kernel's mxfp4/bf16 GEMM precision. The tests -prove: - * the SwiGLU branch matches the SwiGLU-OAI reference, and - * it is genuinely different from the SiLU reference (i.e. the fix changed - behaviour, it is not silently equivalent), and - * the SiLU branch (DeepSeek) is unchanged. -""" - -import types - -import pytest -import torch - -cuda_only = pytest.mark.skipif( - not torch.cuda.is_available(), reason="requires an AMD GPU" -) - -SCALE_GROUP_SIZE = 32 - -# e2m1 (fp4) decode table: sign | 2-bit exp | 1-bit mantissa. -_MXFP4_TABLE = [ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, -] - - -def _mxfp4_to_f32(packed: torch.Tensor) -> torch.Tensor: - """Decode a uint8 tensor of two packed e2m1 nibbles to f32 (last dim x2).""" - x = packed.repeat_interleave(2, dim=-1) - x[..., ::2] = x[..., ::2] & 0xF - x[..., 1::2] = x[..., 1::2] >> 4 - table = torch.tensor(_MXFP4_TABLE, dtype=torch.float32, device=packed.device) - return table[x.long()] - - -def _e8m0_to_f32(scale: torch.Tensor) -> torch.Tensor: - return 2.0 ** (scale.to(torch.float32) - 127) - - -def _make_weight(n: int, k: int, *, seed: int): - """Random fp4-packed weight (n, k//2) + e8m0 scales (n, k//32). - - The bf16 dequantization is not returned: the reference reuses the kernel's - own GEMM, so we never need a from-scratch dequant matmul. - """ - g = torch.Generator(device="cuda").manual_seed(seed) - low = torch.randint(0, 16, (n, k // 2), dtype=torch.uint8, device="cuda", generator=g) - high = torch.randint(0, 16, (n, k // 2), dtype=torch.uint8, device="cuda", generator=g) - packed = low | (high << 4) - # e8m0 scales near 1.0 (bias 127) keep the dequant range sane. - scales = torch.randint( - 125, 130, (n, k // SCALE_GROUP_SIZE), dtype=torch.uint8, device="cuda", generator=g - ) - return packed, scales - - -def _ref_swiglu_oai(gate_up, alpha, beta, limit): - n = gate_up.shape[-1] // 2 - gate = gate_up[:, :n].to(torch.float32) - up = gate_up[:, n:].to(torch.float32) - if limit is not None: - gate = torch.clamp(gate, max=limit) - up = torch.clamp(up, min=-limit, max=limit) - return (gate * torch.sigmoid(alpha * gate) * (up + beta)).to(gate_up.dtype) - - -def _ref_silu(gate_up, limit): - n = gate_up.shape[-1] // 2 - gate = gate_up[:, :n].to(torch.float32) - up = gate_up[:, n:].to(torch.float32) - if limit > 0: - gate = torch.clamp(gate, max=limit) - up = torch.clamp(up, min=-limit, max=limit) - return (gate * torch.sigmoid(gate) * up).to(gate_up.dtype) - - -def _build_method_and_layer(hidden, inter, *, alpha, beta, limit): - from atom.config import LayerQuantConfig - from atom.model_ops.moe import Mxfp4MoEMethod, MoEActivationQuant - from aiter import QuantType - from unittest.mock import MagicMock - - qc = LayerQuantConfig( - quant_type=QuantType.per_1x32, - quant_dtype=torch.float4_e2m1fn_x2, - quant_method="quark", - ) - method = Mxfp4MoEMethod(qc, MagicMock()) - # Exercise the a16w4 (bf16 activation) path so we feed plain bf16 inputs. - method.act_quant = MoEActivationQuant.BF16 - - w13, s13 = _make_weight(2 * inter, hidden, seed=1) # (2I, H) - w2, s2 = _make_weight(hidden, inter, seed=2) # (H, I) - - layer = types.SimpleNamespace( - num_fused_shared_experts=1, - shared_w13_weight=w13.unsqueeze(0), - shared_w13_weight_scale=s13.unsqueeze(0), - shared_w2_weight=w2.unsqueeze(0), - shared_w2_weight_scale=s2.unsqueeze(0), - shared_w13_bias=None, - shared_w2_bias=None, - swiglu_limit=limit, - swiglu_alpha=alpha, - swiglu_beta=beta, - ) - return method, layer, (w13, s13), (w2, s2) - - -def _kernel_gemm(act, packed_scale): - """Reuse the exact same GEMM the dense path uses, so the only difference - between the dense path and the reference is the activation.""" - from aiter.ops.triton.gemm.basic.gemm_a16wfp4 import gemm_a16wfp4 - - weight, scale = packed_scale - return gemm_a16wfp4(act, weight, scale, dtype=torch.bfloat16) - - -@cuda_only -def test_swiglu_shared_expert_matches_reference(): - from aiter import ActivationType - from atom.model_ops.moe import Mxfp4MoEMethod - - if not _fp4_available(): - pytest.skip("MXFP4 not supported on this architecture") - - hidden, inter, M = 256, 256, 64 - alpha, beta, limit = 1.702, 1.0, 7.0 - method, layer, w13, w2 = _build_method_and_layer( - hidden, inter, alpha=alpha, beta=beta, limit=limit - ) - - torch.manual_seed(0) - x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") * 0.5 - - out = Mxfp4MoEMethod._apply_shared_experts_dense( - method, layer, x, ActivationType.Swiglu - ) - - # Reference reuses the SAME kernel GEMM; only the activation is computed - # independently (plain torch), isolating the fix from GEMM precision. - gate_up = _kernel_gemm(x, w13) - inter_ref = _ref_swiglu_oai(gate_up, alpha, beta, limit) - out_ref = _kernel_gemm(inter_ref, w2) - - # SiLU on the same gate_up must be clearly different (proves the branch - # matters and the fix is not silently equivalent to the old code). - inter_silu = _ref_silu(gate_up, limit) - out_silu = _kernel_gemm(inter_silu, w2) - - err_swiglu = (out.float() - out_ref.float()).abs().mean().item() - err_vs_silu = (out_ref.float() - out_silu.float()).abs().mean().item() - - torch.testing.assert_close(out.float(), out_ref.float(), rtol=1e-2, atol=1e-2) - assert err_vs_silu > 10 * max(err_swiglu, 1e-6), ( - f"swiglu vs silu too close to distinguish " - f"(err_swiglu={err_swiglu}, err_vs_silu={err_vs_silu})" - ) - - -@cuda_only -def test_silu_shared_expert_unchanged(): - from aiter import ActivationType - from atom.model_ops.moe import Mxfp4MoEMethod - - if not _fp4_available(): - pytest.skip("MXFP4 not supported on this architecture") - - hidden, inter, M = 256, 256, 64 - limit = 7.0 - method, layer, w13, w2 = _build_method_and_layer( - hidden, inter, alpha=1.702, beta=1.0, limit=limit - ) - - torch.manual_seed(0) - x = torch.randn(M, hidden, dtype=torch.bfloat16, device="cuda") * 0.5 - - out = Mxfp4MoEMethod._apply_shared_experts_dense( - method, layer, x, ActivationType.Silu - ) - - gate_up = _kernel_gemm(x, w13) - inter_ref = _ref_silu(gate_up, limit) - out_ref = _kernel_gemm(inter_ref, w2) - - torch.testing.assert_close(out.float(), out_ref.float(), rtol=1e-2, atol=1e-2) - - -def _fp4_available(): - try: - import aiter.ops.triton.utils._triton.arch_info as arch_info - - return arch_info.is_fp4_avail() - except Exception: - return torch.cuda.is_available() - - -if __name__ == "__main__": - import sys - - sys.exit(pytest.main([__file__, "-v"]))