Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 98 additions & 7 deletions atom/model_ops/attentions/deepseek_v4_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import numpy as np
import torch
from aiter import dtypes
from aiter.jit.utils.chip_info import get_gfx
from atom.model_engine.scheduler import ScheduledBatch
from atom.model_ops.attentions.backends import (
AttentionBackend,
Expand Down Expand Up @@ -296,6 +297,7 @@ def __init__(self, model_runner):
self.index_head_dim = getattr(hf, "index_head_dim", 128)
self.window_size = getattr(hf, "sliding_window", 128)
self.index_topk = getattr(hf, "index_topk", 1024)
self.rope_head_dim = getattr(hf, "qk_rope_head_dim", 64)
# MTP-portion of compress_ratios. `prepare_mtp_decode`'s direct-kernel
# fast path only handles SWA (ratio=0) draft layers; non-zero ratios
# would also need n_committed_{csa,hca} + HCA compress tail rebuilt.
Expand All @@ -316,8 +318,23 @@ def __init__(self, model_runner):
self.k2_hca = self.block_size // 128 # = 1

self._state_dtype = torch.float32 # fp32 required for softmax-pool
self._swa_dtype = torch.bfloat16 # SWA window matches KV dtype
self._classical_dtype = torch.bfloat16 # CSA Main / HCA Main KV is BF16
# KV cache dtype gate. fp8 → 2buff native layout (nope fp8 + inline e8m0
# scale in a 512B entry; parallel bf16 rope pool). bf16 → unchanged.
self._kv_fp8 = model_runner.kv_cache_dtype == "fp8"
if self._kv_fp8:
# nope pool is fp8 (e4m3); SWA and classical (CSA/HCA Main) share it.
self._swa_dtype = dtypes.fp8
self._classical_dtype = dtypes.fp8
self._rope_dtype = torch.bfloat16 # rope pool is always bf16
# aiter prefill (op4) / decode (op5) hard-check gfx950 internally.
assert get_gfx() == "gfx950", (
"DeepSeek-V4 --kv_cache_dtype fp8 requires gfx950 (MI350); "
f"got {get_gfx()!r}. Use --kv_cache_dtype bf16 on this platform."
)
else:
self._swa_dtype = torch.bfloat16 # SWA window matches KV dtype
self._classical_dtype = torch.bfloat16 # CSA/HCA Main KV is BF16
self._rope_dtype = torch.bfloat16 # unused in bf16 path (symmetry)
# CSA Indexer cache is FP8 + 4-byte fp32 scale per row, aligned to 16
# bytes (matches V3.2 sparse MLA pattern; avoids torch inductor
# unaligned-access slowdowns). Written by `indexer_k_quant_and_cache`,
Expand Down Expand Up @@ -408,14 +425,19 @@ def compute_per_req_cache_bytes(self) -> int:
for every Compressor instance (CSA Main / CSA Indexer / HCA Main).
"""
elem_state = self._state_dtype.itemsize # fp32 = 4
elem_swa = self._swa_dtype.itemsize # bf16 = 2
elem_swa = self._swa_dtype.itemsize # fp8 = 1 or bf16 = 2
# Tail buffers (kv_state + score_state pair per Compressor instance).
csa_main = self._numel(self.csa_main_state_shape) * 2 * elem_state
csa_idx = self._numel(self.csa_idx_state_shape) * 2 * elem_state
hca_main = self._numel(self.hca_main_state_shape) * 2 * elem_state
# SWA window per layer. Cache holds `win_with_spec = win + mtp_k`
# slots so MTP draft tokens don't alias verified-token slots.
swa_per_layer = self.win_with_spec * self.head_dim * elem_swa
if self._kv_fp8:
# 2buff: parallel bf16 rope pool [win_with_spec, rope_head_dim].
swa_per_layer += (
self.win_with_spec * self.rope_head_dim * self._rope_dtype.itemsize
)
return (
len(self.csa_layers) * (csa_main + csa_idx)
+ len(self.hca_layers) * hca_main
Expand All @@ -442,10 +464,15 @@ def compute_block_bytes(self) -> int:
read by `cp_gather_indexer_k_quant_cache`).
- HCA Main: k2=1 entry × head_dim BF16
"""
elem_bf16 = self._classical_dtype.itemsize
csa_main_per_block = self.k1_csa * self.head_dim * elem_bf16
elem_nope = self._classical_dtype.itemsize # fp8 = 1 or bf16 = 2
csa_main_per_block = self.k1_csa * self.head_dim * elem_nope
csa_idx_per_block = self.k1_csa * self._aligned_index_dim # fp8 = 1B
hca_main_per_block = self.k2_hca * self.head_dim * elem_bf16
hca_main_per_block = self.k2_hca * self.head_dim * elem_nope
if self._kv_fp8:
# 2buff: parallel bf16 rope pool per compress entry.
elem_rope = self._rope_dtype.itemsize
csa_main_per_block += self.k1_csa * self.rope_head_dim * elem_rope
hca_main_per_block += self.k2_hca * self.rope_head_dim * elem_rope
return (
len(self.csa_layers) * (csa_main_per_block + csa_idx_per_block)
+ len(self.hca_layers) * hca_main_per_block
Expand Down Expand Up @@ -507,7 +534,9 @@ def allocate_per_req_cache(self, num_slots: int) -> dict[str, object]:
"""
assert self._swa_dtype == self._classical_dtype, (
"unified_kv requires SWA dtype == classical KV dtype "
f"(got SWA={self._swa_dtype}, classical={self._classical_dtype})"
f"(got SWA={self._swa_dtype}, classical={self._classical_dtype}). "
"fp8 path must set both to dtypes.fp8 (rope lives in a separate "
"bf16 pool); a genuine mismatch corrupts the unified layout."
)
device = self.model_runner.device
num_blocks = self.model_runner.num_physical_kvcache_blocks
Expand Down Expand Up @@ -536,6 +565,30 @@ def allocate_per_req_cache(self, num_slots: int) -> dict[str, object]:
)
)

# ---- 2buff fp8: parallel per-layer rope pool (bf16) ------------------
# Same [swa_pages + compress_pages] page count as unified_kv, but width
# = rope_head_dim (64) and dtype bf16 (rope is never quantized). bf16
# path: list of None (no rope pool; rope stays inline in unified_kv).
unified_kv_rope: list[Optional[torch.Tensor]] = []
if self._kv_fp8:
for layer_id in range(self.num_layers):
ratio = ratios[layer_id]
if ratio == 4:
compress_pages = num_blocks * self.k1_csa
elif ratio == 128:
compress_pages = num_blocks * self.k2_hca
else:
compress_pages = 0 # Dense
unified_kv_rope.append(
torch.zeros(
(swa_pages + compress_pages, self.rope_head_dim),
dtype=self._rope_dtype,
device=device,
)
)
else:
unified_kv_rope = [None] * self.num_layers

# ---- Compressor state tensors (compute-contiguous) ------------------
csa_main_kv = self._zero_state(
(n_csa, num_slots, *self.csa_main_state_shape), device
Expand Down Expand Up @@ -580,6 +633,7 @@ def allocate_per_req_cache(self, num_slots: int) -> dict[str, object]:

return {
"v4_unified_kv": unified_kv,
"v4_unified_kv_rope": unified_kv_rope,
"v4_csa_main_kv_state": csa_main_kv,
"v4_csa_main_score_state": csa_main_score,
"v4_csa_idx_kv_state": csa_idx_kv,
Expand Down Expand Up @@ -626,6 +680,16 @@ def build_kv_cache_tensor(self, layer_id: int, module):
module.swa_kv = unified[:swa_pages].view(
num_slots, self.win_with_spec, self.head_dim
)
module.kv_fp8 = self._kv_fp8
if self._kv_fp8:
rope = runner.v4_unified_kv_rope[module.layer_id]
module.unified_kv_rope = rope
module.swa_kv_rope = rope[:swa_pages].view(
num_slots, self.win_with_spec, self.rope_head_dim
)
else:
module.unified_kv_rope = None
module.swa_kv_rope = None
return None

if isinstance(module, _V4Indexer):
Expand Down Expand Up @@ -685,6 +749,10 @@ def build_kv_cache_tensor(self, layer_id: int, module):
storage_offset=scale_fp32_offset,
)
)
# Indexer-inner cache is always fp8 (independent of
# kv_cache_dtype); it has no separate rope pool.
module.write_mode = "indexer_fp8"
module.kv_cache_rope = None
elif ratio == 4:
pos = self.layer_id_to_csa_pos[layer_id_from_prefix]
module.kv_state = runner.v4_csa_main_kv_state[pos]
Expand All @@ -698,6 +766,15 @@ def build_kv_cache_tensor(self, layer_id: int, module):
module.kv_cache = unified[swa_pages:].view(
num_blocks, self.k1_csa, self.head_dim
)
if self._kv_fp8:
rope = runner.v4_unified_kv_rope[layer_id_from_prefix]
module.kv_cache_rope = rope[swa_pages:].view(
num_blocks, self.k1_csa, self.rope_head_dim
)
module.write_mode = "main_2buff_fp8"
else:
module.kv_cache_rope = None
module.write_mode = "bf16"
elif ratio == 128:
pos = self.layer_id_to_hca_pos[layer_id_from_prefix]
module.kv_state = runner.v4_hca_main_kv_state[pos]
Expand All @@ -707,6 +784,15 @@ def build_kv_cache_tensor(self, layer_id: int, module):
module.kv_cache = unified[swa_pages:].view(
num_blocks, self.k2_hca, self.head_dim
)
if self._kv_fp8:
rope = runner.v4_unified_kv_rope[layer_id_from_prefix]
module.kv_cache_rope = rope[swa_pages:].view(
num_blocks, self.k2_hca, self.rope_head_dim
)
module.write_mode = "main_2buff_fp8"
else:
module.kv_cache_rope = None
module.write_mode = "bf16"
else:
raise ValueError(
f"Unknown V4 compress_ratio={ratio} on Compressor at "
Expand All @@ -725,6 +811,11 @@ def get_kv_transfer_tensors(self):
runner = self.model_runner
if not hasattr(runner, "v4_unified_kv"):
return None
if self._kv_fp8:
# PD disaggregation with 2buff fp8 KV cache is not yet supported:
# the byte-region math below assumes a single bf16 unified pool and
# ignores the parallel rope pool. Disable KV transfer for fp8.
return None

num_slots = runner.max_per_req_cache_slots
swa_pages = num_slots * self.win_with_spec
Expand Down
9 changes: 8 additions & 1 deletion atom/model_ops/v4_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,18 @@
qk_norm_rope_maybe_quant,
qk_norm_rope_maybe_quant_reference,
)
from atom.model_ops.v4_kernels.state_writes import update_compressor_states, swa_write
from atom.model_ops.v4_kernels.state_writes import (
qk_norm_rope_quant_2buff,
swa_write_2buff_prepacked,
update_compressor_states,
swa_write,
)

__all__ = [
"update_compressor_states",
"swa_write",
"swa_write_2buff_prepacked",
"qk_norm_rope_quant_2buff",
"fused_compress_attn",
"fused_compress_attn_reference",
"sparse_attn_v4_paged_decode",
Expand Down
22 changes: 19 additions & 3 deletions atom/model_ops/v4_kernels/fused_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,13 @@ def fused_compress_attn(
] = None, # fp32 [NB, k_per_block]; required when quant=True
use_ue8m0: bool = True, # round scale to power-of-2 (UE8M0); only when quant=True
preshuffle: bool = True, # MFMA 16x16 preshuffled FP8 layout; only when quant=True
fp8_max: Optional[float] = None, # E4M3 max; required when quant=True
fp8_max: Optional[float] = None, # E4M3 max; required when quant or main_2buff_fp8
# V4-Main native fp8 2buff path (CSA/HCA Main under --kv_cache_dtype fp8).
# Distinct from `quant` (Indexer-inner per-row preshuffle): writes per-64-tile
# e8m0 nope-fp8 + inline dup-scale into `kv_cache` (fp8 [NB,k,512]) and bf16
# rope into `kv_cache_rope` (bf16 [NB,k,64]) via the flydsl group_fp8 scatter.
main_2buff_fp8: bool = False,
kv_cache_rope: Optional[torch.Tensor] = None, # bf16 [NB,k_per_block,64]
) -> None:
"""Batched fused per-source-position pool + RMSNorm + RoPE + cache scatter,
dispatched via SGLang-style packed plan.
Expand Down Expand Up @@ -462,6 +468,8 @@ def fused_compress_attn(
and _shape_key == (512, 64, 128, False)
)
if _hca_use:
# main_2buff_fp8: native group_fp8 2buff scatter (nope-fp8 into
# kv_cache, bf16 rope into k_rope_cache). Otherwise plain bf16.
flydsl_hca_compress_attn(
kv_in=kv_in,
score_in=score_in,
Expand All @@ -480,9 +488,15 @@ def fused_compress_attn(
ratio=ratio,
head_dim=head_dim,
rope_head_dim=rope_head_dim,
quant=main_2buff_fp8,
k_rope_cache=kv_cache_rope if main_2buff_fp8 else None,
quant_group_size=64,
)
return
if _flydsl_use:
# main_2buff_fp8: CSA Main native group_fp8 2buff (nope-fp8 + inline
# e8m0 into kv_cache, bf16 rope into k_rope_cache; scale carried inline
# so cache_scale stays None). Indexer-inner uses per_row_fp8 preshuffle.
flydsl_fused_compress_attn(
kv_in=kv_in,
score_in=score_in,
Expand All @@ -502,10 +516,12 @@ def fused_compress_attn(
ratio=ratio,
head_dim=head_dim,
rope_head_dim=rope_head_dim,
quant=quant,
quant=quant or main_2buff_fp8,
cache_scale=cache_scale,
use_ue8m0=use_ue8m0,
preshuffle=preshuffle,
preshuffle=preshuffle and not main_2buff_fp8,
quant_mode="group_fp8" if main_2buff_fp8 else "per_row_fp8",
k_rope_cache=kv_cache_rope if main_2buff_fp8 else None,
)
return

Expand Down
Loading
Loading