From 4a236c6e2f16cf291958dd4dff08ff2819809b71 Mon Sep 17 00:00:00 2001 From: "Ying.zhou2" Date: Mon, 29 Jun 2026 07:46:57 -0500 Subject: [PATCH] enable deepseek v4 kv cache fp8 --- atom/model_ops/attentions/deepseek_v4_attn.py | 105 ++++++++- atom/model_ops/v4_kernels/__init__.py | 9 +- atom/model_ops/v4_kernels/fused_compress.py | 22 +- atom/model_ops/v4_kernels/paged_decode.py | 127 +++++++++- atom/model_ops/v4_kernels/state_writes.py | 179 ++++++++++++++ atom/model_ops/v4_kernels/v4_quant.py | 188 +++++++++++++++ atom/models/deepseek_v4.py | 218 ++++++++++++++---- 7 files changed, 785 insertions(+), 63 deletions(-) create mode 100644 atom/model_ops/v4_kernels/v4_quant.py diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index 6521825c06..d95b56a07a 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -42,6 +42,7 @@ import numpy as np import torch from aiter import dtypes +from aiter.jit.utils.chip_info import get_gfx from atom.model_engine.scheduler import ScheduledBatch from atom.model_ops.attentions.backends import ( AttentionBackend, @@ -296,6 +297,7 @@ def __init__(self, model_runner): self.index_head_dim = getattr(hf, "index_head_dim", 128) self.window_size = getattr(hf, "sliding_window", 128) self.index_topk = getattr(hf, "index_topk", 1024) + self.rope_head_dim = getattr(hf, "qk_rope_head_dim", 64) # MTP-portion of compress_ratios. `prepare_mtp_decode`'s direct-kernel # fast path only handles SWA (ratio=0) draft layers; non-zero ratios # would also need n_committed_{csa,hca} + HCA compress tail rebuilt. @@ -316,8 +318,23 @@ def __init__(self, model_runner): self.k2_hca = self.block_size // 128 # = 1 self._state_dtype = torch.float32 # fp32 required for softmax-pool - self._swa_dtype = torch.bfloat16 # SWA window matches KV dtype - self._classical_dtype = torch.bfloat16 # CSA Main / HCA Main KV is BF16 + # KV cache dtype gate. fp8 → 2buff native layout (nope fp8 + inline e8m0 + # scale in a 512B entry; parallel bf16 rope pool). bf16 → unchanged. + self._kv_fp8 = model_runner.kv_cache_dtype == "fp8" + if self._kv_fp8: + # nope pool is fp8 (e4m3); SWA and classical (CSA/HCA Main) share it. + self._swa_dtype = dtypes.fp8 + self._classical_dtype = dtypes.fp8 + self._rope_dtype = torch.bfloat16 # rope pool is always bf16 + # aiter prefill (op4) / decode (op5) hard-check gfx950 internally. + assert get_gfx() == "gfx950", ( + "DeepSeek-V4 --kv_cache_dtype fp8 requires gfx950 (MI350); " + f"got {get_gfx()!r}. Use --kv_cache_dtype bf16 on this platform." + ) + else: + self._swa_dtype = torch.bfloat16 # SWA window matches KV dtype + self._classical_dtype = torch.bfloat16 # CSA/HCA Main KV is BF16 + self._rope_dtype = torch.bfloat16 # unused in bf16 path (symmetry) # CSA Indexer cache is FP8 + 4-byte fp32 scale per row, aligned to 16 # bytes (matches V3.2 sparse MLA pattern; avoids torch inductor # unaligned-access slowdowns). Written by `indexer_k_quant_and_cache`, @@ -408,7 +425,7 @@ def compute_per_req_cache_bytes(self) -> int: for every Compressor instance (CSA Main / CSA Indexer / HCA Main). """ elem_state = self._state_dtype.itemsize # fp32 = 4 - elem_swa = self._swa_dtype.itemsize # bf16 = 2 + elem_swa = self._swa_dtype.itemsize # fp8 = 1 or bf16 = 2 # Tail buffers (kv_state + score_state pair per Compressor instance). csa_main = self._numel(self.csa_main_state_shape) * 2 * elem_state csa_idx = self._numel(self.csa_idx_state_shape) * 2 * elem_state @@ -416,6 +433,11 @@ def compute_per_req_cache_bytes(self) -> int: # SWA window per layer. Cache holds `win_with_spec = win + mtp_k` # slots so MTP draft tokens don't alias verified-token slots. swa_per_layer = self.win_with_spec * self.head_dim * elem_swa + if self._kv_fp8: + # 2buff: parallel bf16 rope pool [win_with_spec, rope_head_dim]. + swa_per_layer += ( + self.win_with_spec * self.rope_head_dim * self._rope_dtype.itemsize + ) return ( len(self.csa_layers) * (csa_main + csa_idx) + len(self.hca_layers) * hca_main @@ -442,10 +464,15 @@ def compute_block_bytes(self) -> int: read by `cp_gather_indexer_k_quant_cache`). - HCA Main: k2=1 entry × head_dim BF16 """ - elem_bf16 = self._classical_dtype.itemsize - csa_main_per_block = self.k1_csa * self.head_dim * elem_bf16 + elem_nope = self._classical_dtype.itemsize # fp8 = 1 or bf16 = 2 + csa_main_per_block = self.k1_csa * self.head_dim * elem_nope csa_idx_per_block = self.k1_csa * self._aligned_index_dim # fp8 = 1B - hca_main_per_block = self.k2_hca * self.head_dim * elem_bf16 + hca_main_per_block = self.k2_hca * self.head_dim * elem_nope + if self._kv_fp8: + # 2buff: parallel bf16 rope pool per compress entry. + elem_rope = self._rope_dtype.itemsize + csa_main_per_block += self.k1_csa * self.rope_head_dim * elem_rope + hca_main_per_block += self.k2_hca * self.rope_head_dim * elem_rope return ( len(self.csa_layers) * (csa_main_per_block + csa_idx_per_block) + len(self.hca_layers) * hca_main_per_block @@ -507,7 +534,9 @@ def allocate_per_req_cache(self, num_slots: int) -> dict[str, object]: """ assert self._swa_dtype == self._classical_dtype, ( "unified_kv requires SWA dtype == classical KV dtype " - f"(got SWA={self._swa_dtype}, classical={self._classical_dtype})" + f"(got SWA={self._swa_dtype}, classical={self._classical_dtype}). " + "fp8 path must set both to dtypes.fp8 (rope lives in a separate " + "bf16 pool); a genuine mismatch corrupts the unified layout." ) device = self.model_runner.device num_blocks = self.model_runner.num_physical_kvcache_blocks @@ -536,6 +565,30 @@ def allocate_per_req_cache(self, num_slots: int) -> dict[str, object]: ) ) + # ---- 2buff fp8: parallel per-layer rope pool (bf16) ------------------ + # Same [swa_pages + compress_pages] page count as unified_kv, but width + # = rope_head_dim (64) and dtype bf16 (rope is never quantized). bf16 + # path: list of None (no rope pool; rope stays inline in unified_kv). + unified_kv_rope: list[Optional[torch.Tensor]] = [] + if self._kv_fp8: + for layer_id in range(self.num_layers): + ratio = ratios[layer_id] + if ratio == 4: + compress_pages = num_blocks * self.k1_csa + elif ratio == 128: + compress_pages = num_blocks * self.k2_hca + else: + compress_pages = 0 # Dense + unified_kv_rope.append( + torch.zeros( + (swa_pages + compress_pages, self.rope_head_dim), + dtype=self._rope_dtype, + device=device, + ) + ) + else: + unified_kv_rope = [None] * self.num_layers + # ---- Compressor state tensors (compute-contiguous) ------------------ csa_main_kv = self._zero_state( (n_csa, num_slots, *self.csa_main_state_shape), device @@ -580,6 +633,7 @@ def allocate_per_req_cache(self, num_slots: int) -> dict[str, object]: return { "v4_unified_kv": unified_kv, + "v4_unified_kv_rope": unified_kv_rope, "v4_csa_main_kv_state": csa_main_kv, "v4_csa_main_score_state": csa_main_score, "v4_csa_idx_kv_state": csa_idx_kv, @@ -626,6 +680,16 @@ def build_kv_cache_tensor(self, layer_id: int, module): module.swa_kv = unified[:swa_pages].view( num_slots, self.win_with_spec, self.head_dim ) + module.kv_fp8 = self._kv_fp8 + if self._kv_fp8: + rope = runner.v4_unified_kv_rope[module.layer_id] + module.unified_kv_rope = rope + module.swa_kv_rope = rope[:swa_pages].view( + num_slots, self.win_with_spec, self.rope_head_dim + ) + else: + module.unified_kv_rope = None + module.swa_kv_rope = None return None if isinstance(module, _V4Indexer): @@ -685,6 +749,10 @@ def build_kv_cache_tensor(self, layer_id: int, module): storage_offset=scale_fp32_offset, ) ) + # Indexer-inner cache is always fp8 (independent of + # kv_cache_dtype); it has no separate rope pool. + module.write_mode = "indexer_fp8" + module.kv_cache_rope = None elif ratio == 4: pos = self.layer_id_to_csa_pos[layer_id_from_prefix] module.kv_state = runner.v4_csa_main_kv_state[pos] @@ -698,6 +766,15 @@ def build_kv_cache_tensor(self, layer_id: int, module): module.kv_cache = unified[swa_pages:].view( num_blocks, self.k1_csa, self.head_dim ) + if self._kv_fp8: + rope = runner.v4_unified_kv_rope[layer_id_from_prefix] + module.kv_cache_rope = rope[swa_pages:].view( + num_blocks, self.k1_csa, self.rope_head_dim + ) + module.write_mode = "main_2buff_fp8" + else: + module.kv_cache_rope = None + module.write_mode = "bf16" elif ratio == 128: pos = self.layer_id_to_hca_pos[layer_id_from_prefix] module.kv_state = runner.v4_hca_main_kv_state[pos] @@ -707,6 +784,15 @@ def build_kv_cache_tensor(self, layer_id: int, module): module.kv_cache = unified[swa_pages:].view( num_blocks, self.k2_hca, self.head_dim ) + if self._kv_fp8: + rope = runner.v4_unified_kv_rope[layer_id_from_prefix] + module.kv_cache_rope = rope[swa_pages:].view( + num_blocks, self.k2_hca, self.rope_head_dim + ) + module.write_mode = "main_2buff_fp8" + else: + module.kv_cache_rope = None + module.write_mode = "bf16" else: raise ValueError( f"Unknown V4 compress_ratio={ratio} on Compressor at " @@ -725,6 +811,11 @@ def get_kv_transfer_tensors(self): runner = self.model_runner if not hasattr(runner, "v4_unified_kv"): return None + if self._kv_fp8: + # PD disaggregation with 2buff fp8 KV cache is not yet supported: + # the byte-region math below assumes a single bf16 unified pool and + # ignores the parallel rope pool. Disable KV transfer for fp8. + return None num_slots = runner.max_per_req_cache_slots swa_pages = num_slots * self.win_with_spec diff --git a/atom/model_ops/v4_kernels/__init__.py b/atom/model_ops/v4_kernels/__init__.py index 0b86fff078..599a1e857d 100644 --- a/atom/model_ops/v4_kernels/__init__.py +++ b/atom/model_ops/v4_kernels/__init__.py @@ -43,11 +43,18 @@ qk_norm_rope_maybe_quant, qk_norm_rope_maybe_quant_reference, ) -from atom.model_ops.v4_kernels.state_writes import update_compressor_states, swa_write +from atom.model_ops.v4_kernels.state_writes import ( + qk_norm_rope_quant_2buff, + swa_write_2buff_prepacked, + update_compressor_states, + swa_write, +) __all__ = [ "update_compressor_states", "swa_write", + "swa_write_2buff_prepacked", + "qk_norm_rope_quant_2buff", "fused_compress_attn", "fused_compress_attn_reference", "sparse_attn_v4_paged_decode", diff --git a/atom/model_ops/v4_kernels/fused_compress.py b/atom/model_ops/v4_kernels/fused_compress.py index a178dd2985..1685d79e83 100644 --- a/atom/model_ops/v4_kernels/fused_compress.py +++ b/atom/model_ops/v4_kernels/fused_compress.py @@ -403,7 +403,13 @@ def fused_compress_attn( ] = None, # fp32 [NB, k_per_block]; required when quant=True use_ue8m0: bool = True, # round scale to power-of-2 (UE8M0); only when quant=True preshuffle: bool = True, # MFMA 16x16 preshuffled FP8 layout; only when quant=True - fp8_max: Optional[float] = None, # E4M3 max; required when quant=True + fp8_max: Optional[float] = None, # E4M3 max; required when quant or main_2buff_fp8 + # V4-Main native fp8 2buff path (CSA/HCA Main under --kv_cache_dtype fp8). + # Distinct from `quant` (Indexer-inner per-row preshuffle): writes per-64-tile + # e8m0 nope-fp8 + inline dup-scale into `kv_cache` (fp8 [NB,k,512]) and bf16 + # rope into `kv_cache_rope` (bf16 [NB,k,64]) via the flydsl group_fp8 scatter. + main_2buff_fp8: bool = False, + kv_cache_rope: Optional[torch.Tensor] = None, # bf16 [NB,k_per_block,64] ) -> None: """Batched fused per-source-position pool + RMSNorm + RoPE + cache scatter, dispatched via SGLang-style packed plan. @@ -462,6 +468,8 @@ def fused_compress_attn( and _shape_key == (512, 64, 128, False) ) if _hca_use: + # main_2buff_fp8: native group_fp8 2buff scatter (nope-fp8 into + # kv_cache, bf16 rope into k_rope_cache). Otherwise plain bf16. flydsl_hca_compress_attn( kv_in=kv_in, score_in=score_in, @@ -480,9 +488,15 @@ def fused_compress_attn( ratio=ratio, head_dim=head_dim, rope_head_dim=rope_head_dim, + quant=main_2buff_fp8, + k_rope_cache=kv_cache_rope if main_2buff_fp8 else None, + quant_group_size=64, ) return if _flydsl_use: + # main_2buff_fp8: CSA Main native group_fp8 2buff (nope-fp8 + inline + # e8m0 into kv_cache, bf16 rope into k_rope_cache; scale carried inline + # so cache_scale stays None). Indexer-inner uses per_row_fp8 preshuffle. flydsl_fused_compress_attn( kv_in=kv_in, score_in=score_in, @@ -502,10 +516,12 @@ def fused_compress_attn( ratio=ratio, head_dim=head_dim, rope_head_dim=rope_head_dim, - quant=quant, + quant=quant or main_2buff_fp8, cache_scale=cache_scale, use_ue8m0=use_ue8m0, - preshuffle=preshuffle, + preshuffle=preshuffle and not main_2buff_fp8, + quant_mode="group_fp8" if main_2buff_fp8 else "per_row_fp8", + k_rope_cache=kv_cache_rope if main_2buff_fp8 else None, ) return diff --git a/atom/model_ops/v4_kernels/paged_decode.py b/atom/model_ops/v4_kernels/paged_decode.py index e91acd947e..dc5d0c119f 100644 --- a/atom/model_ops/v4_kernels/paged_decode.py +++ b/atom/model_ops/v4_kernels/paged_decode.py @@ -902,6 +902,109 @@ def sparse_attn_v4_paged_decode_reference( return _sparse_attn_ragged_torch(q, unified_kv, attn_sink, topk_idxs, softmax_scale) +def _sparse_attn_v4_paged_decode_asm( + unified_kv: torch.Tensor, + kv_indices: torch.Tensor, + kv_indptr: torch.Tensor, + attn_sink: torch.Tensor, + softmax_scale: float, + unified_kv_rope: torch.Tensor, + q_packed_in: torch.Tensor, + q_rope_in: torch.Tensor, + num_kv_splits: int | None = None, +) -> torch.Tensor: + """Native 2buff fp8 V4 decode via the aiter assembly kernel + ``mla_decode_fwd_v4_nm`` (ROCm/aiter#3112, mi350/gfx950). + + ``unified_kv`` is the packed 512-byte fp8 NoPE pool ``[P, 512]`` and + ``unified_kv_rope`` is the bf16 RoPE pool ``[P, 64]`` — consumed with NO + requant (caller stored the 2buff layout natively via op1). Q is supplied + pre-packed (``q_packed_in``/``q_rope_in``) by the fused decode KV-write + (op1 ``fused_qk_norm_rope_group_quant``); no per-call quantize. + + The flat-CSR ``unified_kv`` is made compatible with the kernel's paged + format by treating each token as a 1-token page (``page_size=1``): + ``kv_page_indices = kv_indices``, ``kv_indptr`` is the existing CSR indptr, + ``kv_last_page_lens = ones(N)``, ``qo_indptr = arange(N+1)``, + ``max_seqlen_q = 1``. No translation layer needed. + + softmax_scale is ignored by the kernel (hardcodes 1/sqrt(512)); passed + through for API parity. + """ + import aiter + import aiter.mla + + from atom.model_ops.v4_kernels.v4_quant import ( + V4_DIM_NOPE, + V4_DIM_QK_PACKED, + V4_DIM_ROPE, + ) + + assert q_packed_in.dim() == 3 and q_packed_in.shape[-1] == V4_DIM_QK_PACKED, ( + f"asm v4 nm: q_packed_in must be [N, H, {V4_DIM_QK_PACKED}] fp8, " + f"got {tuple(q_packed_in.shape)}" + ) + assert q_rope_in.dim() == 3 and q_rope_in.shape[-1] == V4_DIM_ROPE, ( + f"asm v4 nm: q_rope_in must be [N, H, {V4_DIM_ROPE}] bf16, " + f"got {tuple(q_rope_in.shape)}" + ) + q_packed = q_packed_in + q_rope = q_rope_in + device = q_packed.device + N, H, _ = q_packed.shape + + assert unified_kv.dim() == 2 and unified_kv.shape[-1] == V4_DIM_QK_PACKED, ( + f"asm v4 nm: native fp8 unified_kv must be [P, {V4_DIM_QK_PACKED}] fp8, " + f"got {tuple(unified_kv.shape)}" + ) + assert unified_kv_rope.dim() == 2 and unified_kv_rope.shape[-1] == V4_DIM_ROPE, ( + f"asm v4 nm: native fp8 unified_kv_rope must be [P, {V4_DIM_ROPE}] bf16, " + f"got {tuple(unified_kv_rope.shape)}" + ) + + nhead_kv = 1 + page_size = 1 + # Kernel expects KV as [num_page, page_size=1, num_kv_heads=1, dim]. + kv_packed = unified_kv.view(-1, page_size, nhead_kv, V4_DIM_QK_PACKED) + kv_rope = unified_kv_rope.view(-1, page_size, nhead_kv, V4_DIM_ROPE) + + # ---- per-token paged layout (one slot per token, decode=1) ---------- + qo_indptr = torch.arange(N + 1, dtype=torch.int32, device=device) + kv_last_page_lens = torch.ones(N, dtype=torch.int32, device=device) + kv_indptr_i32 = kv_indptr.to(torch.int32) + kv_page_indices_i32 = kv_indices.to(torch.int32).contiguous() + max_seqlen_q = 1 + + output = torch.empty( + (N, H, V4_DIM_NOPE + V4_DIM_ROPE), dtype=torch.bfloat16, device=device + ) + out_16_nosplit = 1 if num_kv_splits == 1 else 0 + + logits, _ = aiter.mla.mla_decode_fwd_v4_nm( + q_packed, + q_rope, + kv_packed, + kv_rope, + output, + qo_indptr, + kv_indptr_i32, + kv_page_indices_i32, + kv_last_page_lens, + max_seqlen_q, + sink=attn_sink, + sm_scale=softmax_scale, + out_16_nosplit=out_16_nosplit, + num_kv_splits=num_kv_splits, + ) + # Result lands in `output` for the bf16-direct write (out_16_nosplit=1) or + # the stage2 LSE merge (resolved splits > 1). Only the single-pass fp32 + # path leaves it in logits[:, 0]. logits.shape[1] = resolved split count. + resolved_splits = logits.shape[1] + if out_16_nosplit != 0 or resolved_splits > 1: + return output.to(torch.bfloat16) + return logits[:, 0].to(torch.bfloat16) + + def sparse_attn_v4_paged_decode( q: torch.Tensor, unified_kv: torch.Tensor, @@ -910,12 +1013,32 @@ def sparse_attn_v4_paged_decode( attn_sink: torch.Tensor, softmax_scale: float, kv_scales: torch.Tensor | None = None, + unified_kv_rope: torch.Tensor | None = None, + q_packed_in: torch.Tensor | None = None, + q_rope_in: torch.Tensor | None = None, ) -> torch.Tensor: """V4 decode sparse attention over a unified KV pool with paged indices. - When ``kv_scales`` is provided, ``unified_kv`` must be fp8 (e4m3fnuz) and - will be dequantized in-kernel using 1xGROUP_SIZE (default 64) block scales. + Native 2buff fp8 (``unified_kv_rope`` provided): routes to the aiter asm + kernel (op5) with pre-packed fp8 Q (``q_packed_in``/``q_rope_in``); the + fp8 NoPE pool + bf16 RoPE pool are read with no requant. + + Otherwise (bf16): the existing Triton / reference path. When ``kv_scales`` + is provided, ``unified_kv`` must be fp8 (e4m3fnuz) and is dequantized + in-kernel using 1xGROUP_SIZE (default 64) block scales (legacy 1buff, + unreachable from the model). """ + if unified_kv_rope is not None: + return _sparse_attn_v4_paged_decode_asm( + unified_kv, + kv_indices, + kv_indptr, + attn_sink, + softmax_scale, + unified_kv_rope, + q_packed_in, + q_rope_in, + ) if os.environ.get("ATOM_USE_TRITON_ATTN", "1") == "1": return _sparse_attn_v4_paged_decode_triton( q, diff --git a/atom/model_ops/v4_kernels/state_writes.py b/atom/model_ops/v4_kernels/state_writes.py index 068b5d30a6..674ecb3f6a 100644 --- a/atom/model_ops/v4_kernels/state_writes.py +++ b/atom/model_ops/v4_kernels/state_writes.py @@ -51,6 +51,8 @@ GPU index buffer needed (no DMA race window). """ +from typing import Optional + import torch import triton import triton.language as tl @@ -215,6 +217,183 @@ def swa_write_reference( swa_kv[slot, ring_idx] = src_kv +def swa_write_2buff_prepacked( + k_packed: torch.Tensor, + k_rope: torch.Tensor, + positions: torch.Tensor, + cu_seqlens_q: torch.Tensor, + state_slot_per_seq: torch.Tensor, + swa_kv_nope: torch.Tensor, + swa_kv_rope: torch.Tensor, + cache_size: int, + write_per_batch: int, +) -> None: + """Native 2buff fp8 SWA write: ring-scatter the LAST + ``min(tok_n_b, write_per_batch)`` tokens of every seq into the two + fp8/bf16 SWA rings. The K is ALREADY in the 2buff layout (nope-fp8 + ``[T,512]`` + rope-bf16 ``[T,64]``) — produced upstream by aiter op + ``fused_qk_norm_rope_group_quant`` (the prefill K branch). This is a + pure dtype-agnostic scatter (reuses ``_swa_write_kernel`` once per ring); + NO torch quantization happens here. + + Args: + k_packed: [T, 512] fp8 — op1-quantized K nope+inline-scale+pad. + k_rope: [T, 64] bf16 — op1-rotated K-PE (not quantized). + swa_kv_nope: [num_slots, cache_size, 512] fp8 ring (2buff nope pool). + swa_kv_rope: [num_slots, cache_size, 64] bf16 ring (rope pool). + (other args as ``swa_write``.) + """ + from atom.model_ops.v4_kernels.v4_quant import ( + V4_DIM_QK_PACKED, + V4_DIM_ROPE, + ) + + assert ( + k_packed.dim() == 2 and k_packed.shape[1] == V4_DIM_QK_PACKED + ), f"k_packed must be [T,{V4_DIM_QK_PACKED}] fp8, got {tuple(k_packed.shape)}" + assert ( + k_rope.dim() == 2 and k_rope.shape[1] == V4_DIM_ROPE + ), f"k_rope must be [T,{V4_DIM_ROPE}] bf16, got {tuple(k_rope.shape)}" + assert swa_kv_nope.dim() == 3 and swa_kv_nope.shape[2] == V4_DIM_QK_PACKED + assert swa_kv_rope.dim() == 3 and swa_kv_rope.shape[2] == V4_DIM_ROPE + assert swa_kv_nope.shape[1] == cache_size and swa_kv_rope.shape[1] == cache_size + + swa_write( + k_packed.contiguous(), + positions, + cu_seqlens_q, + state_slot_per_seq, + swa_kv_nope, + cache_size, + write_per_batch, + ) + swa_write( + k_rope.contiguous(), + positions, + cu_seqlens_q, + state_slot_per_seq, + swa_kv_rope, + cache_size, + write_per_batch, + ) + + +def qk_norm_rope_quant_2buff( + q: torch.Tensor, + kv_pre: torch.Tensor, + kv_weight: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + positions: torch.Tensor, + n_local_heads: int, + head_dim: int, + rope_head_dim: int, + eps: float, + swa_scatter: bool = False, + state_slot_per_seq: Optional[torch.Tensor] = None, + batch_id_per_token: Optional[torch.Tensor] = None, + swa_kv_nope: Optional[torch.Tensor] = None, + swa_kv_rope: Optional[torch.Tensor] = None, +): + """Single aiter HIP op (op1): Q+K RMSNorm + RoPE + fp8 2buff group-quant, + with an OPTIONAL fused SWA ring-scatter of the K row. All-native — NO torch + quantization, NO separate scatter launch. Sole entry point for the fp8 + 2buff Q/K quantization on both the decode and prefill paths. + + - ``swa_scatter=False`` (prefill): quantize only. The returned + ``k_packed``/``k_rope`` feed op4 (``pa_sparse_prefill_fp8_opus``) and are + scattered into the SWA rings AFTER attention via + ``swa_write_2buff_prepacked`` (prefix must be read before it is + overwritten). + - ``swa_scatter=True`` (decode): op1 also scatters the K row into both SWA + rings in the same launch (``state_slot_per_seq`` / ``batch_id_per_token`` + / ``swa_kv_nope`` / ``swa_kv_rope`` required). The returned K buffers are + throwaway. op1's scatter addresses identically to ``swa_write``: K lands + at ``swa_*[state_slot_per_seq[batch_id_per_token[t]], + positions[t] % cache_size, :]`` (ring derived from the buffer's dim-1 + size inside the kernel; ``batch_id_per_token[t] < 0`` CG-pad tokens are + skipped). + + Args: + q: [T, H*D] bf16 — post-``wq_b`` Q (heads packed in last dim). + kv_pre: [T, D] bf16 — post-``wkv_a`` split KV row (pre norm/rope). + kv_weight: [D] bf16 — KV-side RMSNorm weight (Q is weightless). + cos_cache, sin_cache: RoPE tables, [max_pos, rope_head_dim/2]. + positions: [T'] int64 — absolute token positions (T' >= T). + swa_scatter: enable the fused decode SWA ring-scatter (see above). + state_slot_per_seq: [bs] int32 — per-seq SWA ring slot (scatter only). + batch_id_per_token: [T'] int32 — token→seq map, -1 on CG-pad tokens + (scatter only). + swa_kv_nope: [num_slots, cache_size, 512] fp8 ring (scatter only). + swa_kv_rope: [num_slots, cache_size, 64] bf16 ring (scatter only). + + Returns: + (q_packed, q_rope, k_packed, k_rope): + q_packed [T, H, 512] fp8 — Q nope fp8 + 14B dup e8m0 scale + pad. + q_rope [T, H, 64] bf16 — rotated Q-PE (not quantized). + k_packed [T, 1, 512] fp8 — K nope fp8 + scale + pad (throwaway when + ``swa_scatter`` — already in the rings). + k_rope [T, 1, 64] bf16 — rotated K-PE. + """ + import aiter + from aiter import dtypes + + from atom.model_ops.v4_kernels.v4_quant import ( + V4_DIM_QK, + V4_DIM_QK_PACKED, + V4_DIM_ROPE, + ) + + assert ( + q.dim() == 2 and q.shape[1] == n_local_heads * head_dim + ), f"q must be [T, H*D={n_local_heads * head_dim}], got {tuple(q.shape)}" + assert kv_pre.dim() == 2 and kv_pre.shape[1] == head_dim + assert head_dim == V4_DIM_QK, f"fused 2buff path requires head_dim={V4_DIM_QK}" + assert rope_head_dim == V4_DIM_ROPE + + T = q.shape[0] + q3 = q.view(T, n_local_heads, head_dim) + kv3 = kv_pre.view(T, 1, head_dim) + + swa_kwargs = {} + if swa_scatter: + assert ( + state_slot_per_seq is not None + and batch_id_per_token is not None + and swa_kv_nope is not None + and swa_kv_rope is not None + ), "swa_scatter requires state_slot_per_seq/batch_id_per_token/swa_kv_*" + assert state_slot_per_seq.dtype == torch.int32 + assert batch_id_per_token.dtype == torch.int32 + assert swa_kv_nope.dim() == 3 and swa_kv_nope.shape[2] == V4_DIM_QK_PACKED + assert swa_kv_rope.dim() == 3 and swa_kv_rope.shape[2] == V4_DIM_ROPE + swa_kwargs = dict( + swa_nope_scale_buff=swa_kv_nope, + swa_rope_buff=swa_kv_rope, + state_slot_mapping=state_slot_per_seq, + batch_id_per_token=batch_id_per_token[:T], + ) + + # Fused Q+K norm+rope+fp8 group-quant (single HIP launch). fp8 Q output so + # op4/op5 consume it directly. When swa_scatter, op1 also writes the K row + # into both SWA rings (no separate swa_write). Buffers auto-allocated + # (zeroed pad). + return aiter.fused_qk_norm_rope_group_quant( + q3, + kv3, + kv_weight, + positions[:T], + cos_cache, + sin_cache, + eps, + is_neox=False, + q_out_dtype=dtypes.fp8, + quant_group_size=64, + scale_dtype="e8m0", + **swa_kwargs, + ) + + # === Unified Compressor state save (plan path) ========================== # Paper §3.6.1: per-request fixed-size state cache for "uncompressed tail # tokens + previous block as overlap context (B-side, eq 11)". ATOM keeps diff --git a/atom/model_ops/v4_kernels/v4_quant.py b/atom/model_ops/v4_kernels/v4_quant.py new file mode 100644 index 0000000000..c5ae45ecd8 --- /dev/null +++ b/atom/model_ops/v4_kernels/v4_quant.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""V4 MLA on-the-fly quantization helpers for the hipkittens decode kernel. + +The hipkittens kernel ``aiter.mla.mla_v40_decode_fwd`` consumes Q/KV in a +two-buffer layout: + + nope_scale_buff: [..., 512] FP8 (1 byte/elem) + bytes [0 , 448): NOPE FP8 (per-token feature, position-agnostic) + bytes [448 , 462): 14 duplicated E8M0 scales (7 per-64-elt scales x2) + bytes [462 , 512): 50 bytes unused trailing pad + rope_buff: [..., 64] BF16 (per-token RoPE-rotated, kept BF16) + +ATOM currently stores ``unified_kv`` as a single contiguous bf16 tensor of +shape ``[..., 512]`` (NoPE 448 + RoPE 64 concatenated). For the PoC we +convert bf16 -> (fp8+scale+pad, bf16-rope) on every decode call. A future +phase 2 may change ``unified_kv`` to store the packed layout directly so the +runtime quantization is skipped. + +All constants and pack arithmetic mirror +``/mnt/raid0/ruitang3/git_repo/aiter/op_tests/test_mla_v4_persistent.py`` +(``V4_*`` constants and ``pack_v4_nope_scale``/``quantize_v4_nope_bpad8``). +""" + +from __future__ import annotations + +from typing import Tuple + +import torch +from aiter import dtypes + +V4_DIM_NOPE = 448 +V4_DIM_ROPE = 64 +V4_DIM_QK = V4_DIM_NOPE + V4_DIM_ROPE # 512 +V4_TILE = 64 +V4_NUM_TILES = V4_DIM_NOPE // V4_TILE # 7 +V4_DIM_SCALE_DUP = V4_DIM_NOPE // (V4_TILE // 2) # 14 +V4_DIM_QK_PACKED = 512 +V4_PACK_OFF_NOPE = 0 +V4_PACK_OFF_SCALE = V4_DIM_NOPE # 448 +V4_PACK_OFF_PAD = V4_DIM_NOPE + V4_DIM_SCALE_DUP # 462 + + +def _fp32_pow2_to_e8m0(pow2_fp32: torch.Tensor) -> torch.Tensor: + """Pack a power-of-2 fp32 scale into a 1-byte E8M0 exponent + (byte B encodes 2^(B-127); B=0 -> 0.0, B=255 -> INF).""" + safe = torch.where(pow2_fp32 > 0, pow2_fp32, torch.ones_like(pow2_fp32)) + biased = torch.log2(safe).round().to(torch.int32) + 127 + biased = torch.clamp(biased, 0, 254) + biased = torch.where(pow2_fp32 > 0, biased, torch.zeros_like(biased)) + return biased.to(torch.uint8) + + +def _cast_scale_inv_to_ue8m0_pow2(scales_inv: torch.Tensor) -> torch.Tensor: + """amax/FP8_AMAX -> ceil-log2 -> power-of-2 fp32.""" + return torch.pow(2.0, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to( + torch.float32 + ) + + +def _duplicate_each_lastdim(x: torch.Tensor) -> torch.Tensor: + """[..., N] -> [..., 2*N] with each element written twice.""" + return x.unsqueeze(-1).expand(*x.shape, 2).reshape(*x.shape[:-1], x.shape[-1] * 2) + + +def quantize_v4_nope_bpad8( + nope_src: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Per-tile (64-elt) E8M0 quantization of the NOPE segment. + + Returns ``(nope_fp8 [..., 448], scale_e8m0 [..., 7] uint8)``. + """ + fp8_amax = float(torch.finfo(dtypes.fp8).max) + nope_fp32 = nope_src.float() + leading = nope_fp32.shape[:-1] + tiled = nope_fp32.reshape(*leading, V4_NUM_TILES, V4_TILE) + active_scale_pow2 = _cast_scale_inv_to_ue8m0_pow2( + tiled.abs().amax(dim=-1) / fp8_amax + ) + nope_fp8 = ( + (tiled / active_scale_pow2.unsqueeze(-1)) + .to(dtypes.fp8) + .reshape(*leading, V4_DIM_NOPE) + ) + scale_e8m0 = _fp32_pow2_to_e8m0(active_scale_pow2) # [..., 7] uint8 + return nope_fp8, scale_e8m0 + + +def pack_v4_nope_scale( + nope_fp8: torch.Tensor, scale_e8m0: torch.Tensor +) -> torch.Tensor: + """Pack NOPE + duplicated E8M0 scale into a single 512-byte/token FP8 tensor. + + nope_fp8: [..., 448] FP8 + scale_e8m0: [..., 7] uint8 (E8M0 byte per quant tile) + returns: [..., 512] FP8 (NoPE | dup-scale x14 | pad x50) + """ + leading = nope_fp8.shape[:-1] + assert nope_fp8.shape[-1] == V4_DIM_NOPE + assert scale_e8m0.shape[-1] == V4_NUM_TILES + assert scale_e8m0.shape[:-1] == leading + + packed = torch.zeros( + (*leading, V4_DIM_QK_PACKED), dtype=torch.uint8, device=nope_fp8.device + ) + packed[..., V4_PACK_OFF_NOPE : V4_PACK_OFF_NOPE + V4_DIM_NOPE] = nope_fp8.view( + torch.uint8 + ) + packed[..., V4_PACK_OFF_SCALE : V4_PACK_OFF_SCALE + V4_DIM_SCALE_DUP] = ( + _duplicate_each_lastdim(scale_e8m0) + ) + return packed.view(dtypes.fp8) + + +def _e8m0_to_fp32_pow2(scale_e8m0: torch.Tensor) -> torch.Tensor: + """Inverse of ``_fp32_pow2_to_e8m0``: E8M0 byte B -> fp32 2^(B-127). + + B == 0 decodes to 0.0 (the zero-scale sentinel produced by the forward + path for all-zero tiles).""" + biased = scale_e8m0.to(torch.int32) + pow2 = torch.pow(2.0, (biased - 127).float()) + return torch.where(biased > 0, pow2, torch.zeros_like(pow2)) + + +def dequantize_v4_2buff_to_bf16( + packed_fp8: torch.Tensor, + rope_bf16: torch.Tensor, +) -> torch.Tensor: + """Inverse of ``quantize_bf16_to_v4_2buff``. + + Takes the two-buffer layout ``(packed_fp8 [..., 512], rope_bf16 [..., 64])`` + and reconstructs the bf16 ``[..., 512]`` row (NoPE 448 + RoPE 64). + + The NoPE half is dequantized per 64-elt tile: ``fp8_val * 2^(B-127)`` where + ``B`` is the tile's E8M0 scale byte (read from the dup-scale region; we use + the first of each duplicated pair). Round-trips + ``quantize_bf16_to_v4_2buff`` to within fp8 per-tile quantization error. + """ + assert packed_fp8.shape[-1] == V4_DIM_QK_PACKED, ( + f"dequantize_v4_2buff_to_bf16: packed last dim must be " + f"{V4_DIM_QK_PACKED}, got {tuple(packed_fp8.shape)}" + ) + assert rope_bf16.shape[-1] == V4_DIM_ROPE, ( + f"dequantize_v4_2buff_to_bf16: rope last dim must be {V4_DIM_ROPE}, " + f"got {tuple(rope_bf16.shape)}" + ) + leading = packed_fp8.shape[:-1] + packed_u8 = packed_fp8.view(torch.uint8) + + nope_fp8 = packed_u8[..., V4_PACK_OFF_NOPE : V4_PACK_OFF_NOPE + V4_DIM_NOPE].view( + dtypes.fp8 + ) + nope_fp32 = nope_fp8.float().reshape(*leading, V4_NUM_TILES, V4_TILE) + + # Dup-scale region holds each of the 7 tile scales twice; take the even + # entries to recover the 7 per-tile E8M0 bytes. + scale_dup = packed_u8[..., V4_PACK_OFF_SCALE : V4_PACK_OFF_SCALE + V4_DIM_SCALE_DUP] + scale_e8m0 = scale_dup[..., 0::2] # [..., 7] + scale_pow2 = _e8m0_to_fp32_pow2(scale_e8m0) # [..., 7] fp32 + + nope_bf16 = ( + (nope_fp32 * scale_pow2.unsqueeze(-1)) + .reshape(*leading, V4_DIM_NOPE) + .to(torch.bfloat16) + ) + rope = rope_bf16.to(torch.bfloat16) + return torch.cat([nope_bf16, rope], dim=-1) + + +def quantize_bf16_to_v4_2buff( + bf16_src: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """End-to-end helper: bf16 [..., 512] -> (packed_fp8 [..., 512], rope_bf16 [..., 64]). + + Splits the input on the NoPE/RoPE boundary, quantizes the NoPE half via + ``quantize_v4_nope_bpad8`` + ``pack_v4_nope_scale``, and keeps the RoPE + half in bf16 (contiguous). + """ + assert bf16_src.shape[-1] == V4_DIM_QK, ( + f"quantize_bf16_to_v4_2buff: last dim must be {V4_DIM_QK}, " + f"got {tuple(bf16_src.shape)}" + ) + nope_src = bf16_src[..., :V4_DIM_NOPE] + rope_src = bf16_src[..., V4_DIM_NOPE:].to(torch.bfloat16).contiguous() + nope_fp8, scale_e8m0 = quantize_v4_nope_bpad8(nope_src) + packed_fp8 = pack_v4_nope_scale(nope_fp8, scale_e8m0) + return packed_fp8, rope_src diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index b98e2dedf0..e023ab572b 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -89,10 +89,12 @@ fused_compress_attn, inverse_rope_inplace, qk_norm_rope_maybe_quant, + qk_norm_rope_quant_2buff, scale_indexer_weights, sparse_attn_v4_paged_decode, sparse_attn_v4_paged_prefill, swa_write, + swa_write_2buff_prepacked, update_compressor_states, ) from atom.utils import envs, mark_spliting_op @@ -886,6 +888,13 @@ def __init__( # of `self.kv_cache`. Bound by the V4 builder when `kv_cache.dtype` is # FP8 (Indexer-inner Compressor); None for BF16 cache (Main path). self.cache_scale: Optional[torch.Tensor] = None + # Native 2buff fp8 Main path (CSA/HCA Main under --kv_cache_dtype fp8): + # parallel bf16 rope pool [NB, k_per_block, rope_head_dim] and a write + # mode tag. Bound by DeepseekV4AttentionMetadataBuilder; "main_2buff_fp8" + # routes the compress scatter to the flydsl group_fp8 path, "bf16" / + # "indexer_fp8" keep the existing behavior. + self.kv_cache_rope: Optional[torch.Tensor] = None + self.write_mode: str = "bf16" # State cache (per paper §3.6.1 "uncompressed tail + B-side overlap # window" portion). Indexed as a single ring buffer of size @@ -1012,19 +1021,27 @@ def forward( # fwd's data for the NEXT fwd's overlap — must run AFTER the fused # kernel. cos_cache, sin_cache = self.rotary_emb.cos_cache, self.rotary_emb.sin_cache - # Quant path triggers when the bound cache is FP8 (Indexer-inner). - # `self.cache_scale` is bound alongside `self.kv_cache` by the V4 - # builder when the cache is FP8 (strided fp32 view of the per-block - # scale region). - is_quant = self.kv_cache is not None and self.kv_cache.dtype != torch.bfloat16 + # Three scatter modes, discriminated by the bound `write_mode`: + # - "indexer_fp8": per-row fp8 + preshuffle (Indexer-inner Compressor). + # - "main_2buff_fp8": CSA/HCA Main under --kv_cache_dtype fp8 — native + # group_fp8 2buff (nope-fp8 + inline e8m0 into `kv_cache`, bf16 rope + # into `kv_cache_rope`). Routed to the flydsl group_fp8 path. + # - "bf16": plain bf16 Main scatter. + # `self.cache_scale` is bound alongside an fp8 `kv_cache` by the V4 + # builder (indexer-inner only; group_fp8 carries scale inline so + # cache_scale stays None). + main_2buff_fp8 = self.write_mode == "main_2buff_fp8" + is_quant = self.write_mode == "indexer_fp8" # Skip the kernel's cache scatter during warmup (kv_cache/block_tables # not yet bound). if block_tables is None or self.kv_cache is None: scatter_kv_cache = None scatter_block_tables = None + scatter_kv_cache_rope = None else: scatter_kv_cache = self.kv_cache scatter_block_tables = block_tables + scatter_kv_cache_rope = self.kv_cache_rope if main_2buff_fp8 else None fused_compress_attn( kv_in=kv, score_in=score, @@ -1048,7 +1065,13 @@ def forward( cache_scale=self.cache_scale if is_quant else None, use_ue8m0=(self.scale_fmt == "ue8m0"), preshuffle=True, - fp8_max=(torch.finfo(self.kv_cache.dtype).max if is_quant else None), + fp8_max=( + torch.finfo(self.kv_cache.dtype).max + if (is_quant or main_2buff_fp8) and self.kv_cache is not None + else None + ), + main_2buff_fp8=main_2buff_fp8, + kv_cache_rope=scatter_kv_cache_rope, ) update_compressor_states( kv, @@ -1611,6 +1634,12 @@ def __init__( self.layer_name = prefix atom_config = get_current_atom_config() atom_config.compilation_config.static_forward_context[self.layer_name] = self + # Frozen bool: when KV cache dtype is fp8, route writes/attention to the + # native 2buff fp8 aiter operators (op1 fused norm+rope+group-quant, + # op4 fp8 prefill, op5 asm decode). Dynamo specializes on this constant + # so the bf16 path traces unchanged. Buffers (swa_kv_rope / unified_kv_rope) + # are bound onto the module by DeepseekV4AttentionMetadataBuilder. + self.kv_fp8 = atom_config.kv_cache_dtype == "fp8" def process_weights_after_loading(self) -> None: """Dequant wo_a (FP8 + e8m0 block scale) → BF16 in place. @@ -1809,30 +1838,78 @@ def forward_impl( # For decode, write_per_batch (= min(max_seqlen_q, cache_size)) >= # tokens-per-seq, so the fused per-token scatter (gated on batch_id>=0) # covers exactly the tokens the old standalone swa_write did. - q_sa, kv, q_scale, kv_scale = qk_norm_rope_maybe_quant( - q, - kv_pre, - self.kv_norm.weight, - self.rotary_emb.cos_cache, - self.rotary_emb.sin_cache, - positions, - self.n_local_heads, - self.head_dim, - rd, - self.eps, - quant_q=False, - quant_k=False, - swa_kv=self.swa_kv if is_decode else None, - state_slot_mapping=state_slot_mapping if is_decode else None, - batch_id_per_token=attn_md.batch_id_per_token if is_decode else None, - swa_cu_seqlens_q=attn_md.cu_seqlens_q if is_decode else None, - swa_cache_size=cache_size if is_decode else None, - swa_write_per_batch=( - min(attn_md.max_seqlen_q, cache_size) if is_decode else None - ), - ) - if _V4_USE_REF_QUANT: - act_quant_inplace(kv[..., :-rd], 64, self.scale_fmt) + # Native 2buff fp8 Q/K buffers (populated only on the fp8 path; both + # remain None for bf16). When fp8: aiter op1 + # `fused_qk_norm_rope_group_quant` does Q+K norm+rope+group-quant in one + # launch, emitting the 2buff layout (nope-fp8 [.,512] + rope-bf16 [.,64]) + # that op4 (prefill) and op5 (decode) consume directly — no torch quant. + q_packed = q_rope_q = k_packed = k_rope = None + if self.kv_fp8: + if is_decode: + # Decode: op1 fuses the SWA ring-scatter of K and returns the + # pre-packed fp8 Q for the asm decode kernel (op5). K buffers + # are throwaway (already scattered into the rings). + q_packed, q_rope_q, _, _ = qk_norm_rope_quant_2buff( + q, + kv_pre, + self.kv_norm.weight, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + positions, + self.n_local_heads, + self.head_dim, + rd, + self.eps, + swa_scatter=True, + state_slot_per_seq=state_slot_mapping, + batch_id_per_token=attn_md.batch_id_per_token, + swa_kv_nope=self.swa_kv, + swa_kv_rope=self.swa_kv_rope, + ) + q_sa = kv = q_scale = kv_scale = None + else: + # Prefill: op1 quantizes Q and the per-fwd extend K WITHOUT the + # ring-scatter (SWA tail is written after attn). The fp8 Q/K go + # to op4; the extend K (k_packed/k_rope) is scattered post-attn. + q_packed, q_rope_q, k_packed, k_rope = qk_norm_rope_quant_2buff( + q, + kv_pre, + self.kv_norm.weight, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + positions, + self.n_local_heads, + self.head_dim, + rd, + self.eps, + swa_scatter=False, + ) + q_sa = kv = q_scale = kv_scale = None + else: + q_sa, kv, q_scale, kv_scale = qk_norm_rope_maybe_quant( + q, + kv_pre, + self.kv_norm.weight, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + positions, + self.n_local_heads, + self.head_dim, + rd, + self.eps, + quant_q=False, + quant_k=False, + swa_kv=self.swa_kv if is_decode else None, + state_slot_mapping=state_slot_mapping if is_decode else None, + batch_id_per_token=attn_md.batch_id_per_token if is_decode else None, + swa_cu_seqlens_q=attn_md.cu_seqlens_q if is_decode else None, + swa_cache_size=cache_size if is_decode else None, + swa_write_per_batch=( + min(attn_md.max_seqlen_q, cache_size) if is_decode else None + ), + ) + if _V4_USE_REF_QUANT: + act_quant_inplace(kv[..., :-rd], 64, self.scale_fmt) # HCA if use_async_compress: @@ -1881,6 +1958,9 @@ def forward_impl( kv_indptr, self.attn_sink, self.softmax_scale, + unified_kv_rope=self.unified_kv_rope if self.kv_fp8 else None, + q_packed_in=q_packed, + q_rope_in=q_rope_q, ) # [S, H, head_dim] else: # Two-source paged prefill: prefix from `unified_kv` (per-ratio @@ -1897,29 +1977,67 @@ def forward_impl( kv_indptr_prefix = attn_md.kv_indptr_prefix_hca else: raise ValueError(f"Unsupported compress_ratio {ratio}") - o = sparse_attn_v4_paged_prefill( - q_sa, - self.unified_kv, - kv_indices_prefix, - kv_indptr_prefix, - kv, - attn_md.kv_indices_extend, - attn_md.kv_indptr_extend, - self.attn_sink, - self.softmax_scale, - ) # [S, H, head_dim] + if self.kv_fp8: + # Native fp8 prefill (op4): feed the 2buff fp8 prefix pool + # directly (nope-fp8 + rope-bf16) plus op1-quantized fp8 Q and + # extend K. No dequant of the prefix, no torch quant. gfx950. + from aiter.ops.pa_sparse_prefill_opus import ( + pa_sparse_prefill_fp8_opus, + ) + + o = pa_sparse_prefill_fp8_opus( + q_packed, + q_rope_q, + self.unified_kv, + self.unified_kv_rope, + kv_indices_prefix, + kv_indptr_prefix, + k_packed.view(k_packed.shape[0], -1), + k_rope.view(k_rope.shape[0], -1), + attn_md.kv_indices_extend, + attn_md.kv_indptr_extend, + self.attn_sink, + self.softmax_scale, + ) # [S, H, head_dim] bf16 + else: + o = sparse_attn_v4_paged_prefill( + q_sa, + self.unified_kv, + kv_indices_prefix, + kv_indptr_prefix, + kv, + attn_md.kv_indices_extend, + attn_md.kv_indptr_extend, + self.attn_sink, + self.softmax_scale, + ) # [S, H, head_dim] # swa_write AFTER attn so chunked-prefill prefix SWA reads see # prior-chunk's ring contents (current swa_write would overwrite # ring slots `pos % cache_size` for positions in this chunk's tail). - swa_write( - kv, - positions, - attn_md.cu_seqlens_q, - state_slot_mapping, - self.swa_kv, - cache_size, - min(attn_md.max_seqlen_q, cache_size), - ) + if self.kv_fp8: + # Scatter the op1-quantized extend K (already 2buff fp8) into the + # two SWA rings — pure copy, no re-quantization. + swa_write_2buff_prepacked( + k_packed.view(k_packed.shape[0], -1), + k_rope.view(k_rope.shape[0], -1), + positions, + attn_md.cu_seqlens_q, + state_slot_mapping, + self.swa_kv, + self.swa_kv_rope, + cache_size, + min(attn_md.max_seqlen_q, cache_size), + ) + else: + swa_write( + kv, + positions, + attn_md.cu_seqlens_q, + state_slot_mapping, + self.swa_kv, + cache_size, + min(attn_md.max_seqlen_q, cache_size), + ) # Inverse RoPE on output's rope dims to remove absolute-position # contribution carried in by the value-side RoPE of the KV entries.