Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions atom/plugin/vllm/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from vllm.v1.attention.backends.mla.prefill.base import MLAPrefillBackend
from atom.model_ops.minimax_m3.sparse_attn import SPARSE_BLOCK_SIZE


class AiterMhaBackendForVllm:
Expand Down Expand Up @@ -300,6 +301,159 @@ def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)


class MiniMaxM3SparseAttentionBackend:
"""vLLM-facing sparse MHA backend surface for MiniMax-M3."""

accept_output_buffer: bool = True
supported_dtypes: list = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: list = ["bfloat16", "fp8", "fp8_e4m3"]
forward_includes_kv_cache_update: bool = True

@staticmethod
def get_name() -> str:
return "MINIMAX_M3_SPARSE"

@staticmethod
def get_supported_kernel_block_sizes():
return [SPARSE_BLOCK_SIZE]

@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
return block_size is None or block_size == SPARSE_BLOCK_SIZE

@classmethod
def get_kv_cache_block_dim(
cls,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> int:
sentinel = 1234567
shape = cls.get_kv_cache_shape(
sentinel,
block_size,
num_kv_heads,
head_size,
cache_dtype_str=cache_dtype_str,
)
return shape.index(sentinel)

@classmethod
def get_preferred_block_size(cls, default_block_size: int) -> int:
return SPARSE_BLOCK_SIZE

@staticmethod
def get_builder_cls() -> Type:
from atom.plugin.vllm.attention.metadata import (
MinimaxM3SparseAttentionMetadataBuilder,
)

return MinimaxM3SparseAttentionMetadataBuilder

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [128]

@classmethod
def is_sparse(cls) -> bool:
return True

@classmethod
def is_mla(cls) -> bool:
return False

@classmethod
def is_ssm(cls) -> bool:
return False

@staticmethod
def get_required_kv_cache_layout():
return None

@classmethod
def supports_alibi_sqrt(cls) -> bool:
return False

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size != SPARSE_BLOCK_SIZE:
raise ValueError(
f"MiniMax-M3 sparse block size must be {SPARSE_BLOCK_SIZE}."
)
return (2, num_blocks, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
raise NotImplementedError

@staticmethod
def get_impl_cls():
from atom.plugin.vllm.attention.minimax_m3_attnetion import (
MiniMaxM3SparseAttentionForVllm,
)

return MiniMaxM3SparseAttentionForVllm

@classmethod
def full_cls_name(cls) -> tuple[str, str]:
return (cls.__module__, cls.__qualname__)


class SparseMHAIndexerBackend(AiterMlaBackendForVllm):
"""vLLM-facing key-only indexer backend surface for MiniMax-M3."""

@staticmethod
def get_name() -> str:
return "MINIMAX_M3_SPARSE_INDEXER"

@staticmethod
def get_supported_kernel_block_sizes():
return [SPARSE_BLOCK_SIZE]

@classmethod
def get_preferred_block_size(cls, default_block_size: int) -> int:
return SPARSE_BLOCK_SIZE

@staticmethod
def get_builder_cls() -> Type:
from atom.plugin.vllm.attention.metadata import (
MinimaxM3SparseAttentionMetadataBuilder,
)

return MinimaxM3SparseAttentionMetadataBuilder

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [64, 128, 256]

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)

@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
) -> tuple[int, ...]:
if include_num_layers_dimension:
raise NotImplementedError
return (0, 1, 2)


class GDNAttentionBackend:
@staticmethod
def get_name() -> str:
Expand Down
49 changes: 49 additions & 0 deletions atom/plugin/vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,49 @@
)
from atom.plugin.vllm.attention import ops as _atom_vllm_attention_ops # noqa: F401

_MINIMAX_M3_MODEL_TYPES = {"minimax_m3", "minimax_m3_text", "minimax_m3_vl"}


def _is_minimax_m3_model(atom_config) -> bool:
hf_config = getattr(atom_config, "hf_config", None)
model_type = getattr(hf_config, "model_type", "")
text_config = getattr(hf_config, "text_config", None)
text_model_type = getattr(text_config, "model_type", "")
return (
model_type in _MINIMAX_M3_MODEL_TYPES
or text_model_type in _MINIMAX_M3_MODEL_TYPES
)


def _minimax_m3_attention_cls_for_vllm(atom_config, kwargs):
if not _is_minimax_m3_model(atom_config):
return None
impl_cls = kwargs.get("impl_cls")
if impl_cls is not None:
from atom.model_ops.attention_mha import (
SparseMHAPagedAttentionImpl as AtomSparseMHAPagedAttentionImpl,
)

if impl_cls is AtomSparseMHAPagedAttentionImpl:
from atom.plugin.vllm.attention.minimax_m3_attnetion import (
MiniMaxM3SparseAttentionForVllm,
)

return MiniMaxM3SparseAttentionForVllm
return None

if (
kwargs.get("rotary_emb") is not None
and kwargs.get("q_norm") is not None
and kwargs.get("k_norm") is not None
):
from atom.plugin.vllm.attention.minimax_m3_attnetion import (
MiniMaxM3DenseAttentionForVllm,
)

return MiniMaxM3DenseAttentionForVllm
return None


class AttentionForVllm:
"""Factory for ATOM-owned attention layers running under vLLM."""
Expand All @@ -30,4 +73,10 @@ def __new__(
*args, mla_modules=mla_modules, **kwargs
)
return AttentionForVllmMLA(*args, mla_modules=mla_modules, **kwargs)
minimax_m3_attention_cls = _minimax_m3_attention_cls_for_vllm(
atom_config, kwargs
)
if minimax_m3_attention_cls is not None:
return minimax_m3_attention_cls(*args, **kwargs)
kwargs.pop("impl_cls", None)
return AttentionForVllmMHA(*args, **kwargs)
Loading
Loading