Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ class MLAModules:
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
indexer: Optional[torch.nn.Module]
# Model-level sparse flag. A v3.2 / GLM-5.2 model runs sparse MLA on ALL its
# layers. GLM-5.2 IndexShare "shared" layers carry no indexer module yet must
# still run sparse attention (reusing the prior "full" layer's top-k), so
# sparsity must be derived from the model, not from whether this layer owns
# an indexer. Defaults keep non-sparse models unchanged.
is_sparse: bool = False
topk_tokens: Optional[int] = None


def dynamic_per_batched_tensor_quant(
Expand Down Expand Up @@ -231,10 +238,20 @@ def __init__(
self.one_scale = torch.tensor(1.0, dtype=torch.float32)
self._k_scale = self.one_scale
self._q_scale = self.one_scale
self.is_sparse_mla = mla_modules.indexer is not None
# Derive sparsity from the model-level flag, not from whether THIS layer
# owns an indexer: GLM-5.2 IndexShare "shared" layers have indexer=None
# but must still run sparse MLA, reusing the prior "full" layer's top-k.
# (`mla_modules.is_sparse` defaults False, so non-sparse models and the
# `indexer is not None` fallback keep their previous behavior.)
self.is_sparse_mla = mla_modules.is_sparse or (mla_modules.indexer is not None)
self.topk_tokens = (
mla_modules.indexer.topk_tokens if mla_modules.indexer is not None else None
mla_modules.indexer.topk_tokens
if mla_modules.indexer is not None
else mla_modules.topk_tokens
)
# Shared layers have no indexer buffer at construction; the metadata
# builder rebinds it to the shared `_sparse_kv_indices_gpu` at runtime,
# so the layer reads the prior full layer's selected indices.
self.sparse_kv_indices_buffer = (
mla_modules.indexer.sparse_kv_indices_buffer
if mla_modules.indexer is not None
Expand Down Expand Up @@ -874,6 +891,24 @@ def _forward_prefill_mla(
sm_scale=self.scale,
q_scale=self._q_scale,
kv_scale=self._k_scale,
work_meta_data=getattr(
attn_metadata, "sparse_prefill_work_meta_data", None
),
work_indptr=getattr(
attn_metadata, "sparse_prefill_work_indptr", None
),
work_info_set=getattr(
attn_metadata, "sparse_prefill_work_info_set", None
),
reduce_indptr=getattr(
attn_metadata, "sparse_prefill_reduce_indptr", None
),
reduce_final_map=getattr(
attn_metadata, "sparse_prefill_reduce_final_map", None
),
reduce_partial_map=getattr(
attn_metadata, "sparse_prefill_reduce_partial_map", None
),
)
else:
mla_prefill_fwd(
Expand Down
97 changes: 89 additions & 8 deletions atom/model_ops/attentions/aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,40 @@ def __init__(self, model_runner):
dtype=torch.int32,
device=self.device,
)
(
(spp_wmd_size, spp_wmd_type),
(spp_wi_size, spp_wi_type),
(spp_wis_size, spp_wis_type),
(spp_ri_size, spp_ri_type),
(spp_rfm_size, spp_rfm_type),
(spp_rpm_size, spp_rpm_type),
) = get_mla_metadata_info_v1(
self.max_num_batched_tokens,
1, # sparse prefill treats each query token as q_len=1
self.padded_num_attention_heads,
self.dtype_q,
self.dtype_kv,
is_sparse=True,
fast_mode=True,
)
Comment on lines +208 to +216
mla_metadata["sparse_prefill_work_meta_data"] = torch.empty(
spp_wmd_size, dtype=spp_wmd_type, device=self.device
)
mla_metadata["sparse_prefill_work_indptr"] = torch.empty(
spp_wi_size, dtype=spp_wi_type, device=self.device
)
mla_metadata["sparse_prefill_work_info_set"] = torch.empty(
spp_wis_size, dtype=spp_wis_type, device=self.device
)
mla_metadata["sparse_prefill_reduce_indptr"] = torch.empty(
spp_ri_size, dtype=spp_ri_type, device=self.device
)
mla_metadata["sparse_prefill_reduce_final_map"] = torch.empty(
spp_rfm_size, dtype=spp_rfm_type, device=self.device
)
mla_metadata["sparse_prefill_reduce_partial_map"] = torch.empty(
spp_rpm_size, dtype=spp_rpm_type, device=self.device
)

if self.is_sparse and max_seqlen_qo > 1:
# Allocate a second set of persistent work buffers for sparse MTP
Expand Down Expand Up @@ -744,21 +778,33 @@ def prepare_prefill(self, batch: ScheduledBatch):
self.prepare_block_tables(batch)
attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs)
counts = var["cu_seqlens_q"].np[1 : bs + 1] - var["cu_seqlens_q"].np[:bs]
local_offsets = np.concatenate(
[np.arange(s, dtype=np.int32) for s in counts]
)
if attn_metadata.has_cached:
# Full context (cached + new): use cu_seqlens_k for indexer
# Full context (cached + new): each query token can see the cached
# prefix plus previous query tokens in this chunk, not future chunk
# tokens.
seq_starts = var["cu_seqlens_k"].np[:bs]
seq_lens = var["cu_seqlens_k"].np[1 : bs + 1] - seq_starts
cached_lens = seq_lens - counts
repeated_seq_starts = np.repeat(seq_starts, counts)
repeated_cached_lens = np.repeat(cached_lens, counts)
var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat(
var["cu_seqlens_k"].np[:bs], counts
seq_starts, counts
)
var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = np.repeat(
var["cu_seqlens_k"].np[1 : bs + 1], counts
var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = (
repeated_seq_starts + repeated_cached_lens + local_offsets + 1
)
sparse_counts = repeated_cached_lens + local_offsets + 1
else:
var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = (
np.arange(sum_scheduled_tokens, dtype=np.int32) + 1
)
var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat(
var["cu_seqlens_q"].np[:bs], counts
)
sparse_counts = local_offsets + 1
attn_metadata.cu_seqlen_ks = var["cu_seqlen_ks"].copy_to_gpu(
sum_scheduled_tokens
)
Expand All @@ -780,15 +826,50 @@ def prepare_prefill(self, batch: ScheduledBatch):
)
var["sparse_kv_indptr"].np[0] = 0
var["sparse_kv_indptr"].np[1 : sum_scheduled_tokens + 1] = np.cumsum(
np.minimum(
np.concatenate([np.arange(1, s + 1) for s in counts]),
self.index_topk,
),
np.minimum(sparse_counts, self.index_topk),
dtype=np.int32,
)
attn_metadata.sparse_kv_indptr = var["sparse_kv_indptr"].copy_to_gpu(
sum_scheduled_tokens + 1
)
get_mla_metadata_v1(
attn_metadata.sparse_cu_seqlens_q,
attn_metadata.sparse_kv_indptr,
attn_metadata.kv_last_page_lens,
self.padded_num_attention_heads,
1, # nhead_kv
True,
var["sparse_prefill_work_meta_data"],
var["sparse_prefill_work_info_set"],
var["sparse_prefill_work_indptr"],
var["sparse_prefill_reduce_indptr"],
var["sparse_prefill_reduce_final_map"],
var["sparse_prefill_reduce_partial_map"],
page_size=self.block_size,
dtype_q=self.dtype_q,
dtype_kv=self.dtype_kv,
kv_granularity=max(self.block_size, 16),
max_seqlen_qo=1,
uni_seqlen_qo=1,
fast_mode=1,
max_split_per_batch=16,
)
attn_metadata.sparse_prefill_work_meta_data = var[
"sparse_prefill_work_meta_data"
]
attn_metadata.sparse_prefill_work_info_set = var[
"sparse_prefill_work_info_set"
]
attn_metadata.sparse_prefill_work_indptr = var["sparse_prefill_work_indptr"]
attn_metadata.sparse_prefill_reduce_indptr = var[
"sparse_prefill_reduce_indptr"
]
attn_metadata.sparse_prefill_reduce_final_map = var[
"sparse_prefill_reduce_final_map"
]
attn_metadata.sparse_prefill_reduce_partial_map = var[
"sparse_prefill_reduce_partial_map"
]

if hasattr(self.model_runner, "drafter") or attn_metadata.has_cached:
# Populate kv_last_page_lens for full sequence (needed for MLA prefill with
Expand Down
25 changes: 21 additions & 4 deletions atom/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ def _supports_fused_indexer_kernel_config(config: PretrainedConfig) -> bool:
)


def _is_neox_rope_style(
config: PretrainedConfig, interleave_attr: str, default_interleave: bool
) -> bool:
interleave = getattr(config, interleave_attr, default_interleave)
if interleave is None:
interleave = default_interleave
return not bool(interleave)


def _can_fuse_indexer_wk_weights_proj(
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
Expand Down Expand Up @@ -1595,14 +1604,16 @@ def forward(
weights = self.weights_proj(hidden_states)

if not self.use_qk_rope_cache_fusion:
q_pe, _ = torch.split(
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
k = self.k_norm(k)
k_pe, _ = torch.split(
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
q = torch.cat([q_pe, q_nope], dim=-1)
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)

q = q.view(-1, self.head_dim)
q_fp8, q_scale = self.quant_func(q, quant_dtype=dtypes.fp8)
Expand Down Expand Up @@ -1822,7 +1833,7 @@ def __init__(
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
is_neox_style=_is_neox_rope_style(config, "rope_interleave", True),
)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
Expand Down Expand Up @@ -1888,6 +1899,12 @@ def __init__(
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
indexer=self.indexer,
# v3.2 / GLM-5.2 runs sparse MLA on every layer. For GLM-5.2 IndexShare
# "shared" layers self.indexer is None, but they must still run sparse
# attention and reuse the prior full layer's top-k, so flag sparsity at
# the model level rather than per-layer.
is_sparse=self.is_v32,
topk_tokens=(config.index_topk if self.is_v32 else None),
)

self.mla_attn = Attention(
Expand Down
Loading