diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index ee98185f4..ea846ebed 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -169,6 +169,13 @@ class MLAModules: kv_b_proj: torch.nn.Module o_proj: torch.nn.Module indexer: Optional[torch.nn.Module] + # Model-level sparse flag. A v3.2 / GLM-5.2 model runs sparse MLA on ALL its + # layers. GLM-5.2 IndexShare "shared" layers carry no indexer module yet must + # still run sparse attention (reusing the prior "full" layer's top-k), so + # sparsity must be derived from the model, not from whether this layer owns + # an indexer. Defaults keep non-sparse models unchanged. + is_sparse: bool = False + topk_tokens: Optional[int] = None def dynamic_per_batched_tensor_quant( @@ -231,10 +238,20 @@ def __init__( self.one_scale = torch.tensor(1.0, dtype=torch.float32) self._k_scale = self.one_scale self._q_scale = self.one_scale - self.is_sparse_mla = mla_modules.indexer is not None + # Derive sparsity from the model-level flag, not from whether THIS layer + # owns an indexer: GLM-5.2 IndexShare "shared" layers have indexer=None + # but must still run sparse MLA, reusing the prior "full" layer's top-k. + # (`mla_modules.is_sparse` defaults False, so non-sparse models and the + # `indexer is not None` fallback keep their previous behavior.) + self.is_sparse_mla = mla_modules.is_sparse or (mla_modules.indexer is not None) self.topk_tokens = ( - mla_modules.indexer.topk_tokens if mla_modules.indexer is not None else None + mla_modules.indexer.topk_tokens + if mla_modules.indexer is not None + else mla_modules.topk_tokens ) + # Shared layers have no indexer buffer at construction; the metadata + # builder rebinds it to the shared `_sparse_kv_indices_gpu` at runtime, + # so the layer reads the prior full layer's selected indices. self.sparse_kv_indices_buffer = ( mla_modules.indexer.sparse_kv_indices_buffer if mla_modules.indexer is not None @@ -874,6 +891,24 @@ def _forward_prefill_mla( sm_scale=self.scale, q_scale=self._q_scale, kv_scale=self._k_scale, + work_meta_data=getattr( + attn_metadata, "sparse_prefill_work_meta_data", None + ), + work_indptr=getattr( + attn_metadata, "sparse_prefill_work_indptr", None + ), + work_info_set=getattr( + attn_metadata, "sparse_prefill_work_info_set", None + ), + reduce_indptr=getattr( + attn_metadata, "sparse_prefill_reduce_indptr", None + ), + reduce_final_map=getattr( + attn_metadata, "sparse_prefill_reduce_final_map", None + ), + reduce_partial_map=getattr( + attn_metadata, "sparse_prefill_reduce_partial_map", None + ), ) else: mla_prefill_fwd( diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 389502683..b924ea049 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -198,6 +198,40 @@ def __init__(self, model_runner): dtype=torch.int32, device=self.device, ) + ( + (spp_wmd_size, spp_wmd_type), + (spp_wi_size, spp_wi_type), + (spp_wis_size, spp_wis_type), + (spp_ri_size, spp_ri_type), + (spp_rfm_size, spp_rfm_type), + (spp_rpm_size, spp_rpm_type), + ) = get_mla_metadata_info_v1( + self.max_num_batched_tokens, + 1, # sparse prefill treats each query token as q_len=1 + self.padded_num_attention_heads, + self.dtype_q, + self.dtype_kv, + is_sparse=True, + fast_mode=True, + ) + mla_metadata["sparse_prefill_work_meta_data"] = torch.empty( + spp_wmd_size, dtype=spp_wmd_type, device=self.device + ) + mla_metadata["sparse_prefill_work_indptr"] = torch.empty( + spp_wi_size, dtype=spp_wi_type, device=self.device + ) + mla_metadata["sparse_prefill_work_info_set"] = torch.empty( + spp_wis_size, dtype=spp_wis_type, device=self.device + ) + mla_metadata["sparse_prefill_reduce_indptr"] = torch.empty( + spp_ri_size, dtype=spp_ri_type, device=self.device + ) + mla_metadata["sparse_prefill_reduce_final_map"] = torch.empty( + spp_rfm_size, dtype=spp_rfm_type, device=self.device + ) + mla_metadata["sparse_prefill_reduce_partial_map"] = torch.empty( + spp_rpm_size, dtype=spp_rpm_type, device=self.device + ) if self.is_sparse and max_seqlen_qo > 1: # Allocate a second set of persistent work buffers for sparse MTP @@ -744,14 +778,25 @@ def prepare_prefill(self, batch: ScheduledBatch): self.prepare_block_tables(batch) attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs) counts = var["cu_seqlens_q"].np[1 : bs + 1] - var["cu_seqlens_q"].np[:bs] + local_offsets = np.concatenate( + [np.arange(s, dtype=np.int32) for s in counts] + ) if attn_metadata.has_cached: - # Full context (cached + new): use cu_seqlens_k for indexer + # Full context (cached + new): each query token can see the cached + # prefix plus previous query tokens in this chunk, not future chunk + # tokens. + seq_starts = var["cu_seqlens_k"].np[:bs] + seq_lens = var["cu_seqlens_k"].np[1 : bs + 1] - seq_starts + cached_lens = seq_lens - counts + repeated_seq_starts = np.repeat(seq_starts, counts) + repeated_cached_lens = np.repeat(cached_lens, counts) var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat( - var["cu_seqlens_k"].np[:bs], counts + seq_starts, counts ) - var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = np.repeat( - var["cu_seqlens_k"].np[1 : bs + 1], counts + var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = ( + repeated_seq_starts + repeated_cached_lens + local_offsets + 1 ) + sparse_counts = repeated_cached_lens + local_offsets + 1 else: var["cu_seqlen_ke"].np[:sum_scheduled_tokens] = ( np.arange(sum_scheduled_tokens, dtype=np.int32) + 1 @@ -759,6 +804,7 @@ def prepare_prefill(self, batch: ScheduledBatch): var["cu_seqlen_ks"].np[:sum_scheduled_tokens] = np.repeat( var["cu_seqlens_q"].np[:bs], counts ) + sparse_counts = local_offsets + 1 attn_metadata.cu_seqlen_ks = var["cu_seqlen_ks"].copy_to_gpu( sum_scheduled_tokens ) @@ -780,15 +826,50 @@ def prepare_prefill(self, batch: ScheduledBatch): ) var["sparse_kv_indptr"].np[0] = 0 var["sparse_kv_indptr"].np[1 : sum_scheduled_tokens + 1] = np.cumsum( - np.minimum( - np.concatenate([np.arange(1, s + 1) for s in counts]), - self.index_topk, - ), + np.minimum(sparse_counts, self.index_topk), dtype=np.int32, ) attn_metadata.sparse_kv_indptr = var["sparse_kv_indptr"].copy_to_gpu( sum_scheduled_tokens + 1 ) + get_mla_metadata_v1( + attn_metadata.sparse_cu_seqlens_q, + attn_metadata.sparse_kv_indptr, + attn_metadata.kv_last_page_lens, + self.padded_num_attention_heads, + 1, # nhead_kv + True, + var["sparse_prefill_work_meta_data"], + var["sparse_prefill_work_info_set"], + var["sparse_prefill_work_indptr"], + var["sparse_prefill_reduce_indptr"], + var["sparse_prefill_reduce_final_map"], + var["sparse_prefill_reduce_partial_map"], + page_size=self.block_size, + dtype_q=self.dtype_q, + dtype_kv=self.dtype_kv, + kv_granularity=max(self.block_size, 16), + max_seqlen_qo=1, + uni_seqlen_qo=1, + fast_mode=1, + max_split_per_batch=16, + ) + attn_metadata.sparse_prefill_work_meta_data = var[ + "sparse_prefill_work_meta_data" + ] + attn_metadata.sparse_prefill_work_info_set = var[ + "sparse_prefill_work_info_set" + ] + attn_metadata.sparse_prefill_work_indptr = var["sparse_prefill_work_indptr"] + attn_metadata.sparse_prefill_reduce_indptr = var[ + "sparse_prefill_reduce_indptr" + ] + attn_metadata.sparse_prefill_reduce_final_map = var[ + "sparse_prefill_reduce_final_map" + ] + attn_metadata.sparse_prefill_reduce_partial_map = var[ + "sparse_prefill_reduce_partial_map" + ] if hasattr(self.model_runner, "drafter") or attn_metadata.has_cached: # Populate kv_last_page_lens for full sequence (needed for MLA prefill with diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index ff47a5fcc..b802d84d7 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -174,6 +174,15 @@ def _supports_fused_indexer_kernel_config(config: PretrainedConfig) -> bool: ) +def _is_neox_rope_style( + config: PretrainedConfig, interleave_attr: str, default_interleave: bool +) -> bool: + interleave = getattr(config, interleave_attr, default_interleave) + if interleave is None: + interleave = default_interleave + return not bool(interleave) + + def _can_fuse_indexer_wk_weights_proj( config: PretrainedConfig, quant_config: Optional[QuantizationConfig], @@ -1595,14 +1604,16 @@ def forward( weights = self.weights_proj(hidden_states) if not self.use_qk_rope_cache_fusion: - q_pe, _ = torch.split( + q_pe, q_nope = torch.split( q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 ) k = self.k_norm(k) - k_pe, _ = torch.split( + k_pe, k_nope = torch.split( k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 ) - q_pe, k_pe = rotary_emb(positions, q_pe, k_pe) + q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + q = torch.cat([q_pe, q_nope], dim=-1) + k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) q = q.view(-1, self.head_dim) q_fp8, q_scale = self.quant_func(q, quant_dtype=dtypes.fp8) @@ -1822,7 +1833,7 @@ def __init__( max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, - is_neox_style=False, + is_neox_style=_is_neox_rope_style(config, "rope_interleave", True), ) if rope_scaling: mscale_all_dim = rope_scaling.get("mscale_all_dim", False) @@ -1888,6 +1899,12 @@ def __init__( kv_b_proj=self.kv_b_proj, o_proj=self.o_proj, indexer=self.indexer, + # v3.2 / GLM-5.2 runs sparse MLA on every layer. For GLM-5.2 IndexShare + # "shared" layers self.indexer is None, but they must still run sparse + # attention and reuse the prior full layer's top-k, so flag sparsity at + # the model level rather than per-layer. + is_sparse=self.is_v32, + topk_tokens=(config.index_topk if self.is_v32 else None), ) self.mla_attn = Attention(