diff --git a/atom/plugin/vllm/attention/backend.py b/atom/plugin/vllm/attention/backend.py index f2c4c734a..abfa254f3 100644 --- a/atom/plugin/vllm/attention/backend.py +++ b/atom/plugin/vllm/attention/backend.py @@ -2,6 +2,7 @@ import torch from vllm.v1.attention.backends.mla.prefill.base import MLAPrefillBackend +from atom.model_ops.minimax_m3.sparse_attn import SPARSE_BLOCK_SIZE class AiterMhaBackendForVllm: @@ -300,6 +301,159 @@ def full_cls_name(cls) -> tuple[str, str]: return (cls.__module__, cls.__qualname__) +class MiniMaxM3SparseAttentionBackend: + """vLLM-facing sparse MHA backend surface for MiniMax-M3.""" + + accept_output_buffer: bool = True + supported_dtypes: list = [torch.float16, torch.bfloat16] + supported_kv_cache_dtypes: list = ["bfloat16", "fp8", "fp8_e4m3"] + forward_includes_kv_cache_update: bool = True + + @staticmethod + def get_name() -> str: + return "MINIMAX_M3_SPARSE" + + @staticmethod + def get_supported_kernel_block_sizes(): + return [SPARSE_BLOCK_SIZE] + + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + return block_size is None or block_size == SPARSE_BLOCK_SIZE + + @classmethod + def get_kv_cache_block_dim( + cls, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> int: + sentinel = 1234567 + shape = cls.get_kv_cache_shape( + sentinel, + block_size, + num_kv_heads, + head_size, + cache_dtype_str=cache_dtype_str, + ) + return shape.index(sentinel) + + @classmethod + def get_preferred_block_size(cls, default_block_size: int) -> int: + return SPARSE_BLOCK_SIZE + + @staticmethod + def get_builder_cls() -> Type: + from atom.plugin.vllm.attention.metadata import ( + MinimaxM3SparseAttentionMetadataBuilder, + ) + + return MinimaxM3SparseAttentionMetadataBuilder + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [128] + + @classmethod + def is_sparse(cls) -> bool: + return True + + @classmethod + def is_mla(cls) -> bool: + return False + + @classmethod + def is_ssm(cls) -> bool: + return False + + @staticmethod + def get_required_kv_cache_layout(): + return None + + @classmethod + def supports_alibi_sqrt(cls) -> bool: + return False + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size != SPARSE_BLOCK_SIZE: + raise ValueError( + f"MiniMax-M3 sparse block size must be {SPARSE_BLOCK_SIZE}." + ) + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + raise NotImplementedError + + @staticmethod + def get_impl_cls(): + from atom.plugin.vllm.attention.minimax_m3_attnetion import ( + MiniMaxM3SparseAttentionForVllm, + ) + + return MiniMaxM3SparseAttentionForVllm + + @classmethod + def full_cls_name(cls) -> tuple[str, str]: + return (cls.__module__, cls.__qualname__) + + +class SparseMHAIndexerBackend(AiterMlaBackendForVllm): + """vLLM-facing key-only indexer backend surface for MiniMax-M3.""" + + @staticmethod + def get_name() -> str: + return "MINIMAX_M3_SPARSE_INDEXER" + + @staticmethod + def get_supported_kernel_block_sizes(): + return [SPARSE_BLOCK_SIZE] + + @classmethod + def get_preferred_block_size(cls, default_block_size: int) -> int: + return SPARSE_BLOCK_SIZE + + @staticmethod + def get_builder_cls() -> Type: + from atom.plugin.vllm.attention.metadata import ( + MinimaxM3SparseAttentionMetadataBuilder, + ) + + return MinimaxM3SparseAttentionMetadataBuilder + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [64, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @staticmethod + def get_kv_cache_stride_order( + include_num_layers_dimension: bool = False, + ) -> tuple[int, ...]: + if include_num_layers_dimension: + raise NotImplementedError + return (0, 1, 2) + + class GDNAttentionBackend: @staticmethod def get_name() -> str: diff --git a/atom/plugin/vllm/attention/layer.py b/atom/plugin/vllm/attention/layer.py index 71c54b40d..1185e61c3 100644 --- a/atom/plugin/vllm/attention/layer.py +++ b/atom/plugin/vllm/attention/layer.py @@ -9,6 +9,49 @@ ) from atom.plugin.vllm.attention import ops as _atom_vllm_attention_ops # noqa: F401 +_MINIMAX_M3_MODEL_TYPES = {"minimax_m3", "minimax_m3_text", "minimax_m3_vl"} + + +def _is_minimax_m3_model(atom_config) -> bool: + hf_config = getattr(atom_config, "hf_config", None) + model_type = getattr(hf_config, "model_type", "") + text_config = getattr(hf_config, "text_config", None) + text_model_type = getattr(text_config, "model_type", "") + return ( + model_type in _MINIMAX_M3_MODEL_TYPES + or text_model_type in _MINIMAX_M3_MODEL_TYPES + ) + + +def _minimax_m3_attention_cls_for_vllm(atom_config, kwargs): + if not _is_minimax_m3_model(atom_config): + return None + impl_cls = kwargs.get("impl_cls") + if impl_cls is not None: + from atom.model_ops.attention_mha import ( + SparseMHAPagedAttentionImpl as AtomSparseMHAPagedAttentionImpl, + ) + + if impl_cls is AtomSparseMHAPagedAttentionImpl: + from atom.plugin.vllm.attention.minimax_m3_attnetion import ( + MiniMaxM3SparseAttentionForVllm, + ) + + return MiniMaxM3SparseAttentionForVllm + return None + + if ( + kwargs.get("rotary_emb") is not None + and kwargs.get("q_norm") is not None + and kwargs.get("k_norm") is not None + ): + from atom.plugin.vllm.attention.minimax_m3_attnetion import ( + MiniMaxM3DenseAttentionForVllm, + ) + + return MiniMaxM3DenseAttentionForVllm + return None + class AttentionForVllm: """Factory for ATOM-owned attention layers running under vLLM.""" @@ -30,4 +73,10 @@ def __new__( *args, mla_modules=mla_modules, **kwargs ) return AttentionForVllmMLA(*args, mla_modules=mla_modules, **kwargs) + minimax_m3_attention_cls = _minimax_m3_attention_cls_for_vllm( + atom_config, kwargs + ) + if minimax_m3_attention_cls is not None: + return minimax_m3_attention_cls(*args, **kwargs) + kwargs.pop("impl_cls", None) return AttentionForVllmMHA(*args, **kwargs) diff --git a/atom/plugin/vllm/attention/metadata.py b/atom/plugin/vllm/attention/metadata.py index 30d8a018a..fece1b262 100644 --- a/atom/plugin/vllm/attention/metadata.py +++ b/atom/plugin/vllm/attention/metadata.py @@ -399,6 +399,190 @@ class AiterMlaSparseMetadataForVllm: reduce_partial_map: torch.Tensor | None = None +@dataclass +class MinimaxM3SparsePrefillMetadata: + qo_indptr: torch.Tensor + cu_seqlens_q: torch.Tensor + seq_lens: torch.Tensor + context_lens: torch.Tensor + block_table: torch.Tensor + max_query_len: int + max_seq_len: int + + +@dataclass +class MinimaxM3SparseDecodeMetadata: + seq_lens: torch.Tensor + block_table: torch.Tensor + max_query_len: int = 1 + + +@dataclass +class MinimaxM3SparseMetadata: + seq_lens: torch.Tensor + max_seq_len: int + slot_mapping: torch.Tensor + num_actual_tokens: int + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + block_table: torch.Tensor + max_query_len: int + prefill: MinimaxM3SparsePrefillMetadata | None = None + decode: MinimaxM3SparseDecodeMetadata | None = None + + +class MinimaxM3SparseAttentionMetadataBuilder(AttentionMetadataBuilder): + # MiniMax-M3 sparse attention owns dynamic index/topk/cache state that must + # be refreshed outside full cudagraph capture. + _cudagraph_support = AttentionCGSupport.NEVER + reorder_batch_threshold = 1 + + def __init__( + self, + kv_cache_spec=None, + layer_names=None, + config=None, + device=None, + model_runner=None, + ): + del model_runner + super().__init__(kv_cache_spec, layer_names, config, device) + logger.info("init MinimaxM3SparseAttentionMetadataBuilder") + from vllm.config import VllmConfig + + assert isinstance(config, VllmConfig) + self.vllm_config = config + self.model_config = config.model_config + self.cache_config = config.cache_config + self.scheduler_config = config.scheduler_config + self.block_size = kv_cache_spec.block_size + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) + + def build( + self, + common_prefix_len: int = 0, + common_attn_metadata=None, + fast_build: bool = False, + ): + del fast_build + if common_prefix_len > 0: + raise ValueError("ATOM does not support cascade attention yet") + assert common_attn_metadata is not None + + from atom.model_ops.minimax_m3.sparse_attn import SPARSE_BLOCK_SIZE + from vllm.v1.attention.backends.utils import ( + split_decodes_prefills_and_extends, + ) + + ( + num_decodes, + num_extends, + num_prefills, + num_decode_tokens, + _num_extend_tokens, + _num_prefill_tokens, + ) = split_decodes_prefills_and_extends( + common_attn_metadata=common_attn_metadata, + decode_threshold=getattr(self, "reorder_batch_threshold", 1) or 1, + ) + num_tokens = common_attn_metadata.num_actual_tokens + num_prefills_total = num_extends + num_prefills + num_prefill_tokens = num_tokens - num_decode_tokens + seq_lens = common_attn_metadata.seq_lens + block_table = common_attn_metadata.block_table_tensor + + prefill_metadata: MinimaxM3SparsePrefillMetadata | None = None + if num_prefills_total > 0: + # MiniMax-M3 sparse attention uses the prefill kernel for any mixed + # decode+prefill batch, because it builds per-token causal sparse + # block tables. Only pure decode batches use the decode kernel. + # The vLLM scheduler orders request rows as decode, extend, prefill. + # The prefill metadata below must therefore start after decode rows + # and must shift query_start_loc back to this phase's local token + # slice; otherwise sparse prefill reads decode requests as prefixes. + prefill_start = num_decodes + prefill_stop = prefill_start + num_prefills_total + prefill_token_start = num_decode_tokens + prefill_seq_lens = seq_lens[prefill_start:prefill_stop] + prefill_query_start = common_attn_metadata.query_start_loc[ + prefill_start : prefill_stop + 1 + ].to(torch.int32) + prefill_query_start = prefill_query_start - prefill_token_start + context_lens = common_attn_metadata.compute_num_computed_tokens()[ + prefill_start:prefill_stop + ] + seq_lens_cpu = getattr( + common_attn_metadata, "seq_lens_cpu_upper_bound", None + ) + if seq_lens_cpu is None: + seq_lens_cpu = getattr(common_attn_metadata, "_seq_lens_cpu", None) + if seq_lens_cpu is None: + seq_lens_cpu = seq_lens.cpu() + prefill_max_seq_len = int( + seq_lens_cpu[prefill_start:prefill_stop].max().item() + ) + query_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + prefill_max_query_len = max( + 1, int(query_lens_cpu[prefill_start:prefill_stop].max().item()) + ) + if self.block_size != SPARSE_BLOCK_SIZE: + raise ValueError( + f"MiniMax-M3 sparse block size must be {SPARSE_BLOCK_SIZE}." + ) + qo_indptr = torch.arange( + num_prefill_tokens + 1, dtype=torch.int32, device=seq_lens.device + ) + prefill_metadata = MinimaxM3SparsePrefillMetadata( + qo_indptr=qo_indptr, + cu_seqlens_q=prefill_query_start, + seq_lens=prefill_seq_lens, + context_lens=context_lens, + block_table=block_table[prefill_start:prefill_stop], + max_query_len=prefill_max_query_len, + max_seq_len=prefill_max_seq_len, + ) + + decode_metadata: MinimaxM3SparseDecodeMetadata | None = None + if num_decodes > 0: + query_lens_cpu = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) + decode_max_query_len = max( + 1, int(query_lens_cpu[:num_decodes].max().item()) + ) + decode_metadata = MinimaxM3SparseDecodeMetadata( + seq_lens=seq_lens[:num_decodes], + block_table=block_table[:num_decodes], + max_query_len=decode_max_query_len, + ) + + return MinimaxM3SparseMetadata( + seq_lens=seq_lens, + max_seq_len=common_attn_metadata.max_seq_len, + slot_mapping=common_attn_metadata.slot_mapping, + num_actual_tokens=num_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills_total, + num_prefill_tokens=num_prefill_tokens, + block_table=block_table, + max_query_len=common_attn_metadata.max_query_len, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + def build_for_cudagraph_capture(self, common_attn_metadata=None): + return self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + + # vLLM metadata builders class AiterMhaMetadataBuilderForVllm(AttentionMetadataBuilder): """vLLM-only MHA metadata builder.""" diff --git a/atom/plugin/vllm/attention/minimax_m3_attnetion.py b/atom/plugin/vllm/attention/minimax_m3_attnetion.py new file mode 100644 index 000000000..d689a7eb0 --- /dev/null +++ b/atom/plugin/vllm/attention/minimax_m3_attnetion.py @@ -0,0 +1,887 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +"""MiniMax-M3 attention adapters for ATOM vLLM plugin mode. + +This module keeps the ATOM ``MiniMaxM3Attention`` layer intact: qkv/o +projections, per-head QK norms, RoPE objects, and checkpoint weight names stay +owned by ``atom.models.minimax_m3``. Sparse layers use the MiniMax-M3-specific +runtime below; dense layers use vLLM's Triton custom-op backend after applying +MiniMax-M3's q/k norm + RoPE transform. +""" + +from typing import Optional + +import aiter +import torch +from aiter import dtypes +from torch import nn + +from atom.config import get_current_atom_config +from atom.model_ops.minimax_m3.sparse_attn import ( + ASM_PAGE_SIZE, + PAGES_PER_SPARSE_BLOCK, + SPARSE_BLOCK_SIZE, +) +from atom.plugin.vllm.attention.backend import ( + MiniMaxM3SparseAttentionBackend, + SparseMHAIndexerBackend, +) +from atom.plugin.vllm.attention.layer_common import ( + _register_vllm_static_forward_context, +) +from atom.utils import mark_spliting_op +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase + + +def minimax_m3_sparse_attention_fake( + qkv: torch.Tensor, + positions: torch.Tensor, + layer_name: str, + output_hidden_size: int, +) -> torch.Tensor: + del positions, layer_name + return qkv.new_empty((qkv.shape[0], output_hidden_size)) + + +@mark_spliting_op( + is_custom=True, + gen_fake=minimax_m3_sparse_attention_fake, + mutates_args=[], +) +def minimax_m3_sparse_attention( + qkv: torch.Tensor, + positions: torch.Tensor, + layer_name: str, + output_hidden_size: int, +) -> torch.Tensor: + from vllm.forward_context import get_forward_context + + layer = get_forward_context().no_compile_layers[layer_name] + output = qkv.new_empty((qkv.shape[0], output_hidden_size)) + return layer._forward_with_output(qkv, positions, output) + + +class MiniMaxM3SparseIndexerCache(nn.Module, AttentionLayerBase): + """Key-only index cache owned by MiniMax-M3 sparse attention.""" + + def __init__( + self, + *, + layer_name: str, + head_dim: int, + kv_cache_dtype: str, + ) -> None: + from vllm.v1.attention.backend import AttentionType + from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype + + super().__init__() + atom_config = get_current_atom_config() + vllm_config = atom_config.plugin_config.vllm_config + self.layer_name = layer_name + self.prefix = layer_name + self.attn_type = AttentionType.DECODER + self.attn_backend = SparseMHAIndexerBackend + self.kv_cache_dtype = kv_cache_dtype + self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( + kv_cache_dtype, vllm_config.model_config + ) + self.num_kv_heads = 1 + self.head_size = head_dim + self.head_size_v = head_dim + self.sliding_window = -1 + self.kv_cache = torch.tensor([]) + _register_vllm_static_forward_context(self) + + @property + def impl(self): + return self + + def get_attn_backend(self): + return self.attn_backend + + def get_kv_cache_spec(self, vllm_config): + from vllm.v1.kv_cache_interface import MLAAttentionSpec + + block_size = vllm_config.cache_config.block_size + if block_size != SPARSE_BLOCK_SIZE: + raise ValueError( + f"MiniMax-M3 sparse index block size must be {SPARSE_BLOCK_SIZE}." + ) + + return MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + ) + + +AttentionLayerBase.register(MiniMaxM3SparseIndexerCache) + + +class MiniMaxM3SparseAttentionForVllm(nn.Module, AttentionLayerBase): + """MiniMax-M3 sparse attention backend for ATOM models under vLLM. + + This intentionally depends only on the generic ATOM vLLM attention stack + under ``atom.plugin.vllm.attention``. Do not depend on model-local MiniMax-M3 + backend modules here: that model directory is not part of the long-term ATOM + backend surface. + """ + + is_indexed_sparse_attention = True + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]] = None, + kv_cache_dtype: str = "bf16", + layer_num: int = 0, + use_mla: bool = False, + rotary_emb: Optional[nn.Module] = None, + prefix: Optional[str] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + cache_config=None, + quant_config=None, + index_q_norm: Optional[nn.Module] = None, + index_k_norm: Optional[nn.Module] = None, + index_rotary_emb: Optional[nn.Module] = None, + index_q_size: int = 0, + index_head_dim: int = 0, + topk: int = 0, + init_blocks: int = 0, + local_blocks: int = 0, + skip_index_topk: bool = False, + sparse_layer_ordinal: int = -1, + impl_cls=None, + **kwargs, + ) -> None: + super().__init__() + del ( + alibi_slopes, + use_mla, + quant_config, + index_rotary_emb, + sparse_layer_ordinal, + impl_cls, + kwargs, + ) + from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype + + atom_config = get_current_atom_config() + if atom_config is None or atom_config.plugin_config is None: + raise RuntimeError("atom_config with vLLM plugin_config is required") + + # ATOM's MiniMax-M3 sparse layer historically passes CacheConfig through + # the kv_cache_dtype argument name used by atom.model_ops.base_attention. + if cache_config is None and hasattr(kv_cache_dtype, "cache_dtype"): + cache_config = kv_cache_dtype + cache_dtype = ( + cache_config.cache_dtype if cache_config is not None else kv_cache_dtype + ) + if cache_config is not None: + block_size = getattr(cache_config, "block_size", SPARSE_BLOCK_SIZE) + if block_size != SPARSE_BLOCK_SIZE: + raise ValueError( + f"MiniMax-M3 sparse block size must be {SPARSE_BLOCK_SIZE}." + ) + self.layer_name = prefix if prefix is not None else f"M3_SPARSE_{layer_num}" + self.attn_backend = MiniMaxM3SparseAttentionBackend + self.kv_cache_dtype = cache_dtype + self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( + cache_dtype, atom_config.plugin_config.vllm_config.model_config + ) + self.kv_cache = torch.tensor([]) + self.k_scale = self.v_scale = None + self.kv_scale = torch.tensor(1.0, dtype=torch.float32) + + self.num_heads = num_heads + self.head_dim = head_dim + self.head_size = head_dim + self.head_size_v = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.layer_num = layer_num + self.rotary_emb = rotary_emb + self.q_norm = q_norm + self.k_norm = k_norm + self.index_q_norm = index_q_norm + self.index_k_norm = index_k_norm + self.index_q_size = index_q_size + self.index_head_dim = index_head_dim + self.num_idx_heads = num_kv_heads + self.topk = topk + self.init_blocks = init_blocks + self.local_blocks = local_blocks + self.skip_index_topk = skip_index_topk + self._cached_topk: tuple | None = None + self._cached_topk_key: tuple | None = None + + if self.head_dim != 128: + raise ValueError("MiniMax-M3 sparse attention requires head_dim == 128.") + if index_q_norm is None or index_k_norm is None: + raise ValueError("MiniMax-M3 sparse attention requires index norms.") + if index_head_dim <= 0 or index_q_size <= 0 or topk <= 0: + raise ValueError( + "MiniMax-M3 sparse attention requires index dimensions/topk." + ) + + self.index_cache_layer = MiniMaxM3SparseIndexerCache( + layer_name=f"{self.layer_name}.index_cache", + head_dim=index_head_dim, + kv_cache_dtype="auto", + ) + _register_vllm_static_forward_context(self) + + @property + def impl(self): + return self + + def get_attn_backend(self): + return self.attn_backend + + def get_kv_cache_spec(self, vllm_config): + from vllm.v1.kv_cache_interface import FullAttentionSpec + + block_size = vllm_config.cache_config.block_size + if block_size != SPARSE_BLOCK_SIZE: + raise ValueError( + f"MiniMax-M3 sparse block size must be {SPARSE_BLOCK_SIZE}." + ) + + return FullAttentionSpec( + block_size=block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_dim, + head_size_v=self.head_dim, + dtype=self.kv_cache_torch_dtype, + ) + + @staticmethod + def _main_metadata(): + metadata = get_forward_context().attn_metadata + return metadata + + def _metadata_for_layer(self): + metadata = self._main_metadata() + if not isinstance(metadata, dict): + return metadata, metadata + return metadata.get(self.layer_name), metadata.get( + self.index_cache_layer.layer_name + ) + + def _validate_bound_sparse_state(self, main_metadata, index_metadata) -> None: + if main_metadata is None: + raise ValueError("MiniMax-M3 sparse attention metadata is required.") + if index_metadata is None: + raise ValueError("MiniMax-M3 sparse index metadata is required.") + if self.kv_cache.numel() == 0 or self.index_cache_layer.kv_cache.numel() == 0: + # vLLM profiling calls can run before cache binding; the caller + # handles this by returning zero outputs. + return + + if self.kv_cache.ndim != 5: + raise ValueError( + "MiniMax-M3 sparse KV cache must have shape " + "[2, num_blocks, block_size, num_kv_heads, head_dim]." + ) + if self.kv_cache.shape[0] != 2: + raise ValueError("MiniMax-M3 sparse KV cache must store K and V.") + if self.kv_cache.shape[2] != SPARSE_BLOCK_SIZE: + raise ValueError( + f"MiniMax-M3 sparse KV block size must be {SPARSE_BLOCK_SIZE}." + ) + if self.kv_cache.shape[3] != self.num_kv_heads: + raise ValueError("MiniMax-M3 sparse KV cache head count mismatch.") + if self.kv_cache.shape[4] != self.head_dim: + raise ValueError("MiniMax-M3 sparse KV cache head dim mismatch.") + + if self.index_cache_layer.kv_cache.ndim != 3: + raise ValueError( + "MiniMax-M3 sparse index cache must have shape " + "[num_blocks, block_size, index_head_dim]." + ) + if self.index_cache_layer.kv_cache.shape[1] != SPARSE_BLOCK_SIZE: + raise ValueError( + f"MiniMax-M3 index cache block size must be {SPARSE_BLOCK_SIZE}." + ) + if self.index_cache_layer.kv_cache.shape[2] != self.index_head_dim: + raise ValueError("MiniMax-M3 index cache head dim mismatch.") + + def _ensure_fp8_scales(self, kv_cache: torch.Tensor): + if self.kv_cache_dtype != "fp8": + return None, None + _kv, num_blocks, block_size, num_kv_heads, _head_dim = kv_cache.shape + expected_shape = (num_blocks, num_kv_heads, block_size) + if ( + self.k_scale is None + or self.v_scale is None + or self.k_scale.shape != expected_shape + or self.k_scale.device != kv_cache.device + ): + self.kv_scale = torch.zeros( + 2, + num_blocks, + num_kv_heads, + block_size, + dtype=dtypes.fp32, + device=kv_cache.device, + ) + self.k_scale = self.kv_scale[0] + self.v_scale = self.kv_scale[1] + return self.k_scale, self.v_scale + + def _page16_shuffle_cache_for_sparse_kernel( + self, + ) -> tuple[torch.Tensor, torch.Tensor, object, object]: + _kv, num_blocks, block_size, num_kv_heads, head_dim = self.kv_cache.shape + if block_size != SPARSE_BLOCK_SIZE: + raise ValueError("MiniMax-M3 sparse cache must use page size 128.") + k_cache, v_cache = self.kv_cache.unbind(0) + if self.kv_cache_dtype == "fp8": + target_dtype = dtypes.d_dtypes[self.kv_cache_dtype] + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + x = 16 // k_cache.element_size() + num_phys16 = num_blocks * PAGES_PER_SPARSE_BLOCK + k_cache = k_cache.view( + num_phys16, + num_kv_heads, + head_dim // x, + ASM_PAGE_SIZE, + x, + ) + v_cache = v_cache.view( + num_phys16, + num_kv_heads, + ASM_PAGE_SIZE // x, + head_dim, + x, + ) + if self.kv_cache_dtype == "fp8": + k_scale = self.k_scale.view(num_phys16, num_kv_heads, ASM_PAGE_SIZE) + v_scale = self.v_scale.view(num_phys16, num_kv_heads, ASM_PAGE_SIZE) + else: + k_scale = v_scale = None + return k_cache, v_cache, k_scale, v_scale + + def _insert_qkv_and_index( + self, + qkv: torch.Tensor, + positions: torch.Tensor, + main_metadata, + index_metadata, + ) -> tuple[torch.Tensor, torch.Tensor]: + from atom.models.minimax_m3 import _minimax_m3_cos_sin_cache + + if self.kv_cache.numel() == 0 or self.index_cache_layer.kv_cache.numel() == 0: + num_tokens = qkv.shape[0] + return ( + qkv.new_zeros((num_tokens, self.q_size)), + qkv.new_zeros((num_tokens, self.index_q_size)), + ) + + qkv = qkv.contiguous() + num_tokens = qkv.shape[0] + q_out = qkv.new_empty((num_tokens, self.q_size)) + index_q = qkv.new_empty((num_tokens, self.index_q_size)) + self._ensure_fp8_scales(self.kv_cache) + k_cache, v_cache, k_scale, v_scale = ( + self._page16_shuffle_cache_for_sparse_kernel() + ) + kv_cache_dtype = self.kv_cache_dtype if self.kv_cache_dtype == "fp8" else "auto" + + aiter.fused_qknorm_idxrqknorm( + qkv, + self.q_norm.weight, + self.k_norm.weight, + _minimax_m3_cos_sin_cache(self.rotary_emb, qkv), + positions, + self.num_heads, + self.num_kv_heads, + self.rotary_emb.rotary_dim, + self.q_norm.variance_epsilon, + self.index_q_norm.weight, + self.index_k_norm.weight, + self.num_idx_heads, + slot_mapping=main_metadata.slot_mapping, + kv_cache_k=k_cache, + kv_cache_v=v_cache, + index_cache=self.index_cache_layer.kv_cache, + block_size=k_cache.shape[3], + q_out=q_out, + index_q_out=index_q, + index_slot_mapping=index_metadata.slot_mapping, + kv_cache_dtype=kv_cache_dtype, + k_scale=k_scale if self.kv_cache_dtype == "fp8" else None, + v_scale=v_scale if self.kv_cache_dtype == "fp8" else None, + asm_layout=True, + ) + return q_out, index_q + + def _topk_cache_key(self, phase: str, index_q: torch.Tensor, metadata) -> tuple: + return ( + phase, + tuple(index_q.shape), + index_q.dtype, + index_q.device, + tuple(metadata.block_table.shape), + tuple(metadata.seq_lens.shape), + self.topk, + self.init_blocks, + self.local_blocks, + ) + + def _load_cached_topk(self, key: tuple): + if self.skip_index_topk and self._cached_topk_key == key: + return self._cached_topk + return None + + def _store_cached_topk(self, key: tuple, topk_idx) -> None: + if self.skip_index_topk: + self._cached_topk_key = key + self._cached_topk = topk_idx + + def _decode_topk( + self, + index_q: torch.Tensor, + main_metadata, + index_metadata, + ): + from atom.model_ops.minimax_m3.index_topk import minimax_m3_index_topk_decode + + num_decode_tokens = main_metadata.num_decode_tokens + decode_md = main_metadata.decode + index_decode_md = ( + index_metadata.decode if index_metadata is not None else decode_md + ) + max_query_len = max(1, int(getattr(decode_md, "max_query_len", 1) or 1)) + key = self._topk_cache_key( + "decode", index_q[:num_decode_tokens], index_decode_md + ) + cached = self._load_cached_topk(key) + if cached is not None: + return cached + topk_idx = minimax_m3_index_topk_decode( + index_q[:num_decode_tokens].view( + -1, self.num_idx_heads, self.index_head_dim + ), + self.index_cache_layer.kv_cache, + index_decode_md.block_table, + index_decode_md.seq_lens, + getattr(index_metadata, "max_seq_len", main_metadata.max_seq_len), + self.topk, + self.init_blocks, + self.local_blocks, + self.num_kv_heads, + self.scale, + emit_sparse_block_table=True, + max_query_len=max_query_len, + ) + self._store_cached_topk(key, topk_idx) + return topk_idx + + def _prefill_topk( + self, + index_q: torch.Tensor, + start: int, + stop: int, + main_metadata, + index_metadata, + ): + from atom.model_ops.minimax_m3.index_topk import minimax_m3_index_topk + + prefill_md = main_metadata.prefill + index_prefill_md = ( + index_metadata.prefill if index_metadata is not None else prefill_md + ) + return minimax_m3_index_topk( + index_q[start:stop].view(-1, self.num_idx_heads, self.index_head_dim), + self.index_cache_layer.kv_cache, + index_prefill_md.block_table, + index_prefill_md.cu_seqlens_q, + index_prefill_md.seq_lens, + index_prefill_md.context_lens, + index_prefill_md.max_query_len, + index_prefill_md.max_seq_len, + self.topk, + self.init_blocks, + self.local_blocks, + self.num_kv_heads, + self.scale, + emit_sparse_block_table=True, + ) + + def _run_decode_sparse_attention( + self, + q: torch.Tensor, + index_q: torch.Tensor, + out: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale, + v_scale, + main_metadata, + index_metadata, + ) -> None: + from atom.model_ops.minimax_m3.sparse_attn import ( + minimax_m3_sparse_attn_decode_asm, + ) + + num_decode_tokens = getattr(main_metadata, "num_decode_tokens", 0) + if num_decode_tokens <= 0 or main_metadata.decode is None: + return + topk_idx, sparse_bt, sparse_ctx = self._decode_topk( + index_q, main_metadata, index_metadata + ) + decode_md = main_metadata.decode + minimax_m3_sparse_attn_decode_asm( + q[:num_decode_tokens], + k_cache, + v_cache, + topk_idx, + decode_md.block_table, + decode_md.seq_lens, + self.num_kv_heads, + self.scale, + out[:num_decode_tokens], + k_scale=k_scale, + v_scale=v_scale, + sparse_bt=sparse_bt, + sparse_ctx=sparse_ctx, + ) + + def _run_prefill_sparse_attention( + self, + q: torch.Tensor, + index_q: torch.Tensor, + out: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale, + v_scale, + main_metadata, + index_metadata, + ) -> None: + from atom.model_ops.minimax_m3.sparse_attn import ( + minimax_m3_sparse_attn_prefill_asm, + ) + + num_decode_tokens = getattr(main_metadata, "num_decode_tokens", 0) + num_prefill_tokens = getattr(main_metadata, "num_prefill_tokens", 0) + if num_prefill_tokens <= 0 or main_metadata.prefill is None: + return + start = num_decode_tokens + stop = start + num_prefill_tokens + topk_idx, sparse_bt, sparse_ctx = self._prefill_topk( + index_q, start, stop, main_metadata, index_metadata + ) + prefill_md = main_metadata.prefill + minimax_m3_sparse_attn_prefill_asm( + q[start:stop], + k_cache, + v_cache, + topk_idx, + prefill_md.block_table, + None, + None, + prefill_md.qo_indptr, + self.num_kv_heads, + self.scale, + out[start:stop], + k_scale=k_scale, + v_scale=v_scale, + cu_seqlens_q=prefill_md.cu_seqlens_q, + prefix_lens=prefill_md.context_lens, + sparse_bt=sparse_bt, + sparse_ctx=sparse_ctx, + ) + + def _run_sparse_attention( + self, + query: torch.Tensor, + index_q: torch.Tensor, + output: torch.Tensor, + main_metadata, + index_metadata, + ) -> torch.Tensor: + q = query.view(-1, self.num_heads, self.head_dim) + out = output.view(-1, self.num_heads, self.head_dim) + k_cache, v_cache, k_scale, v_scale = ( + self._page16_shuffle_cache_for_sparse_kernel() + ) + self._run_decode_sparse_attention( + q, + index_q, + out, + k_cache, + v_cache, + k_scale, + v_scale, + main_metadata, + index_metadata, + ) + self._run_prefill_sparse_attention( + q, + index_q, + out, + k_cache, + v_cache, + k_scale, + v_scale, + main_metadata, + index_metadata, + ) + return output + + def _forward_with_output( + self, + qkv: torch.Tensor, + positions: Optional[torch.Tensor] = None, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + main_metadata, index_metadata = self._metadata_for_layer() + num_tokens = qkv.shape[0] + if output is None: + output = qkv.new_empty((num_tokens, self.q_size)) + if main_metadata is None or positions is None: + return output.fill_(0) + actual_tokens = min( + getattr(main_metadata, "num_actual_tokens", num_tokens), num_tokens + ) + if actual_tokens < num_tokens: + output[actual_tokens:].zero_() + index_metadata = index_metadata if index_metadata is not None else main_metadata + self._validate_bound_sparse_state(main_metadata, index_metadata) + if self.kv_cache.numel() == 0 or self.index_cache_layer.kv_cache.numel() == 0: + return output.fill_(0) + q_actual, index_q = self._insert_qkv_and_index( + qkv[:actual_tokens], + positions[:actual_tokens], + main_metadata, + index_metadata, + ) + output[:actual_tokens] = self._run_sparse_attention( + q_actual, + index_q, + output[:actual_tokens], + main_metadata, + index_metadata, + ) + return output + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: Optional[torch.Tensor] = None, + q_scale: Optional[torch.Tensor] = None, + qkv: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + del query, key, value, q_scale, kwargs + if qkv is None: + raise ValueError("MiniMax-M3 sparse vLLM attention requires packed qkv.") + if positions is None: + raise ValueError("positions is required for MiniMax-M3 sparse attention.") + return torch.ops.aiter.minimax_m3_sparse_attention( + qkv, + positions, + self.layer_name, + self.q_size, + ) + + +class MiniMaxM3DenseAttentionForVllm(nn.Module, AttentionLayerBase): + """MiniMax-M3 dense attention using vLLM's Triton backend contract.""" + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]] = None, + kv_cache_dtype: str = "bf16", + layer_num: int = 0, + use_mla: bool = False, + rotary_emb: Optional[nn.Module] = None, + prefix: Optional[str] = None, + q_norm: Optional[nn.Module] = None, + k_norm: Optional[nn.Module] = None, + cache_config=None, + quant_config=None, + **kwargs, + ) -> None: + super().__init__() + del use_mla, cache_config, quant_config, kwargs + from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype + from vllm.v1.attention.backend import AttentionType + from vllm.v1.attention.backends.triton_attn import ( + TritonAttentionBackend, + TritonAttentionImpl, + ) + + atom_config = get_current_atom_config() + if atom_config is None or atom_config.plugin_config is None: + raise RuntimeError("atom_config with vLLM plugin_config is required") + vllm_config = atom_config.plugin_config.vllm_config + cache_config = atom_config.plugin_config.vllm_cache_config + cache_dtype = ( + cache_config.cache_dtype if cache_config is not None else kv_cache_dtype + ) + self.layer_name = prefix if prefix is not None else f"M3_DENSE_{layer_num}" + self.attn_type = AttentionType.DECODER + self.attn_backend = TritonAttentionBackend + self.kv_cache_dtype = cache_dtype + self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( + cache_dtype, vllm_config.model_config + ) + self.calculate_kv_scales = ( + cache_config.calculate_kv_scales if cache_config is not None else False + ) + self.quant_config = None + self.kv_cache = torch.tensor([]) + self.has_sink = False + self.dtype = torch.get_default_dtype() + + self.num_heads = num_heads + self.head_dim = head_dim + self.head_size = head_dim + self.head_size_v = head_dim + self.scale = scale + self.num_kv_heads = num_kv_heads + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.rotary_emb = rotary_emb + self.q_norm = q_norm + self.k_norm = k_norm + self.impl = TritonAttentionImpl( + num_heads, + head_dim, + scale, + num_kv_heads, + alibi_slopes, + None, # sliding_window + cache_dtype, + None, # logits_soft_cap + self.attn_type, + None, # kv_sharing_target_layer_name + ) + from vllm.model_executor.layers.attention.attention import _init_kv_cache_quant + + _init_kv_cache_quant(self, None, self.layer_name) + _register_vllm_static_forward_context(self) + + @property + def layer_name(self): + return self._layer_name + + @layer_name.setter + def layer_name(self, value): + self._layer_name = value + + def get_attn_backend(self): + return self.attn_backend + + def get_kv_cache_spec(self, vllm_config): + from vllm.v1.kv_cache_interface import FullAttentionSpec, get_kv_quant_mode + + return FullAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + head_size_v=self.head_size_v, + dtype=self.kv_cache_torch_dtype, + kv_quant_mode=get_kv_quant_mode(self.kv_cache_dtype), + ) + + def process_weights_after_loading( + self, act_dtype: torch.dtype = torch.bfloat16 + ) -> None: + from vllm.model_executor.layers.attention.attention import ( + set_default_quant_scales, + ) + + self.impl.process_weights_after_loading(act_dtype) + set_default_quant_scales(self, register_buffer=False) + + def _qk_norm_rope( + self, + qkv: torch.Tensor, + positions: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from atom.models.minimax_m3 import _minimax_m3_cos_sin_cache + + qkv = qkv.contiguous() + aiter.fused_qknorm_idxrqknorm( + qkv, + self.q_norm.weight, + self.k_norm.weight, + _minimax_m3_cos_sin_cache(self.rotary_emb, qkv), + positions, + self.num_heads, + self.num_kv_heads, + self.rotary_emb.rotary_dim, + self.q_norm.variance_epsilon, + num_index_heads=0, + ) + return tuple( + tensor.contiguous() + for tensor in qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: Optional[torch.Tensor] = None, + q_scale: Optional[torch.Tensor] = None, + qkv: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + del query, key, value, q_scale, kwargs + if qkv is None: + raise ValueError("MiniMax-M3 dense vLLM attention requires packed qkv.") + if positions is None: + raise ValueError("positions is required for MiniMax-M3 dense attention.") + query, key, value = self._qk_norm_rope(qkv, positions) + if self.calculate_kv_scales and key is not None and value is not None: + from vllm.model_executor.layers.attention.attention import ( + _encode_layer_name, + ) + + torch.ops.vllm.maybe_calc_kv_scales( + query, key, value, _encode_layer_name(self.layer_name) + ) + self.calculate_kv_scales = False + + output_shape = torch.Size((query.shape[0], self.num_heads * self.head_size_v)) + output = torch.empty(output_shape, dtype=query.dtype, device=query.device) + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size_v) + output = output.view(-1, self.num_heads, self.head_size_v) + + from vllm.model_executor.layers.attention.attention import _encode_layer_name + + encoded = _encode_layer_name(self.layer_name) + kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(key, value, encoded) + torch.ops.vllm.unified_attention_with_output( + query, + key, + value, + output, + encoded, + kv_cache_dummy_dep=kv_cache_dummy_dep, + ) + return output.view(-1, self.num_heads * self.head_size_v) diff --git a/atom/plugin/vllm/model_wrapper.py b/atom/plugin/vllm/model_wrapper.py index 5413614d5..890723885 100644 --- a/atom/plugin/vllm/model_wrapper.py +++ b/atom/plugin/vllm/model_wrapper.py @@ -131,6 +131,8 @@ def _maybe_set_v4_expert_dtype(atom_config, vllm_config) -> None: "KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration_", "MiniMaxM2ForCausalLM": "atom.models.minimax_m2:MiniMaxM2ForCausalLM", "DeepseekV4ForCausalLM": "atom.plugin.vllm.models.deepseek_v4:DeepseekV4ForCausalLM", + "MiniMaxM3SparseForCausalLM": "atom.models.minimax_m3:MiniMaxM3SparseForCausalLM", + "MiniMaxM3SparseForConditionalGeneration": "atom.models.minimax_m3:MiniMaxM3SparseForConditionalGeneration", } diff --git a/atom/plugin/vllm/register.py b/atom/plugin/vllm/register.py index f245f2c5b..7048ee13d 100644 --- a/atom/plugin/vllm/register.py +++ b/atom/plugin/vllm/register.py @@ -1,7 +1,8 @@ -from typing import Optional import logging +from typing import Optional import torch +from transformers import AutoConfig, PretrainedConfig from atom.plugin.prepare import _set_framework_backbone from atom.utils import envs from atom.plugin.vllm.spec_decode_patch import apply_vllm_spec_decode_patch @@ -35,13 +36,86 @@ "KimiK25ForConditionalGeneration": "atom.plugin.vllm.models.kimi_k25:KimiK25ForConditionalGeneration", "MiniMaxM2ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, "DeepseekV4ForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, + "MiniMaxM3SparseForCausalLM": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, + "MiniMaxM3SparseForConditionalGeneration": ATOM_MOE_CAUSAL_LM_MODEL_WRAPPER, } +class MiniMaxM3Config(PretrainedConfig): + """Minimal local config shim for MiniMax-M3 VL checkpoints.""" + + model_type = "minimax_m3_vl" + + def __init__( + self, + text_config: dict | PretrainedConfig | None = None, + vision_config: dict | None = None, + **kwargs, + ): + if isinstance(text_config, dict): + text_config = PretrainedConfig(**text_config) + + self.text_config = text_config + self.vision_config = vision_config + self.hidden_size = getattr(text_config, "hidden_size", None) + + super().__init__(**kwargs) + + def _set_plugin_mode() -> None: _set_framework_backbone("vllm") +def _register_hf_configs() -> None: + try: + AutoConfig.register(MiniMaxM3Config.model_type, MiniMaxM3Config) + except ValueError as exc: + if "already used by a Transformers config" not in str(exc): + raise + + +def _register_mxfp8_quantization_config() -> None: + """Let ATOM-owned MXFP8 checkpoints pass vLLM config validation. + + vLLM uses the same name, "mxfp8", for an online-quant shorthand. MiniMax-M3 + MXFP8 checkpoints store "quant_method": "mxfp8" in config.json, and ATOM + parses/loads those weights itself. Registering this no-op config prevents + vLLM from routing the checkpoint config through OnlineQuantizationConfig. + """ + from vllm.model_executor.layers.quantization import register_quantization_config + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, + ) + + @register_quantization_config("mxfp8") + class AtomMxfp8Config(QuantizationConfig): + @classmethod + def from_config(cls, config): + return cls() + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_name(cls): + return "mxfp8" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> QuantizeMethodBase | None: + return None + + def register_platform() -> Optional[str]: if disable_vllm_plugin: @@ -56,6 +130,9 @@ def register_platform() -> Optional[str]: # branch runs and requires vllm.model_executor.models.qwen3_5, which may be # absent. Backbone is set in register_model() for real vLLM runs. + _register_hf_configs() + _register_mxfp8_quantization_config() + # return the ATOM platform to vllm return "atom.plugin.vllm.platform.ATOMPlatform" diff --git a/recipes/atom_vllm/MiniMax-M3.md b/recipes/atom_vllm/MiniMax-M3.md new file mode 100644 index 000000000..b62450f06 --- /dev/null +++ b/recipes/atom_vllm/MiniMax-M3.md @@ -0,0 +1,107 @@ +# MiniMax-M3 with ATOM vLLM Plugin Backend + +This recipe shows how to run MiniMax-M3 sparse checkpoints with the ATOM vLLM +plugin backend. For background on the plugin backend, see +[ATOM vLLM Plugin Backend](../../docs/vllm_plugin_backend_guide.md). + +MiniMax-M3 uses the ATOM-owned model implementation and vLLM attention adapters +for both dense and sparse attention layers. + +## Step 1: Pull the OOT Docker + +```bash +docker pull rocm/atom-dev:vllm-latest +``` + +## Step 2: Launch vLLM Server + +The ATOM vLLM plugin backend keeps the standard vLLM CLI, server APIs, and +general usage flow compatible with upstream vLLM. For general server options and +API usage, refer to the [official vLLM documentation](https://docs.vllm.ai/en/latest/). + +The example below serves the MXFP8 checkpoint on four GPUs. Use your local +checkpoint path or the corresponding model id for `MODEL`. + +```bash +MODEL=/path/to/MiniMax-M3-MXFP8 +TP=4 +PORT=8001 + +vllm serve "${MODEL}" \ + --dtype auto \ + --load-format auto \ + --host localhost \ + --port "${PORT}" \ + --tensor-parallel-size "${TP}" \ + --gpu-memory-utilization 0.85 \ + --max-model-len 32768 \ + --max-num-batched-tokens 32768 \ + --block-size 128 \ + --no-async-scheduling \ + --kv-cache-dtype auto \ + --no-enable-prefix-caching \ + --language-model-only \ + --no-trust-remote-code \ + --additional-config '{"online_quant_config": {"global_quant_config": "ptpc_fp8", "exclude_layer": ["lm_head", "model.embed_tokens", "vision_tower", "multi_modal_projector", "patch_merge_mlp", "*block_sparse_moe"]}}' \ + --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' +``` + +For the MXFP4 checkpoint, change `MODEL` and omit the MXFP8 online quantization +config: + +```bash +MODEL=/path/to/MiniMax-M3-MXFP4 + +vllm serve "${MODEL}" \ + --dtype auto \ + --load-format auto \ + --host localhost \ + --port "${PORT}" \ + --tensor-parallel-size "${TP}" \ + --gpu-memory-utilization 0.85 \ + --max-model-len 32768 \ + --max-num-batched-tokens 32768 \ + --block-size 128 \ + --no-async-scheduling \ + --kv-cache-dtype auto \ + --no-enable-prefix-caching \ + --language-model-only \ + --no-trust-remote-code \ + --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' +``` + +To validate FP8 KV cache, set `--kv-cache-dtype fp8` in either command. + +Notes: +- Keep `--block-size 128`; MiniMax-M3 sparse attention assumes 128-token sparse + blocks. +- `--no-trust-remote-code` is expected because ATOM registers the MiniMax-M3 + model classes used by the vLLM plugin path. +- `--language-model-only` serves the language model path for MiniMax-M3 VL + checkpoints. + +## Step 3: Accuracy Validation + +The accuracy can be verified on GSM8K with the chat-completions API: + +```bash +BS=65 + +lm_eval \ + --model local-chat-completions \ + --model_args "model=${MODEL},base_url=http://localhost:${PORT}/v1/chat/completions,num_concurrent=32,max_gen_toks=2048" \ + --tasks gsm8k \ + --num_fewshot 5 \ + --batch_size "${BS}" \ + --apply_chat_template \ + --fewshot_as_multiturn +``` + +Reference average results from five local GSM8K runs are shown below. + +| Config | `flexible-extract` avg | `strict-match` avg | +| --- | ---: | ---: | +| MIXFP8 | 0.9503 | 0.9510 | +| MIXFP4 | 0.9399 | 0.9407 | +| MIXFP8-kv_fp8 | 0.9480 | 0.9487 | +| MIXFP4-kv_fp8 | 0.9439 | 0.9445 |