diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 577d6588e8..768700fdb6 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -5,6 +5,10 @@ from atom.models.glm4_moe import Glm4MoeForCausalLM from atom.models.deepseek_v2 import DeepseekV3ForCausalLM, GlmMoeDsaForCausalLM from atom.models.minimax_m2 import MiniMaxM2ForCausalLM +from atom.models.minimax_m3 import ( + MiniMaxM3SparseForCausalLM, + MiniMaxM3SparseForConditionalGeneration, +) from atom.models.qwen3_5 import ( Qwen3_5MoeForConditionalGenerationTextOnly, Qwen3_5ForConditionalGenerationTextOnly, @@ -22,6 +26,8 @@ "DeepseekV32ForCausalLM": DeepseekV3ForCausalLM, "GlmMoeDsaForCausalLM": GlmMoeDsaForCausalLM, "MiniMaxM2ForCausalLM": MiniMaxM2ForCausalLM, + "MiniMaxM3SparseForCausalLM": MiniMaxM3SparseForCausalLM, + "MiniMaxM3SparseForConditionalGeneration": MiniMaxM3SparseForConditionalGeneration, "Qwen3_5MoeForConditionalGeneration": Qwen3_5MoeForConditionalGenerationTextOnly, "Qwen3_5ForConditionalGeneration": Qwen3_5ForConditionalGenerationTextOnly, } diff --git a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index 596d350d4c..335e5aeba6 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -1583,6 +1583,17 @@ def _update_decode_page_table( def _should_use_native_dense_mha(self, layer) -> bool: sliding_window_size = getattr(layer, "sliding_window_size", None) + is_minimax_m3 = bool(getattr(layer, "_atom_minimax_m3_dense_mha", False)) + if ( + is_minimax_m3 + and not self.use_mla + and not layer.is_cross_attention + and layer.head_dim == 128 + and layer.qk_head_dim == 128 + and layer.v_head_dim == 128 + and (sliding_window_size is None or sliding_window_size <= -1) + ): + return True return ( not self.use_mla and not layer.is_cross_attention @@ -1748,6 +1759,10 @@ def forward_extend( if self.use_mla: return self._forward_extend_mla(q, k, v, layer, forward_batch) + if bool(getattr(layer, "_atom_minimax_m3_dense_mha", False)): + # M3 dense decode benefits from the native ragged path, but batched + # SGLang prefill is safer through the standard varlen extend path. + return self._forward_extend_mha(q, k, v, layer, forward_batch) if use_native_dense_mha: return self._forward_extend_native_dense_mha(q, layer, forward_batch) else: diff --git a/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py b/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py index 040a7dbf92..2e79ecd500 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py +++ b/atom/plugin/sglang/attention_backend/full_attention/radix_attention.py @@ -14,6 +14,7 @@ from atom.model_ops.attention_mla import MLAModules from atom.model_ops.base_attention import BaseAttention +from atom.model_ops.layernorm import GemmaRMSNorm, fused_qk_norm from atom.model_ops.utils import atom_parameter from atom.plugin.prepare import is_plugin_mode, is_sglang from atom.models.utils import maybe_prefix @@ -40,6 +41,8 @@ def __init__( per_layer_sliding_window: Optional[int] = None, rotary_emb: Optional[torch.nn.Module] = None, prefix: Optional[str] = None, + q_norm: Optional[torch.nn.Module] = None, + k_norm: Optional[torch.nn.Module] = None, **kwargs, ): super().__init__( @@ -55,10 +58,17 @@ def __init__( per_layer_sliding_window=per_layer_sliding_window, rotary_emb=rotary_emb, prefix=prefix, + q_norm=q_norm, + k_norm=k_norm, **kwargs, ) self.rotary_emb = rotary_emb + self.q_norm = q_norm + self.k_norm = k_norm + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim if is_sglang(): from sglang.srt.layers.radix_attention import RadixAttention @@ -141,9 +151,24 @@ def forward_impl_plugin_mode( save_kv_cache = kwargs.get("save_kv_cache", True) assert forward_batch is not None, "forward_batch is required for sglang" - # sglang's RadixAttention does not apply rope internally. - # Apply it here when the model passes rotary_emb at construction - # and hasn't already applied rope (e.g. fused qknorm path). + # sglang's RadixAttention does not apply q/k norm or rope internally. + # Apply them here to match ATOM native Attention semantics. + if self.q_norm is not None and self.k_norm is not None: + eps = getattr(self.q_norm, "variance_epsilon", None) or getattr( + self.q_norm, "eps", None + ) + add_unit_offset = isinstance(self.q_norm, GemmaRMSNorm) + query, key = fused_qk_norm( + query.view(-1, self.num_heads, self.head_dim), + key.view(-1, self.num_kv_heads, self.head_dim), + self.q_norm.weight, + self.k_norm.weight, + eps, + add_unit_offset=add_unit_offset, + ) + query = query.view(-1, self.num_heads * self.head_dim) + key = key.view(-1, self.num_kv_heads * self.head_dim) + if self.rotary_emb is not None and positions is not None: query, key = self.rotary_emb(positions, query, key) diff --git a/atom/plugin/sglang/attention_backend/minimax_m3_sparse.py b/atom/plugin/sglang/attention_backend/minimax_m3_sparse.py new file mode 100644 index 0000000000..9fbda6211e --- /dev/null +++ b/atom/plugin/sglang/attention_backend/minimax_m3_sparse.py @@ -0,0 +1,841 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +import triton +import triton.language as tl + +SPARSE_BLOCK_SIZE = 128 + + +@dataclass +class MiniMaxM3SGLangMetadata: + """Per-forward SGLang metadata for MiniMax-M3 sparse attention.""" + + is_decode: bool + seq_lens: torch.Tensor + block_table: torch.Tensor + max_seq_len: int + cu_seqlens_q: torch.Tensor | None = None + cu_seqlens_k: torch.Tensor | None = None + context_lens: torch.Tensor | None = None + max_query_len: int = 1 + + +def validate_minimax_m3_page_size(page_size: int) -> None: + """MiniMax-M3 sparse blocks must line up 1:1 with SGLang KV pages.""" + + if int(page_size) != SPARSE_BLOCK_SIZE: + raise ValueError( + "MiniMax-M3 sparse attention requires SGLang page size 128 " + f"(got {page_size}). Launch SGLang with --page-size 128." + ) + + +def _get_batch_size(forward_batch) -> int: + return int(getattr(forward_batch, "batch_size")) + + +def _slice_i32(tensor: torch.Tensor, batch_size: int) -> torch.Tensor: + return tensor[:batch_size].to(dtype=torch.int32) + + +def _get_query_lens(forward_batch, batch_size: int) -> torch.Tensor: + query_lens = getattr(forward_batch, "extend_seq_lens", None) + if query_lens is None: + query_lens = getattr(forward_batch, "seq_lens") + return _slice_i32(query_lens, batch_size) + + +def _get_prefix_lens( + forward_batch, + batch_size: int, + seq_lens: torch.Tensor, + query_lens: torch.Tensor, +) -> torch.Tensor: + prefix_lens = getattr(forward_batch, "extend_prefix_lens", None) + if prefix_lens is None: + return (seq_lens - query_lens).to(dtype=torch.int32) + return _slice_i32(prefix_lens, batch_size) + + +def _get_page_size(forward_batch) -> int: + return int(getattr(forward_batch.token_to_kv_pool, "page_size", 1)) + + +def _get_layer_id(layer) -> int: + if hasattr(layer, "layer_id"): + return int(layer.layer_id) + return int(layer.layer_num) + + +def _is_fp8_kv_cache_tensor(kv_cache: torch.Tensor) -> bool: + fp8_dtypes = ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2", None), + ) + return kv_cache.dtype in {dtype for dtype in fp8_dtypes if dtype is not None} + + +@triton.heuristics( + { + "BLOCK_SIZE_D": lambda args: triton.next_power_of_2(args["head_dim"]), + "BLOCK_SIZE_H": lambda args: triton.next_power_of_2(args["gqa_group_size"]), + "BLOCK_SIZE_T": lambda args: triton.next_power_of_2(args["max_topk"]), + "BLOCK_SIZE_QH": lambda args: args["BLOCK_SIZE_Q"] + * triton.next_power_of_2(args["gqa_group_size"]), + } +) +@triton.jit +def _sgl_m3_sparse_fwd_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + t_ptr, + o_ptr, + block_table_ptr, + cu_seqlens_q, + cu_seqblocks_q, + seq_lens, + prefix_lens, + num_kv_heads, + gqa_group_size, + head_dim, + max_topk, + num_q_loop, + sm_scale, + stride_qn, + stride_qh, + stride_qd, + stride_k_blk, + stride_k_pos, + stride_k_h, + stride_k_d, + stride_v_blk, + stride_v_pos, + stride_v_h, + stride_v_d, + stride_th, + stride_tn, + stride_tk, + stride_on, + stride_oh, + stride_od, + stride_bt_b, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_QH: tl.constexpr, + FP8_KV_CACHE: tl.constexpr, +): + sm_scale_log2e = sm_scale * 1.4426950409 + pid_q = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_b = tl.program_id(2) + pid_h = pid_kh * gqa_group_size + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + q_block_start = tl.load(cu_seqblocks_q + pid_b) + q_block_len = tl.load(cu_seqblocks_q + pid_b + 1) - q_block_start + seq_len = tl.load(seq_lens + pid_b) + prefix_len = tl.load(prefix_lens + pid_b) + if pid_q * num_q_loop >= q_block_len: + return + + real_q_loop = min(num_q_loop, q_block_len - pid_q * num_q_loop) + bt_row = block_table_ptr + pid_b * stride_bt_b + off_n = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + d_mask = off_d < head_dim + + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + t_ptr_j = t_ptr + (q_block_start + pid_q_j) * stride_tn + pid_kh * stride_th + off_t = tl.arange(0, BLOCK_SIZE_T) + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < max_topk, other=-1) + real_topk = tl.sum((topk_idx >= 0).to(tl.int32), axis=0) + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, gqa_group_size, head_dim), + strides=(stride_qn, stride_qh, stride_qd), + offsets=(pid_q_j * BLOCK_SIZE_Q, 0, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(2, 1, 0), + ) + q = tl.load(q_ptrs, boundary_check=(0, 1, 2), padding_option="zero") + off_q = ( + tl.arange(0, BLOCK_SIZE_Q)[:, None] + + pid_q_j * BLOCK_SIZE_Q + + prefix_len + - tl.arange(0, BLOCK_SIZE_K)[None, :] + ) + m_i = tl.full((BLOCK_SIZE_QH,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_QH,), float("-inf"), dtype=tl.float32) + acc_o = tl.zeros((BLOCK_SIZE_QH, BLOCK_SIZE_D), dtype=tl.float32) + q = tl.reshape(q, BLOCK_SIZE_QH, BLOCK_SIZE_D) + for _ in range(real_topk): + blk = tl.load(t_ptr_j).to(tl.int32) + t_ptr_j = t_ptr_j + stride_tk + c = blk * BLOCK_SIZE_K + page = tl.load(bt_row + blk).to(tl.int64) + pos = c + off_n + pos_mask = pos < seq_len + k = tl.load( + k_cache_ptr + + page * stride_k_blk + + off_n[None, :] * stride_k_pos + + pid_kh * stride_k_h + + off_d[:, None] * stride_k_d, + mask=d_mask[:, None] & pos_mask[None, :], + other=0.0, + ) + if FP8_KV_CACHE: + k = k.to(q.dtype) + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(off_q[:, None, :] >= c, 0, float("-inf")) + qk = tl.reshape(qk, BLOCK_SIZE_QH, BLOCK_SIZE_K) + qk += tl.dot(q, k) * sm_scale_log2e + qk += tl.where(pos_mask[None, :], 0, float("-inf")) + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + acc_o = acc_o * tl.exp2(m_i - m_ij)[:, None] + v = tl.load( + v_cache_ptr + + page * stride_v_blk + + off_n[:, None] * stride_v_pos + + pid_kh * stride_v_h + + off_d[None, :] * stride_v_d, + mask=pos_mask[:, None] & d_mask[None, :], + other=0.0, + ) + if FP8_KV_CACHE: + v = v.to(q.dtype) + acc_o += tl.dot(p.to(v.dtype), v) + m_i = m_ij + lse_i = m_ij + tl.log2(tl.exp2(lse_i - m_ij) + l_ij) + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + acc_o = tl.reshape(acc_o, BLOCK_SIZE_Q, BLOCK_SIZE_H, BLOCK_SIZE_D) + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, gqa_group_size, head_dim), + strides=(stride_on, stride_oh, stride_od), + offsets=(pid_q_j * BLOCK_SIZE_Q, 0, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(2, 1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1, 2)) + + +@triton.heuristics( + { + "BLOCK_SIZE_H": lambda args: max( + 16, triton.next_power_of_2(args["gqa_group_size"]) + ), + "BLOCK_SIZE_D": lambda args: triton.next_power_of_2(args["head_dim"]), + "BLOCK_SIZE_T": lambda args: triton.next_power_of_2(args["max_topk"]), + } +) +@triton.jit +def _sgl_m3_sparse_decode_kernel( + q_ptr, + k_cache_ptr, + v_cache_ptr, + t_ptr, + o_ptr, + lse_ptr, + block_table_ptr, + seq_lens, + batch_size, + gqa_group_size, + head_dim, + max_topk, + sm_scale, + stride_qn, + stride_qh, + stride_qd, + stride_k_blk, + stride_k_pos, + stride_k_h, + stride_k_d, + stride_v_blk, + stride_v_pos, + stride_v_h, + stride_v_d, + stride_th, + stride_tn, + stride_tk, + stride_o_c, + stride_o_b, + stride_o_h, + stride_o_d, + stride_l_c, + stride_l_b, + stride_l_h, + stride_bt_b, + BLOCK_SIZE_K: tl.constexpr, + NUM_TOPK_CHUNKS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, + FP8_KV_CACHE: tl.constexpr, +): + sm_scale_log2e = sm_scale * 1.4426950409 + pid_bc, pid_kh = tl.program_id(0), tl.program_id(1) + pid_b = pid_bc % batch_size + pid_c = pid_bc // batch_size + pid_h = pid_kh * gqa_group_size + chunk_size_topk = (max_topk + NUM_TOPK_CHUNKS - 1) // NUM_TOPK_CHUNKS + chunk_start_topk = pid_c * chunk_size_topk + chunk_end_compiletime = chunk_start_topk + chunk_size_topk + seq_len = tl.load(seq_lens + pid_b) + off_t = tl.arange(0, BLOCK_SIZE_T) + idx_base = t_ptr + pid_kh * stride_th + pid_b * stride_tn + topk_idx = tl.load(idx_base + off_t * stride_tk, mask=off_t < max_topk, other=-1) + real_topk = tl.sum((topk_idx >= 0).to(tl.int32), axis=0) + chunk_end_topk = tl.minimum(chunk_end_compiletime, real_topk) + + off_n = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + d_mask = off_d < head_dim + bt_row = block_table_ptr + pid_b * stride_bt_b + m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + acc_o = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32) + q_ptrs = tl.make_block_ptr( + base=q_ptr + pid_b * stride_qn + pid_h * stride_qh, + shape=(gqa_group_size, head_dim), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + + cur_idx_ptr = idx_base + chunk_start_topk * stride_tk + for _ in tl.range(chunk_start_topk, chunk_end_topk): + blk = tl.load(cur_idx_ptr).to(tl.int32) + cur_idx_ptr = cur_idx_ptr + stride_tk + c = blk * BLOCK_SIZE_K + page = tl.load(bt_row + blk).to(tl.int64) + pos = c + off_n + pos_mask = pos < seq_len + k = tl.load( + k_cache_ptr + + page * stride_k_blk + + off_n[None, :] * stride_k_pos + + pid_kh * stride_k_h + + off_d[:, None] * stride_k_d, + mask=d_mask[:, None] & pos_mask[None, :], + other=0.0, + ) + if FP8_KV_CACHE: + k = k.to(q.dtype) + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(pos_mask[None, :], 0, float("-inf")) + qk += tl.dot(q, k) * sm_scale_log2e + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + acc_o = acc_o * tl.exp2(m_i - m_ij)[:, None] + v = tl.load( + v_cache_ptr + + page * stride_v_blk + + off_n[:, None] * stride_v_pos + + pid_kh * stride_v_h + + off_d[None, :] * stride_v_d, + mask=pos_mask[:, None] & d_mask[None, :], + other=0.0, + ) + if FP8_KV_CACHE: + v = v.to(q.dtype) + acc_o += tl.dot(p.to(v.dtype), v) + m_i = m_ij + lse_i = m_ij + tl.log2(tl.exp2(lse_i - m_ij) + l_ij) + scale = tl.where(lse_i > float("-inf"), tl.exp2(m_i - lse_i), tl.zeros_like(lse_i)) + acc_o = acc_o * scale[:, None] + o_ptrs = tl.make_block_ptr( + base=o_ptr + pid_c * stride_o_c + pid_b * stride_o_b + pid_h * stride_o_h, + shape=(gqa_group_size, head_dim), + strides=(stride_o_h, stride_o_d), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + l_ptrs = ( + lse_ptr + + pid_c * stride_l_c + + pid_b * stride_l_b + + (pid_h + tl.arange(0, BLOCK_SIZE_H)) * stride_l_h + ) + tl.store( + l_ptrs, + lse_i, + mask=tl.arange(0, BLOCK_SIZE_H) < gqa_group_size, + ) + + +@triton.jit +def _sgl_m3_sparse_decode_merge_kernel( + o_partial, + lse_partial, + output, + batch_size, + num_heads, + head_dim, + num_chunks: tl.constexpr, + stride_oc, + stride_ob, + stride_oh, + stride_od, + stride_lc, + stride_lb, + stride_lh, + stride_on, + stride_oh_out, + stride_od_out, + BLOCK_SIZE_D: tl.constexpr, +): + pid_b, pid_h = tl.program_id(0), tl.program_id(1) + off_d = tl.arange(0, BLOCK_SIZE_D) + d_mask = off_d < head_dim + m = tl.full((), float("-inf"), dtype=tl.float32) + for c in range(num_chunks): + lse_value = tl.load( + lse_partial + c * stride_lc + pid_b * stride_lb + pid_h * stride_lh + ) + m = tl.maximum(m, lse_value) + acc = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) + denom = tl.full((), 0.0, dtype=tl.float32) + for c in range(num_chunks): + lse_value = tl.load( + lse_partial + c * stride_lc + pid_b * stride_lb + pid_h * stride_lh + ) + w = tl.exp2(lse_value - m) + vals = tl.load( + o_partial + + c * stride_oc + + pid_b * stride_ob + + pid_h * stride_oh + + off_d * stride_od, + mask=d_mask, + other=0.0, + ) + acc += vals.to(tl.float32) * w + denom += w + acc = acc / denom + tl.store( + output + pid_b * stride_on + pid_h * stride_oh_out + off_d * stride_od_out, + acc.to(output.dtype.element_ty), + mask=d_mask, + ) + + +@torch.no_grad() +def minimax_m3_sparse_attn_split_kv( + q: torch.Tensor, + key_cache: torch.Tensor, # [num_blocks, page_size, num_kv_heads, head_dim] + value_cache: torch.Tensor, # [num_blocks, page_size, num_kv_heads, head_dim] + topk_idx: torch.Tensor, + block_table: torch.Tensor, + cu_seqlens_q: torch.Tensor, + seq_lens: torch.Tensor, + prefix_lens: torch.Tensor, + max_query_len: int, + num_kv_heads: int, + sm_scale: float, + output: torch.Tensor, +) -> None: + total_q, num_heads, head_dim = q.shape + del total_q + batch = cu_seqlens_q.shape[0] - 1 + topk = topk_idx.shape[-1] + gqa_group_size = num_heads // num_kv_heads + grid = (max_query_len, num_kv_heads, batch) + _sgl_m3_sparse_fwd_kernel[grid]( + q, + key_cache, + value_cache, + topk_idx, + output, + block_table, + cu_seqlens_q, + cu_seqlens_q, + seq_lens, + prefix_lens, + num_kv_heads, + gqa_group_size, + head_dim, + topk, + 1, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + key_cache.stride(0), + key_cache.stride(1), + key_cache.stride(2), + key_cache.stride(3), + value_cache.stride(0), + value_cache.stride(1), + value_cache.stride(2), + value_cache.stride(3), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + block_table.stride(0), + BLOCK_SIZE_Q=1, + BLOCK_SIZE_K=SPARSE_BLOCK_SIZE, + FP8_KV_CACHE=_is_fp8_kv_cache_tensor(key_cache), + num_stages=1, + ) + + +@torch.no_grad() +def minimax_m3_sparse_attn_decode_split_kv( + q: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + topk_idx: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor, + num_kv_heads: int, + sm_scale: float, + output: torch.Tensor, +) -> None: + batch, num_heads, head_dim = q.shape + max_topk = topk_idx.shape[-1] + gqa_group_size = num_heads // num_kv_heads + target = max(1, min(max_topk, 256 // max(1, batch * num_kv_heads))) + num_topk_chunks = 1 << (target.bit_length() - 1) + o_partial = torch.empty( + num_topk_chunks, batch, num_heads, head_dim, dtype=q.dtype, device=q.device + ) + lse_partial = torch.empty( + num_topk_chunks, batch, num_heads, dtype=torch.float32, device=q.device + ) + grid = (batch * num_topk_chunks, num_kv_heads) + _sgl_m3_sparse_decode_kernel[grid]( + q, + key_cache, + value_cache, + topk_idx, + o_partial, + lse_partial, + block_table, + seq_lens, + batch, + gqa_group_size, + head_dim, + max_topk, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + key_cache.stride(0), + key_cache.stride(1), + key_cache.stride(2), + key_cache.stride(3), + value_cache.stride(0), + value_cache.stride(1), + value_cache.stride(2), + value_cache.stride(3), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + block_table.stride(0), + BLOCK_SIZE_K=SPARSE_BLOCK_SIZE, + NUM_TOPK_CHUNKS=num_topk_chunks, + FP8_KV_CACHE=_is_fp8_kv_cache_tensor(key_cache), + num_stages=1, + ) + merge_grid = (batch, num_heads) + _sgl_m3_sparse_decode_merge_kernel[merge_grid]( + o_partial, + lse_partial, + output, + batch, + num_heads, + head_dim, + num_topk_chunks, + o_partial.stride(0), + o_partial.stride(1), + o_partial.stride(2), + o_partial.stride(3), + lse_partial.stride(0), + lse_partial.stride(1), + lse_partial.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + BLOCK_SIZE_D=triton.next_power_of_2(head_dim), + ) + + +def build_minimax_m3_block_table(forward_batch, page_size: int) -> torch.Tensor: + """Build physical block ids from SGLang's request-token table.""" + + validate_minimax_m3_page_size(page_size) + batch_size = _get_batch_size(forward_batch) + req_pool_indices = forward_batch.req_pool_indices[:batch_size] + req_to_token = forward_batch.req_to_token_pool.req_to_token + token_table = req_to_token[req_pool_indices, :].clone() + + if not forward_batch.forward_mode.is_decode_or_idle(): + query_lens = _get_query_lens(forward_batch, batch_size) + seq_lens = _slice_i32(forward_batch.seq_lens, batch_size) + prefix_lens = _get_prefix_lens(forward_batch, batch_size, seq_lens, query_lens) + out_cache_loc = forward_batch.out_cache_loc + offset = 0 + for req_idx in range(batch_size): + prefix_len = int(prefix_lens[req_idx].item()) + query_len = int(query_lens[req_idx].item()) + if query_len > 0: + token_table[req_idx, prefix_len : prefix_len + query_len] = ( + out_cache_loc[offset : offset + query_len] + ) + offset += query_len + + seq_lens = _slice_i32(forward_batch.seq_lens, batch_size) + max_seq_len = int(seq_lens.max().item()) if batch_size else 0 + max_blocks = (max_seq_len + page_size - 1) // page_size + block_table = token_table[:, : max_blocks * page_size : page_size] // page_size + return block_table.to(dtype=torch.int32).contiguous() + + +def build_minimax_m3_forward_metadata( + forward_batch, + block_table: torch.Tensor, + page_size: int, +) -> MiniMaxM3SGLangMetadata: + """Translate SGLang ForwardBatch fields into MiniMax-M3 sparse metadata.""" + + validate_minimax_m3_page_size(page_size) + batch_size = _get_batch_size(forward_batch) + seq_lens = _slice_i32(forward_batch.seq_lens, batch_size) + max_seq_len = int(seq_lens.max().item()) if batch_size else 0 + + if forward_batch.forward_mode.is_decode_or_idle(): + return MiniMaxM3SGLangMetadata( + is_decode=True, + seq_lens=seq_lens, + block_table=block_table, + max_seq_len=max_seq_len, + ) + + query_lens = _get_query_lens(forward_batch, batch_size) + context_lens = _get_prefix_lens(forward_batch, batch_size, seq_lens, query_lens) + cu_seqlens_q = torch.empty( + batch_size + 1, dtype=torch.int32, device=seq_lens.device + ) + cu_seqlens_k = torch.empty( + batch_size + 1, dtype=torch.int32, device=seq_lens.device + ) + cu_seqlens_q[0] = 0 + cu_seqlens_k[0] = 0 + torch.cumsum(query_lens, dim=0, out=cu_seqlens_q[1:]) + torch.cumsum(seq_lens, dim=0, out=cu_seqlens_k[1:]) + + return MiniMaxM3SGLangMetadata( + is_decode=False, + seq_lens=seq_lens, + block_table=block_table, + max_seq_len=max_seq_len, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + context_lens=context_lens, + max_query_len=int(query_lens.max().item()) if batch_size else 0, + ) + + +def _ensure_side_caches( + layer, + forward_batch, + index_key: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + page_size = _get_page_size(forward_batch) + validate_minimax_m3_page_size(page_size) + + k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( + _get_layer_id(layer) + ) + num_slots = int(k_buffer.shape[0]) + num_blocks = num_slots // page_size + if num_blocks <= 0: + raise RuntimeError("MiniMax-M3 sparse attention received an empty KV pool.") + + key_cache = k_buffer[: num_blocks * page_size].view( + num_blocks, page_size, layer.num_kv_heads, layer.head_dim + ) + value_cache = v_buffer[: num_blocks * page_size].view( + num_blocks, page_size, layer.num_kv_heads, layer.head_dim + ) + index_shape = (num_blocks, page_size, layer.idx_head_dim) + + index_cache = getattr(layer, "_sglang_m3_index_cache", None) + if ( + index_cache is None + or tuple(index_cache.shape) != index_shape + or index_cache.device != index_key.device + or index_cache.dtype != index_key.dtype + ): + index_cache = torch.empty( + index_shape, dtype=index_key.dtype, device=index_key.device + ) + layer._sglang_m3_index_cache = index_cache + + return key_cache, value_cache, index_cache + + +def _insert_sparse_cache( + layer, + forward_batch, + key: torch.Tensor, + value: torch.Tensor, + index_key: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + key_cache, value_cache, index_cache = _ensure_side_caches( + layer, forward_batch, index_key + ) + page_size = _get_page_size(forward_batch) + slot_mapping = forward_batch.out_cache_loc[: key.shape[0]].to(dtype=torch.long) + valid = slot_mapping >= 0 + + slots = slot_mapping[valid] + block_ids = torch.div(slots, page_size, rounding_mode="floor") + block_offsets = slots % page_size + + key = key.view(-1, layer.num_kv_heads, layer.head_dim)[valid] + value = value.view(-1, layer.num_kv_heads, layer.head_dim)[valid] + index_key = index_key.view(-1, layer.idx_head_dim)[valid] + key_cache[block_ids, block_offsets] = key.to(key_cache.dtype) + value_cache[block_ids, block_offsets] = value.to(value_cache.dtype) + index_cache[block_ids, block_offsets] = index_key.to(index_cache.dtype) + return key_cache, value_cache, index_cache + + +def minimax_m3_sparse_attention_for_sglang( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + index_query: torch.Tensor, + index_key: torch.Tensor, + layer, + forward_batch=None, + save_kv_cache: bool = True, +) -> torch.Tensor: + """Run MiniMax-M3 lightning-indexer sparse attention in SGLang plugin mode.""" + + if forward_batch is None: + from atom.plugin.sglang.runtime import get_current_forward_batch + + forward_batch = get_current_forward_batch() + if forward_batch is None: + raise RuntimeError( + "MiniMax-M3 sparse attention requires a SGLang ForwardBatch." + ) + + page_size = _get_page_size(forward_batch) + validate_minimax_m3_page_size(page_size) + if save_kv_cache: + key_cache, value_cache, index_cache = _insert_sparse_cache( + layer, forward_batch, key, value, index_key + ) + else: + key_cache, value_cache, index_cache = _ensure_side_caches( + layer, forward_batch, index_key + ) + + block_table = build_minimax_m3_block_table(forward_batch, page_size) + metadata = build_minimax_m3_forward_metadata(forward_batch, block_table, page_size) + + q = query.view(-1, layer.num_heads, layer.head_dim) + index_q = index_query.view(-1, layer.num_idx_heads, layer.idx_head_dim) + output = torch.empty_like(q) + + from atom.model_ops.minimax_m3.index_topk import ( + minimax_m3_index_topk, + minimax_m3_index_topk_decode, + ) + + if metadata.is_decode: + batch_size = metadata.seq_lens.shape[0] + topk_idx = minimax_m3_index_topk_decode( + index_q[:batch_size], + index_cache, + metadata.block_table, + metadata.seq_lens, + metadata.max_seq_len, + layer.topk_blocks, + layer.init_blocks, + layer.local_blocks, + layer.num_kv_heads, + layer.scaling, + ) + minimax_m3_sparse_attn_decode_split_kv( + q[:batch_size], + key_cache, + value_cache, + topk_idx, + metadata.block_table, + metadata.seq_lens, + layer.num_kv_heads, + layer.scaling, + output[:batch_size], + ) + if batch_size < output.shape[0]: + output[batch_size:].zero_() + else: + assert metadata.cu_seqlens_q is not None + assert metadata.context_lens is not None + num_tokens = int(metadata.cu_seqlens_q[-1].item()) + topk_idx = minimax_m3_index_topk( + index_q[:num_tokens], + index_cache, + metadata.block_table, + metadata.cu_seqlens_q, + metadata.seq_lens, + metadata.context_lens, + metadata.max_query_len, + metadata.max_seq_len, + layer.topk_blocks, + layer.init_blocks, + layer.local_blocks, + layer.num_kv_heads, + layer.scaling, + ) + minimax_m3_sparse_attn_split_kv( + q[:num_tokens], + key_cache, + value_cache, + topk_idx, + metadata.block_table, + metadata.cu_seqlens_q, + metadata.seq_lens, + metadata.context_lens, + metadata.max_query_len, + layer.num_kv_heads, + layer.scaling, + output[:num_tokens], + ) + if num_tokens < output.shape[0]: + output[num_tokens:].zero_() + + return output.reshape_as(query) diff --git a/atom/plugin/sglang/models/__init__.py b/atom/plugin/sglang/models/__init__.py index e69de29bb2..dc1e3135b4 100644 --- a/atom/plugin/sglang/models/__init__.py +++ b/atom/plugin/sglang/models/__init__.py @@ -0,0 +1,5 @@ +from atom.plugin.sglang.models.minimax_m3_processor import ( + register_minimax_m3_text_only_processor, +) + +register_minimax_m3_text_only_processor() diff --git a/atom/plugin/sglang/models/minimax_m3.py b/atom/plugin/sglang/models/minimax_m3.py new file mode 100644 index 0000000000..878fe6dabe --- /dev/null +++ b/atom/plugin/sglang/models/minimax_m3.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from types import MethodType + +import torch + +from atom.model_ops.layernorm import fused_qk_norm +from atom.models.minimax_m3 import MiniMaxM3Attention, MiniMaxM3SparseAttention +from atom.plugin.sglang.attention_backend.minimax_m3_sparse import ( + minimax_m3_sparse_attention_for_sglang, +) + + +def _gemma_qk_norm_for_sglang( + q: torch.Tensor, + k: torch.Tensor, + q_norm, + k_norm, + num_q_heads: int, + num_kv_heads: int, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + q, k = fused_qk_norm( + q.view(-1, num_q_heads, head_dim), + k.view(-1, num_kv_heads, head_dim), + q_norm.weight, + k_norm.weight, + q_norm.variance_epsilon, + add_unit_offset=True, + ) + return ( + q.view(-1, num_q_heads * head_dim), + k.view(-1, num_kv_heads * head_dim), + ) + + +def _patch_minimax_m3_dense_attention_for_sglang(module: MiniMaxM3Attention) -> None: + inner_attn = getattr(getattr(module, "attn", None), "attn", None) + if inner_attn is not None: + setattr(inner_attn, "_atom_minimax_m3_dense_mha", True) + + +def _sparse_forward_for_sglang( + self: MiniMaxM3SparseAttention, + positions: torch.Tensor, + hidden_states: torch.Tensor, +) -> torch.Tensor: + qkv = self.qkv_proj(hidden_states) + if isinstance(qkv, tuple): + qkv = qkv[0] + + q, k, v, index_q, index_k = qkv.split( + [ + self.q_size, + self.kv_size, + self.kv_size, + self.index_q_size, + self.idx_head_dim, + ], + dim=-1, + ) + q, k = _gemma_qk_norm_for_sglang( + q, + k, + self.q_norm, + self.k_norm, + self.num_heads, + self.num_kv_heads, + self.head_dim, + ) + q, k = self.rotary_emb(positions, q, k) + + index_q, index_k = _gemma_qk_norm_for_sglang( + index_q, + index_k, + self.index_q_norm, + self.index_k_norm, + self.num_idx_heads, + 1, + self.idx_head_dim, + ) + index_q, index_k = self.index_rotary_emb(positions, index_q, index_k) + + attn_output = minimax_m3_sparse_attention_for_sglang( + q, + k, + v, + index_q, + index_k, + self, + ) + return self.o_proj(attn_output) + + +def _patch_minimax_m3_sparse_attention_for_sglang( + module: MiniMaxM3SparseAttention, +) -> None: + if getattr(module, "_atom_sglang_minimax_m3_sparse_patched", False): + return + # SGLang's token_to_kv_pool APIs are keyed by layer_id. The native ATOM + # layer uses layer_num, so expose both names for the plugin helper. + module.layer_id = module.layer_num + module.forward = MethodType(_sparse_forward_for_sglang, module) + module._atom_sglang_minimax_m3_sparse_patched = True + + +def setup_minimax_m3_for_sglang(model) -> None: + """Patch MiniMax-M3 modules for SGLang plugin mode.""" + + for module in model.modules(): + if isinstance(module, MiniMaxM3Attention): + _patch_minimax_m3_dense_attention_for_sglang(module) + elif isinstance(module, MiniMaxM3SparseAttention): + _patch_minimax_m3_sparse_attention_for_sglang(module) diff --git a/atom/plugin/sglang/models/minimax_m3_processor.py b/atom/plugin/sglang/models/minimax_m3_processor.py new file mode 100644 index 0000000000..e3984592af --- /dev/null +++ b/atom/plugin/sglang/models/minimax_m3_processor.py @@ -0,0 +1,55 @@ +"""Text-only processor registration for MiniMax-M3 in SGLang plugin mode.""" + +from __future__ import annotations + +try: + from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor +except Exception: + BaseMultimodalProcessor = object + + +class MiniMaxM3SparseForCausalLM: + pass + + +class MiniMaxM3SparseForConditionalGeneration: + pass + + +class MiniMaxM3TextOnlyProcessor(BaseMultimodalProcessor): + """SGLang processor placeholder for text-only MiniMax-M3 serving.""" + + models = [MiniMaxM3SparseForCausalLM, MiniMaxM3SparseForConditionalGeneration] + + async def process_mm_data_async( + self, + image_data, + audio_data, + input_text, + request_obj, + **kwargs, + ): + del image_data, audio_data, input_text, request_obj, kwargs + return None + + +def register_minimax_m3_text_only_processor() -> None: + """Let SGLang tokenizer init accept MiniMax-M3 text-only serving. + + MiniMax-M3 checkpoints advertise a conditional-generation architecture and + include multimodal sub-configs, so SGLang asks for a multimodal processor + before model workers start. The ATOM SGLang path currently supports only + the language model, so plain text requests need a processor placeholder + that rejects actual multimodal inputs. + """ + + try: + from sglang.srt.managers.multimodal_processor import PROCESSOR_MAPPING + except Exception: + return + + PROCESSOR_MAPPING.setdefault(MiniMaxM3SparseForCausalLM, MiniMaxM3TextOnlyProcessor) + PROCESSOR_MAPPING.setdefault( + MiniMaxM3SparseForConditionalGeneration, + MiniMaxM3TextOnlyProcessor, + ) diff --git a/atom/plugin/sglang/runtime/model_arch.py b/atom/plugin/sglang/runtime/model_arch.py index cd791fc485..849a41e11d 100644 --- a/atom/plugin/sglang/runtime/model_arch.py +++ b/atom/plugin/sglang/runtime/model_arch.py @@ -48,6 +48,28 @@ def _prepare_kimi_k25_config(atom_config: Any, model_arch: str) -> None: remap_kimi_k25_quant_config_for_sglang_plugin(atom_config, model_arch) +def _prepare_minimax_m3_config(atom_config: Any, model_arch: str) -> None: + from atom.models.minimax_m3 import ( + MiniMaxM3SparseForCausalLM, + MiniMaxM3SparseForConditionalGeneration, + ) + + quant_config = getattr(atom_config, "quant_config", None) + if quant_config is None: + return + + model_cls = ( + MiniMaxM3SparseForConditionalGeneration + if model_arch == "MiniMaxM3SparseForConditionalGeneration" + else MiniMaxM3SparseForCausalLM + ) + quant_config.remap_layer_name( + atom_config.hf_config, + packed_modules_mapping=model_cls.packed_modules_mapping, + quant_exclude_name_mapping=getattr(model_cls, "quant_exclude_name_mapping", {}), + ) + + def _install_deepseek_mla_adapters(model: Any) -> None: from atom.plugin.sglang.models.deepseek_mla import setup_deepseek_for_sglang @@ -69,6 +91,12 @@ def _install_deepseek_v4_adapters(model: Any) -> None: patch_deepseek_v4_attention_for_sglang(module) +def _install_minimax_m3_adapters(model: Any) -> None: + from atom.plugin.sglang.models.minimax_m3 import setup_minimax_m3_for_sglang + + setup_minimax_m3_for_sglang(model) + + MODEL_ADAPTER_SPECS = { "DeepseekV3ForCausalLM": SGLangModelAdapterSpec( install_adapters=_install_deepseek_mla_adapters, @@ -103,6 +131,16 @@ def _install_deepseek_v4_adapters(model: Any) -> None: "DeepseekV4ForCausalLM": SGLangModelAdapterSpec( install_adapters=_install_deepseek_v4_adapters, ), + "MiniMaxM3SparseForCausalLM": SGLangModelAdapterSpec( + uses_context_only_forward=True, + prepare_config=_prepare_minimax_m3_config, + install_adapters=_install_minimax_m3_adapters, + ), + "MiniMaxM3SparseForConditionalGeneration": SGLangModelAdapterSpec( + uses_context_only_forward=True, + prepare_config=_prepare_minimax_m3_config, + install_adapters=_install_minimax_m3_adapters, + ), } # Architectures whose SGLang EntryClass is generated by base_model_wrapper. @@ -117,6 +155,8 @@ def _install_deepseek_v4_adapters(model: Any) -> None: "Qwen3MoeForCausalLM", "Qwen3NextForCausalLM", "MiniMaxM2ForCausalLM", + "MiniMaxM3SparseForCausalLM", + "MiniMaxM3SparseForConditionalGeneration", "DeepseekV4ForCausalLM", ) }