From 8659b84bea3f1efc2465d3205887f70b8a4f58bf Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 26 Jun 2026 03:17:07 +0000 Subject: [PATCH 1/4] [ATOM SGL] MTP Spec decode --- atom/plugin/register.py | 110 ++++++ .../attention_backend/deepseek_v4_backend.py | 116 ++++-- .../full_attention/full_attention_backend.py | 373 +++++++++++++++++- .../attention_backend/sparse_mla_indexer.py | 72 ++++ atom/plugin/sglang/deepseek_v4_bridge.py | 114 ++++-- .../sglang/models/base_model_wrapper.py | 24 +- .../sglang/models/deepseek_mla_attention.py | 66 ++++ .../sglang/models/deepseek_mla_forward.py | 11 +- atom/plugin/sglang/runtime/forward_context.py | 142 ++++++- 9 files changed, 953 insertions(+), 75 deletions(-) diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 577d6588e8..074d50983f 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -99,11 +99,121 @@ def create_dsv4_backend(runner): return ATOMDeepseekV4BackendForSgl(runner) +def _patch_sglang_dsv4_draft_backends() -> None: + """Route SGLang's hard-coded DSV4 speculative factories to ATOM. + + DraftBackendFactory constructs DeepSeek-V4 draft backends directly instead + of going through the attention registry. SGLang's native backend asserts a + native DeepSeekV4TokenToKVPool, while ATOM plugin mode uses a proxy KV pool, + so patch the factory methods to return the ATOM shim. + """ + + try: + from sglang.srt.speculative.draft_utils import DraftBackendFactory + from atom.plugin.sglang.attention_backend.deepseek_v4_backend import ( + ATOMDeepseekV4BackendForSgl, + ) + except Exception as exc: + logger.debug("Skip patching SGLang DSV4 draft backends: %s", exc) + return + + if getattr(DraftBackendFactory, "_atom_dsv4_draft_backend_patched", False): + return + + def _create_atom_dsv4_decode_backend(self): + return ATOMDeepseekV4BackendForSgl( + self.draft_model_runner, + topk=self.topk, + speculative_num_steps=self.speculative_num_steps, + ) + + def _create_atom_dsv4_prefill_backend(self): + return ATOMDeepseekV4BackendForSgl( + self.draft_model_runner, + skip_prefill=False, + ) + + DraftBackendFactory._create_dsv4_decode_backend = _create_atom_dsv4_decode_backend + DraftBackendFactory._create_dsv4_prefill_backend = _create_atom_dsv4_prefill_backend + DraftBackendFactory._atom_dsv4_draft_backend_patched = True + logger.info("Patched SGLang DSV4 speculative draft backends to ATOM") + + +def _patch_sglang_dsv4_spec_cuda_graph() -> None: + """Avoid replaying generic SGLang graphs for DSV4 speculative extend modes. + + The target decode graph is still useful and remains enabled. Target verify + and draft-extend need DSV4-specific per-token metadata; until that metadata + is fully graph-safe, let those forwards run eager to avoid replaying a graph + captured with decode-shaped metadata. + """ + + try: + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + from sglang.srt.speculative.eagle_worker_v2 import EagleDraftWorker + except Exception as exc: + logger.debug("Skip patching SGLang DSV4 spec cuda graph: %s", exc) + return + + if getattr(CudaGraphRunner, "_atom_dsv4_spec_can_run_patched", False): + return + + original_can_run = CudaGraphRunner.can_run + + def can_run(self, forward_batch): + try: + model_runner = getattr(self, "model_runner", None) + hf_config = getattr(getattr(model_runner, "model_config", None), "hf_config", None) + arches = getattr(hf_config, "architectures", None) or [] + is_dsv4 = any("DeepseekV4" in str(arch) for arch in arches) + mode = getattr(forward_batch, "forward_mode", None) + is_spec_extend = ( + bool(getattr(mode, "is_target_verify", lambda: False)()) + or bool( + getattr(mode, "is_draft_extend", lambda **kwargs: False)( + include_v2=True + ) + ) + ) + if is_dsv4 and is_spec_extend: + return False + except Exception: + pass + return original_can_run(self, forward_batch) + + CudaGraphRunner.can_run = can_run + CudaGraphRunner._atom_dsv4_spec_can_run_patched = True + + if not getattr(EagleDraftWorker, "_atom_dsv4_init_cuda_graphs_patched", False): + original_init_cuda_graphs = EagleDraftWorker.init_cuda_graphs + + def init_cuda_graphs(self): + try: + arch = ( + self.draft_runner.model_config.hf_config.architectures[0] + ) + if arch == "DeepseekV4ForCausalLMNextN": + self.cuda_graph_runner = None + self.cuda_graph_runner_for_draft_extend = None + logger.info("Skip DSV4 draft cuda graph capture in ATOM plugin") + return + except Exception: + pass + return original_init_cuda_graphs(self) + + EagleDraftWorker.init_cuda_graphs = init_cuda_graphs + EagleDraftWorker._atom_dsv4_init_cuda_graphs_patched = True + + logger.info("Patched SGLang DSV4 speculative cuda graph handling") + + def register_ops_to_sglang(atom_config: Config) -> None: """ Register custom ops to sglang, including attention """ _register_custom_attention_to_sglang() + _patch_sglang_dsv4_draft_backends() + _patch_sglang_dsv4_spec_cuda_graph() def set_attn_cls() -> None: diff --git a/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py b/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py index 9bc772add7..098371b351 100644 --- a/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py +++ b/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py @@ -16,9 +16,10 @@ class ATOMDeepseekV4BackendForSgl(AttentionBackend): """ needs_cpu_seq_lens = True + _last_atom_v4_graph_metadata = None def __init__(self, model_runner, *args, **kwargs): - del args, kwargs + del args logger.info("Initializing ATOMDeepseekV4BackendForSgl") self.model_runner = model_runner self.device = torch.device(model_runner.device) @@ -26,6 +27,12 @@ def __init__(self, model_runner, *args, **kwargs): self.req_to_token_pool = model_runner.req_to_token_pool self.forward_metadata = None self.atom_v4_graph_metadata = None + speculative_num_steps = int(kwargs.pop("speculative_num_steps", 0) or 0) + # SGLang EAGLE multi-step draft code expects decode backends to expose + # one attention backend per draft step. ATOM DSV4 owns the real + # per-layer state in the model/bridge, so all draft steps can share this + # shim instance. + self.attn_backends = [self] * max(1, speculative_num_steps) @staticmethod def get_name() -> str: @@ -37,14 +44,17 @@ def init_forward_metadata(self, forward_batch): def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = False): self.forward_metadata = forward_batch + logger.info( + "ATOM DSV4 init_forward_metadata_out_graph: in_capture=%s mode=%s bs=%s", + in_capture, + getattr(getattr(forward_batch, "forward_mode", None), "name", None), + getattr(forward_batch, "batch_size", None), + ) if not (in_capture or hasattr(forward_batch, "actual_forward_mode")): self.atom_v4_graph_metadata = None return - if not forward_batch.forward_mode.is_decode_or_idle(): - self.atom_v4_graph_metadata = None - return - from atom.plugin.sglang.deepseek_v4_bridge import ( + build_atom_v4_attention_metadata_from_sglang, build_atom_v4_decode_graph_metadata_from_sglang, ) @@ -55,15 +65,36 @@ def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = Fals positions = getattr(buffers, "positions", None) if positions is None: self.atom_v4_graph_metadata = None + logger.info( + "Skip ATOM DeepSeek-V4 graph metadata init: positions unavailable" + ) return atom_model = getattr(getattr(self.model_runner, "model", None), "model", None) - self.atom_v4_graph_metadata = build_atom_v4_decode_graph_metadata_from_sglang( - forward_batch, - positions, - proxy_pool=self.token_to_kv_pool, - req_to_token_pool=self.req_to_token_pool, - model=atom_model, + if forward_batch.forward_mode.is_decode_or_idle(): + self.atom_v4_graph_metadata = build_atom_v4_decode_graph_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + model=atom_model, + ) + else: + self.atom_v4_graph_metadata = build_atom_v4_attention_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + ) + forward_batch.atom_v4_graph_metadata = self.atom_v4_graph_metadata + ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = ( + self.atom_v4_graph_metadata + ) + logger.info( + "ATOM DSV4 graph metadata initialized: mode=%s bs=%s metadata=%s", + getattr(getattr(forward_batch, "forward_mode", None), "name", None), + getattr(forward_batch, "batch_size", None), + type(self.atom_v4_graph_metadata).__name__, ) def _init_decode_cuda_graph_metadata( @@ -114,18 +145,24 @@ def _init_decode_cuda_graph_metadata( req_to_token_pool=self.req_to_token_pool, model=atom_model, ) + forward_batch.atom_v4_graph_metadata = self.atom_v4_graph_metadata + ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = ( + self.atom_v4_graph_metadata + ) - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - num_tokens: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens, - forward_mode, - spec_info, - ): - del num_tokens, encoder_lens, spec_info + def init_forward_metadata_capture_cuda_graph(self, *args, **kwargs): + # New SGLang graph API passes a ForwardBatch. Older call sites pass + # unpacked fields. Support both because speculative draft graph code + # still calls this legacy-named hook directly. + if len(args) == 1 and not kwargs and hasattr(args[0], "forward_mode"): + return self.init_forward_metadata_out_graph(args[0], in_capture=True) + + bs = kwargs.get("bs", args[0] if len(args) > 0 else None) + req_pool_indices = kwargs.get( + "req_pool_indices", args[2] if len(args) > 2 else None + ) + seq_lens = kwargs.get("seq_lens", args[3] if len(args) > 3 else None) + forward_mode = kwargs.get("forward_mode", args[5] if len(args) > 5 else None) self._init_decode_cuda_graph_metadata( bs=bs, req_pool_indices=req_pool_indices, @@ -158,7 +195,40 @@ def init_forward_metadata_replay_cuda_graph( ) def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): - del max_bs, max_num_tokens + from sglang.srt.model_executor.forward_batch_info import ForwardMode + + from atom.plugin.sglang.deepseek_v4_bridge import ( + build_atom_v4_decode_graph_metadata_from_sglang, + ) + + bs = int(max_bs) + tokens_per_req = max(1, int(max_num_tokens) // max(1, bs)) + seq_lens = torch.full( + (bs,), tokens_per_req, dtype=torch.int32, device=self.device + ) + req_pool_indices = torch.arange(bs, dtype=torch.int64, device=self.device) + positions = torch.arange(tokens_per_req, dtype=torch.int64, device=self.device) + positions = positions.repeat(bs) + forward_batch = SimpleNamespace( + forward_mode=ForwardMode.DECODE, + actual_forward_mode=ForwardMode.DECODE, + batch_size=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.detach().cpu(), + out_cache_loc=None, + ) + atom_model = getattr(getattr(self.model_runner, "model", None), "model", None) + self.atom_v4_graph_metadata = build_atom_v4_decode_graph_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + model=atom_model, + ) + ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = ( + self.atom_v4_graph_metadata + ) return None def get_cuda_graph_seq_len_fill_value(self): diff --git a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index 596d350d4c..dc80850b69 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -173,6 +173,93 @@ def __init__( self.prefill_ps_num_kv_splits = cu_num // math.gcd(self.num_kv_head, cu_num) else: self.prefill_ps_num_kv_splits = None + self._mtp_debug_counts: dict[tuple[int, int, str], int] = {} + + def _debug_mtp_tensor(self, name: str, tensor: Optional[torch.Tensor]) -> str: + if tensor is None: + return f"{name}=None" + flat = tensor.detach().flatten() + if flat.numel() == 0: + return f"{name}=empty shape={tuple(tensor.shape)} dtype={tensor.dtype}" + sample = flat[: min(6, flat.numel())].to(torch.float32).cpu().tolist() + stats_tensor = flat.to(torch.float32) + return ( + f"{name}=shape={tuple(tensor.shape)} dtype={tensor.dtype} " + f"min={float(stats_tensor.min().item()):.6g} " + f"max={float(stats_tensor.max().item()):.6g} " + f"mean={float(stats_tensor.mean().item()):.6g} " + f"sample={sample}" + ) + + def _debug_mtp_target_verify( + self, + tag: str, + layer: "RadixAttention", + forward_batch: ForwardBatch, + q: Optional[torch.Tensor] = None, + o: Optional[torch.Tensor] = None, + ) -> None: + if os.getenv("ATOM_DEBUG_MTP_VERIFY", "0") != "1": + return + rank = -1 + try: + from sglang.srt.distributed import get_tp_group + + rank = int(get_tp_group().rank_in_group) + except Exception: + rank = int(os.getenv("RANK", "-1")) + debug_ranks = { + int(x) + for x in os.getenv("ATOM_DEBUG_MTP_VERIFY_RANKS", "0").split(",") + if x.strip() + } + if rank not in debug_ranks: + return + bs = int(forward_batch.batch_size) + debug_bs = { + int(x) + for x in os.getenv("ATOM_DEBUG_MTP_VERIFY_BS", "63,64").split(",") + if x.strip() + } + if bs not in debug_bs: + return + layer_id = int(getattr(layer, "layer_id", -1)) + debug_layers = { + int(x) + for x in os.getenv("ATOM_DEBUG_MTP_VERIFY_LAYERS", "0,1,60").split(",") + if x.strip() + } + if layer_id not in debug_layers: + return + key = (bs, layer_id, tag) + max_hits = int(os.getenv("ATOM_DEBUG_MTP_VERIFY_MAX_HITS", "4")) + hits = self._mtp_debug_counts.get(key, 0) + if hits >= max_hits: + return + self._mtp_debug_counts[key] = hits + 1 + + md = self.forward_metadata + spec_info = getattr(forward_batch, "spec_info", None) + draft_num = getattr(spec_info, "draft_token_num", None) + pieces = [ + f"[ATOM_MTP_DEBUG] tag={tag} hit={hits + 1} rank={rank} layer={layer_id} bs={bs}", + f"mode={forward_batch.forward_mode}", + f"draft_num={draft_num}", + f"max_q_len={getattr(md, 'max_q_len', None)}", + f"num_kv_splits={getattr(md, 'num_kv_splits', None)}", + f"kv_indices_len={None if md.kv_indices is None else md.kv_indices.numel()}", + f"work_metadata_shape={None if md.work_metadata is None else tuple(md.work_metadata.shape)}", + self._debug_mtp_tensor("kv_indices", md.kv_indices[: min(64, md.kv_indices.numel())] if md.kv_indices is not None else None), + self._debug_mtp_tensor("seq_lens", forward_batch.seq_lens[:bs]), + self._debug_mtp_tensor("req_pool", forward_batch.req_pool_indices[:bs]), + self._debug_mtp_tensor("out_cache_loc", getattr(forward_batch, "out_cache_loc", None)), + self._debug_mtp_tensor("qo_indptr", md.qo_indptr[: bs + 1] if md.qo_indptr is not None else None), + self._debug_mtp_tensor("kv_indptr", md.kv_indptr[: bs + 1] if md.kv_indptr is not None else None), + self._debug_mtp_tensor("kv_last_page_len", md.kv_last_page_len[:bs] if md.kv_last_page_len is not None else None), + self._debug_mtp_tensor("q", q), + self._debug_mtp_tensor("o", o), + ] + print(" | ".join(pieces), flush=True) def _cuda_graph_mla_max_seqlen_qo(self) -> int: """Largest q length used by MLA CUDA graph speculative paths.""" @@ -1728,6 +1815,83 @@ def forward_extend( ) elif self.use_mla: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + if os.getenv("ATOM_DEBUG_DUMP_KV_WRITE", "0") == "1": + try: + from sglang.srt.distributed import get_tp_group + + rank = int(get_tp_group().rank_in_group) + except Exception: + rank = int(os.getenv("RANK", "-1")) + dump_ranks = { + int(x) + for x in os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_RANKS", "0").split(",") + if x.strip() + } + dump_layers = { + int(x) + for x in os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_LAYERS", "0").split(",") + if x.strip() + } + dump_bs = { + int(x) + for x in os.getenv( + "ATOM_DEBUG_DUMP_KV_WRITE_BS", + str(int(forward_batch.batch_size)), + ).split(",") + if x.strip() + } + if ( + rank in dump_ranks + and int(getattr(layer, "layer_id", -1)) in dump_layers + and int(forward_batch.batch_size) in dump_bs + ): + if not hasattr(self, "_debug_kv_write_counts"): + self._debug_kv_write_counts = {} + key = ( + rank, + int(getattr(layer, "layer_id", -1)), + int(forward_batch.batch_size), + ) + hits = self._debug_kv_write_counts.get(key, 0) + max_hits = int(os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_MAX_HITS", "2")) + if hits < max_hits: + self._debug_kv_write_counts[key] = hits + 1 + dump_dir = os.getenv( + "ATOM_DEBUG_DUMP_KV_WRITE_DIR", + "/home/qichu_qle/zhiwei/dsv4/atom/work_logs/bs64_issue/rootcause_20260620_kv_write", + ) + os.makedirs(dump_dir, exist_ok=True) + k_buffer = forward_batch.token_to_kv_pool.get_key_buffer( + layer.layer_id + ) + loc = cache_loc.detach().long() + dump_path = os.path.join( + dump_dir, + f"rank{rank}_layer{int(getattr(layer, 'layer_id', -1))}_bs{int(forward_batch.batch_size)}_hit{hits + 1}.pt", + ) + torch.save( + { + "rank": rank, + "layer": int(getattr(layer, "layer_id", -1)), + "batch_size": int(forward_batch.batch_size), + "forward_mode": str(forward_batch.forward_mode), + "cache_loc": cache_loc.detach().cpu(), + "positions": None + if getattr(forward_batch, "positions", None) is None + else forward_batch.positions.detach().cpu(), + "seq_lens": None + if getattr(forward_batch, "seq_lens", None) is None + else forward_batch.seq_lens.detach().cpu(), + "req_pool_indices": None + if getattr(forward_batch, "req_pool_indices", None) is None + else forward_batch.req_pool_indices.detach().cpu(), + "k_input": k.detach().cpu(), + "v_input": v.detach().cpu(), + "k_after_write": k_buffer[loc].detach().cpu(), + }, + dump_path, + ) + print(f"[ATOM_KV_WRITE_DUMP] path={dump_path}", flush=True) else: k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id @@ -1946,6 +2110,8 @@ def _try_fused_mxfp4_kv_b_proj_fp8( qk_nope_head_dim, ): """Return FP8 k/v from MXFP4 kv_b_proj when the fused split-cat path fits.""" + if os.getenv("ATOM_DEBUG_DISABLE_MXFP4_KVB_FUSED", "0") == "1": + return None if fused_gemm_afp4wfp4_preshuffle_split_cat is None: return None weight = getattr(layer.kv_b_proj, "weight", None) @@ -2288,7 +2454,212 @@ def _forward_extend_mla_speculative( (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), dtype=self.input_dtype, ) - self._call_mla_decode_fwd(q, K_Buffer, o, layer) + dump_enabled = os.getenv("ATOM_DUMP_MLA_VERIFY_REPRO", "0") == "1" + dump_hit = False + dump_path = None + if dump_enabled: + try: + from sglang.srt.distributed import get_tp_group + + rank = int(get_tp_group().rank_in_group) + except Exception: + rank = int(os.getenv("RANK", "-1")) + dump_ranks = { + int(x) + for x in os.getenv("ATOM_DUMP_MLA_VERIFY_RANKS", "0").split(",") + if x.strip() + } + dump_layers = { + int(x) + for x in os.getenv("ATOM_DUMP_MLA_VERIFY_LAYERS", "0").split(",") + if x.strip() + } + key = ( + int(forward_batch.batch_size), + int(getattr(layer, "layer_id", -1)), + "dump_mla_verify", + ) + if not hasattr(self, "_dump_mla_verify_counts"): + self._dump_mla_verify_counts = {} + hits = self._dump_mla_verify_counts.get(key, 0) + max_hits = int(os.getenv("ATOM_DUMP_MLA_VERIFY_MAX_HITS", "1")) + dump_bs = { + int(x) + for x in os.getenv("ATOM_DUMP_MLA_VERIFY_BS", str(int(forward_batch.batch_size))).split(",") + if x.strip() + } + dump_hit = ( + rank in dump_ranks + and int(forward_batch.batch_size) in dump_bs + and int(getattr(layer, "layer_id", -1)) in dump_layers + and hits < max_hits + ) + if dump_hit: + self._dump_mla_verify_counts[key] = hits + 1 + dump_dir = os.getenv( + "ATOM_DUMP_MLA_VERIFY_DIR", + "/home/qichu_qle/zhiwei/dsv4/atom/work_logs/bs64_issue/fixed_prompt_rootcause_20260618/mla_kernel_repro", + ) + os.makedirs(dump_dir, exist_ok=True) + draft_num = getattr(forward_batch.spec_info, "draft_token_num", -1) + dump_path = os.path.join( + dump_dir, + f"rank{rank}_layer{int(getattr(layer, 'layer_id', -1))}_bs{int(forward_batch.batch_size)}_draft{int(draft_num)}_hit{hits + 1}.pt", + ) + md = self.forward_metadata + compact_k = K_Buffer[md.kv_indices.long()].contiguous() + unique_kv_indices = torch.unique(md.kv_indices.long()) + raw_k_slots = K_Buffer[unique_kv_indices].contiguous() + torch.save( + { + "q": q.detach().cpu(), + "k_compact": compact_k.detach().cpu(), + "unique_kv_indices": unique_kv_indices.detach().cpu(), + "raw_k_slots": raw_k_slots.detach().cpu(), + "qo_indptr": md.qo_indptr.detach().cpu(), + "kv_indptr": md.kv_indptr.detach().cpu(), + "compact_kv_indices": torch.arange( + compact_k.shape[0], dtype=torch.int32 + ), + "kv_indices_original": None + if md.kv_indices is None + else md.kv_indices.detach().cpu(), + "kv_last_page_len": md.kv_last_page_len.detach().cpu(), + "seq_lens": None + if getattr(forward_batch, "seq_lens", None) is None + else forward_batch.seq_lens.detach().cpu(), + "req_pool_indices": None + if getattr(forward_batch, "req_pool_indices", None) is None + else forward_batch.req_pool_indices.detach().cpu(), + "out_cache_loc": None + if getattr(forward_batch, "out_cache_loc", None) is None + else forward_batch.out_cache_loc.detach().cpu(), + "positions": None + if getattr(forward_batch, "positions", None) is None + else forward_batch.positions.detach().cpu(), + "max_q_len": int(md.max_q_len), + "num_kv_splits": int(md.num_kv_splits or 0), + "q_scale": None + if getattr(layer, "k_scale", None) is None + else layer.k_scale.detach().cpu(), + "kv_scale": None + if getattr(layer, "k_scale", None) is None + else layer.k_scale.detach().cpu(), + "tp_q_head_num": int(layer.tp_q_head_num), + "qk_head_dim": int(layer.qk_head_dim), + "v_head_dim": int(layer.v_head_dim), + "scaling": float(layer.scaling), + "logit_cap": float(layer.logit_cap), + "batch_size": int(forward_batch.batch_size), + "draft_num": int(draft_num), + }, + dump_path, + ) + print(f"[ATOM_MLA_REPRO_DUMP] before path={dump_path}", flush=True) + self._debug_mtp_target_verify("before_mla_decode", layer, forward_batch, q=q) + if ( + os.getenv("ATOM_DEBUG_SPLIT_TARGET_VERIFY_Q2", "0") == "1" + and int(getattr(md, "max_q_len", 0)) == 4 + and int(forward_batch.batch_size) > 0 + ): + bs = int(forward_batch.batch_size) + idx_first = torch.tensor( + [r * 4 + j for r in range(bs) for j in (0, 1)], + device=q.device, + dtype=torch.long, + ) + idx_second = torch.tensor( + [r * 4 + j for r in range(bs) for j in (2, 3)], + device=q.device, + dtype=torch.long, + ) + qo2 = torch.arange(0, (bs + 1) * 2, 2, dtype=torch.int32, device=q.device) + ( + work_metadata2, + work_indptr2, + work_info_set2, + reduce_indptr2, + reduce_final_map2, + reduce_partial_map2, + ) = self.make_mla_decode_meta_data_buffer(2, bs) + num_kv_splits2 = self.max_split_per_batch + self.make_mla_meta_data( + qo2, + md.kv_indptr, + md.kv_last_page_len, + work_metadata2, + work_info_set2, + work_indptr2, + reduce_indptr2, + reduce_final_map2, + reduce_partial_map2, + 2, + fast_mode=_sglang_aiter.fast_mode, + max_split_per_batch=num_kv_splits2, + intra_batch_mode=_sglang_aiter.intra_batch_mode, + ) + old_md = self.forward_metadata + try: + self.forward_metadata = ForwardMetadata( + md.kv_indptr, + md.kv_indices, + qo2, + md.kv_last_page_len, + 2, + None, + None, + None, + work_metadata=work_metadata2, + work_info_set=work_info_set2, + work_indptr=work_indptr2, + reduce_indptr=reduce_indptr2, + reduce_final_map=reduce_final_map2, + reduce_partial_map=reduce_partial_map2, + num_kv_splits=num_kv_splits2, + run_graph=False, + ) + o_first = o.new_empty((bs * 2, layer.tp_q_head_num, layer.v_head_dim)) + o_second = torch.empty_like(o_first) + self._call_mla_decode_fwd(q[idx_first].contiguous(), K_Buffer, o_first, layer) + self._call_mla_decode_fwd(q[idx_second].contiguous(), K_Buffer, o_second, layer) + o[idx_first] = o_first + o[idx_second] = o_second + finally: + self.forward_metadata = old_md + print( + f"[ATOM_MTP_DEBUG] tag=split_target_verify_q2 bs={bs} layer={int(getattr(layer, 'layer_id', -1))}", + flush=True, + ) + elif ( + os.getenv("ATOM_DEBUG_TARGET_VERIFY_PREFILL_FWD", "0") == "1" + and int(getattr(md, "max_q_len", 0)) == 4 + ): + o_prefill = self._extend_mla_absorbed_prefix( + q, + layer, + K_Buffer, + md.kv_indptr, + md.kv_indices, + md.qo_indptr, + ) + if o_prefill.ndim == 2: + o_prefill = o_prefill.view(-1, layer.tp_q_head_num, layer.v_head_dim) + o.copy_(o_prefill) + print( + f"[ATOM_MTP_DEBUG] tag=target_verify_prefill_fwd bs={int(forward_batch.batch_size)} " + f"layer={int(getattr(layer, 'layer_id', -1))}", + flush=True, + ) + else: + self._call_mla_decode_fwd(q, K_Buffer, o, layer) + if dump_hit and dump_path is not None: + saved = torch.load(dump_path, map_location="cpu") + saved["o"] = o.detach().cpu() + torch.save(saved, dump_path) + print(f"[ATOM_MLA_REPRO_DUMP] after path={dump_path}", flush=True) + self._debug_mtp_target_verify( + "after_mla_decode", layer, forward_batch, q=q, o=o + ) return o if forward_batch.forward_mode.is_draft_extend(include_v2=True): diff --git a/atom/plugin/sglang/attention_backend/sparse_mla_indexer.py b/atom/plugin/sglang/attention_backend/sparse_mla_indexer.py index ee9e46565d..a0c29b7431 100644 --- a/atom/plugin/sglang/attention_backend/sparse_mla_indexer.py +++ b/atom/plugin/sglang/attention_backend/sparse_mla_indexer.py @@ -6,6 +6,7 @@ from __future__ import annotations import re +import os from typing import Optional import torch @@ -239,6 +240,77 @@ def forward_sparse_mla_for_sglang( forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v ) + if os.getenv("ATOM_DEBUG_DUMP_KV_WRITE", "0") == "1": + try: + from sglang.srt.distributed import get_tp_group + + rank = int(get_tp_group().rank_in_group) + except Exception: + rank = int(os.getenv("RANK", "-1")) + dump_ranks = { + int(x) + for x in os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_RANKS", "0").split(",") + if x.strip() + } + dump_layers = { + int(x) + for x in os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_LAYERS", "0").split(",") + if x.strip() + } + dump_bs = { + int(x) + for x in os.getenv( + "ATOM_DEBUG_DUMP_KV_WRITE_BS", + str(int(forward_batch.batch_size)), + ).split(",") + if x.strip() + } + if ( + rank in dump_ranks + and int(getattr(layer, "layer_id", -1)) in dump_layers + and int(forward_batch.batch_size) in dump_bs + ): + if not hasattr(forward_sparse_mla_for_sglang, "_debug_kv_write_counts"): + forward_sparse_mla_for_sglang._debug_kv_write_counts = {} + key = (rank, int(getattr(layer, "layer_id", -1)), int(forward_batch.batch_size)) + hits = forward_sparse_mla_for_sglang._debug_kv_write_counts.get(key, 0) + max_hits = int(os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_MAX_HITS", "2")) + if hits < max_hits: + forward_sparse_mla_for_sglang._debug_kv_write_counts[key] = hits + 1 + dump_dir = os.getenv( + "ATOM_DEBUG_DUMP_KV_WRITE_DIR", + "/home/qichu_qle/zhiwei/dsv4/atom/work_logs/bs64_issue/rootcause_20260620_kv_write", + ) + os.makedirs(dump_dir, exist_ok=True) + k_buffer_after = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + loc = forward_batch.out_cache_loc.detach().long() + dump_path = os.path.join( + dump_dir, + f"sparse_rank{rank}_layer{int(getattr(layer, 'layer_id', -1))}_bs{int(forward_batch.batch_size)}_hit{hits + 1}.pt", + ) + torch.save( + { + "rank": rank, + "layer": int(getattr(layer, "layer_id", -1)), + "batch_size": int(forward_batch.batch_size), + "forward_mode": str(forward_batch.forward_mode), + "cache_loc": forward_batch.out_cache_loc.detach().cpu(), + "positions": None + if getattr(forward_batch, "positions", None) is None + else forward_batch.positions.detach().cpu(), + "seq_lens": None + if getattr(forward_batch, "seq_lens", None) is None + else forward_batch.seq_lens.detach().cpu(), + "req_pool_indices": None + if getattr(forward_batch, "req_pool_indices", None) is None + else forward_batch.req_pool_indices.detach().cpu(), + "k_input": k.detach().cpu(), + "v_input": v.detach().cpu(), + "k_after_write": k_buffer_after[loc].detach().cpu(), + }, + dump_path, + ) + print(f"[ATOM_KV_WRITE_DUMP] path={dump_path}", flush=True) q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) num_tokens = q.shape[0] diff --git a/atom/plugin/sglang/deepseek_v4_bridge.py b/atom/plugin/sglang/deepseek_v4_bridge.py index 2f3c4da3e8..3905308e9a 100644 --- a/atom/plugin/sglang/deepseek_v4_bridge.py +++ b/atom/plugin/sglang/deepseek_v4_bridge.py @@ -335,6 +335,19 @@ def _bind_compressor_state( compressor.cache_scale = None +def _iter_deepseek_v4_cache_blocks(model): + inner = getattr(model, "model", None) + if inner is None: + return [] + layers = getattr(inner, "layers", None) + if layers is not None: + return list(layers) + mtp = getattr(inner, "mtp", None) + if mtp is not None: + return list(mtp) + return [] + + def bind_deepseek_v4_proxy_cache_views(model, proxy_pool: Any) -> bool: if not getattr(proxy_pool, "is_atom_v4_proxy_pool", False): return False @@ -344,7 +357,7 @@ def bind_deepseek_v4_proxy_cache_views(model, proxy_pool: Any) -> bool: csa_i = 0 hca_i = 0 - for local_layer_id, block in enumerate(model.model.layers): + for local_layer_id, block in enumerate(_iter_deepseek_v4_cache_blocks(model)): attn = block.attn ratio = int(attn.compress_ratio) attn.unified_kv = proxy_pool.views["unified"][local_layer_id] @@ -572,7 +585,9 @@ def _get_seq_lens_cpu(forward_batch) -> np.ndarray: seq_lens_cpu = getattr(forward_batch, "seq_lens_cpu", None) if seq_lens_cpu is None: seq_lens_cpu = forward_batch.seq_lens.detach().cpu() - return seq_lens_cpu.numpy().astype(np.int32) + if torch.is_tensor(seq_lens_cpu): + seq_lens_cpu = seq_lens_cpu.detach().cpu().numpy() + return np.asarray(seq_lens_cpu, dtype=np.int32) def _build_block_tables( @@ -609,6 +624,7 @@ def build_atom_v4_decode_graph_metadata_from_sglang( ) is_idle = bool(getattr(actual_mode, "is_idle", lambda: False)()) out_cache_loc = getattr(forward_batch, "out_cache_loc", None) + positions_numel = int(positions.numel()) scheduled_bs = ( 0 if is_idle @@ -618,15 +634,20 @@ def build_atom_v4_decode_graph_metadata_from_sglang( else bs ) ) - total = scheduled_bs + total = max(scheduled_bs, positions_numel) t_pad = bs max_blocks = max(1, proxy_pool.num_blocks) bufs = getattr(proxy_pool, "_atom_v4_decode_graph_buffers", None) - if bufs is None or bufs.num_slots < bs or bufs.max_blocks < max_blocks: + if ( + bufs is None + or bufs.num_slots < bs + or bufs.max_blocks < max_blocks + or bufs.max_decode_tokens < total + ): bufs = proxy_pool._atom_v4_decode_graph_buffers = _V4SGLangDecodeGraphBuffers( num_slots=proxy_pool.num_slots, - max_decode_tokens=max(proxy_pool.num_slots, bs), + max_decode_tokens=max(proxy_pool.num_slots, bs, total), window=proxy_pool.window_size, index_topk=1024, max_committed_hca=max_blocks, @@ -667,12 +688,18 @@ def build_atom_v4_decode_graph_metadata_from_sglang( md.swa_pages = proxy_pool.num_slots * proxy_pool.window_size if total: - pos_np = (seq_np[:total] - 1).astype(np.int32) - batch_np = np.arange(total, dtype=np.int32) + if positions_numel > scheduled_bs: + pos_np = positions[:total].detach().cpu().numpy().astype(np.int32) + repeats = max(1, total // max(1, bs)) + batch_np = np.repeat(np.arange(bs, dtype=np.int64), repeats)[:total] + else: + pos_np = (seq_np[:total] - 1).astype(np.int32) + batch_np = np.arange(total, dtype=np.int64) else: pos_np = np.zeros(0, dtype=np.int32) - batch_np = np.zeros(0, dtype=np.int32) - batch_pad = np.full(t_pad, -1, dtype=np.int32) + batch_np = np.zeros(0, dtype=np.int64) + t_pad = max(t_pad, total) + batch_pad = np.full(t_pad, -1, dtype=np.int64) if total: batch_pad[:total] = batch_np @@ -684,11 +711,13 @@ def build_atom_v4_decode_graph_metadata_from_sglang( slot_arr = np.zeros(bs, dtype=np.int32) reset_slots: set[int] = set() - if total: - first_blocks = block_tables[:total, 0].detach().cpu().numpy().astype(np.int32) - fresh_mask = pos_np == 0 + if scheduled_bs: + first_blocks = ( + block_tables[:scheduled_bs, 0].detach().cpu().numpy().astype(np.int32) + ) + fresh_mask = seq_np[:scheduled_bs] <= 1 slot_real, reset_slots = allocator.assign(first_blocks, fresh_mask) - slot_arr[:total] = slot_real + slot_arr[:scheduled_bs] = slot_real if reset_slots and model is not None: reset_deepseek_v4_state_slots(model, reset_slots) @@ -717,9 +746,9 @@ def build_atom_v4_decode_graph_metadata_from_sglang( if total: actual_swa = np.minimum(pos_np + 1, win).astype(np.int32) csa_valid = np.minimum( - np.minimum((pos_np + 1) // 4, n_csa[:total]), index_topk + np.minimum((pos_np + 1) // 4, n_csa[batch_np]), index_topk ).astype(np.int32) - hca_valid = n_hca[:total].astype(np.int32) + hca_valid = n_hca[batch_np].astype(np.int32) else: actual_swa = csa_valid = hca_valid = np.zeros(0, dtype=np.int32) @@ -770,8 +799,30 @@ def indptr(counts): md.kv_indptr_swa = swa_indptr md.kv_indptr_csa = csa_indptr md.kv_indptr_hca = hca_indptr + cu_committed_cpu = np.concatenate( + [ + np.zeros(1, dtype=np.int32), + np.cumsum(md.n_committed_csa_per_seq_cpu, dtype=np.int32), + ] + ) + cu_committed_cpu[-1] = max(int(cu_committed_cpu[-1]), 1) + cu_committed_gpu = torch.from_numpy(cu_committed_cpu).to( + device=device, dtype=torch.int32 + ) + safe_batch_id = md.batch_id_per_token.clamp_min(0) + seq_base = cu_committed_gpu[safe_batch_id].to(torch.int32) + visible_end = seq_base + torch.minimum( + (positions_gpu.to(torch.int32) + 1) // 4, + md.n_committed_csa_per_seq[safe_batch_id], + ).to(torch.int32) md.indexer_meta = { + "total_committed": int(cu_committed_cpu[-1]), + "cu_committed_gpu": cu_committed_gpu, "n_committed_per_seq_gpu": md.n_committed_csa_per_seq, + "batch_id_per_token_gpu": md.batch_id_per_token, + "seq_base_per_token_gpu": seq_base, + "cu_starts_gpu": seq_base, + "cu_ends_gpu": visible_end, } return md @@ -803,15 +854,30 @@ def build_atom_v4_attention_metadata_from_sglang( if extend_lens_t is not None: extend_lens = extend_lens_t.detach().cpu().numpy().astype(np.int32) else: - extend_lens = np.diff( - torch.nn.functional.pad( - forward_batch.extend_start_loc, (0, 1), value=positions.numel() + extend_start_loc = getattr(forward_batch, "extend_start_loc", None) + if extend_start_loc is not None: + extend_lens = np.diff( + torch.nn.functional.pad( + extend_start_loc, (0, 1), value=positions.numel() + ) + .detach() + .cpu() + .numpy() + .astype(np.int32) + ) + else: + tokens_per_req = getattr( + getattr(forward_batch, "spec_info", None), + "num_tokens_per_req", + None, + ) + if tokens_per_req is None: + tokens_per_req = max(1, int(positions.numel()) // num_reqs) + extend_lens = np.full( + num_reqs, + int(tokens_per_req), + dtype=np.int32, ) - .detach() - .cpu() - .numpy() - .astype(np.int32) - ) else: extend_lens = np.asarray(extend_lens, dtype=np.int32) lens = extend_lens[:num_reqs].astype(np.int32) @@ -1113,7 +1179,7 @@ def reset_deepseek_v4_state_slots(model, slots) -> None: if not slots: return idx = None - for block in getattr(model.model, "layers", []): + for block in _iter_deepseek_v4_cache_blocks(model): attn = getattr(block, "attn", None) swa = getattr(attn, "swa_kv", None) if isinstance(swa, torch.Tensor): diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index c3251bd650..417b8be71b 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -147,6 +147,9 @@ def get_embed_and_head(self): if hasattr(self.model, "get_embed_and_head"): return self.model.get_embed_and_head() + if self.model_arch == "DeepseekV4ForCausalLM": + return self.model.model.embed.weight, self.model.model.head.weight + embed_owner = ( self.model.model if hasattr(self.model, "model") @@ -269,20 +272,13 @@ def forward( hidden_states = runtime.trim_output(hidden_states) if self.pp_group.is_last_rank: - if self.model_arch == "DeepseekV4ForCausalLM" and not getattr( - forward_batch, "return_logprob", False - ): - if forward_batch.forward_mode.is_decode_or_idle(): - pruned_states = hidden_states - elif forward_batch.forward_mode.is_extend(): - last_index = ( - torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1 - ) - pruned_states = hidden_states[last_index] - else: - pruned_states = hidden_states - return LogitsProcessorOutput( - next_token_logits=self.model.compute_logits(pruned_states) + if self.model_arch == "DeepseekV4ForCausalLM": + return self.logits_processor( + input_ids, + hidden_states, + self.logits_head, + forward_batch, + hidden_states_before_norm=hidden_states, ) return self.logits_processor( input_ids, diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py index e9ffe79f2e..3b85d375f8 100644 --- a/atom/plugin/sglang/models/deepseek_mla_attention.py +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -10,6 +10,7 @@ from __future__ import annotations +import os from typing import TYPE_CHECKING, Any import torch @@ -228,6 +229,71 @@ def _forward_absorbed( is_neox=attn.rotary_emb.is_neox_style, is_nope_first=True, ) + if os.getenv("ATOM_DEBUG_DUMP_FUSED_QK_CACHE", "0") == "1": + try: + from sglang.srt.distributed import get_tp_group + + rank = int(get_tp_group().rank_in_group) + except Exception: + rank = int(os.getenv("RANK", "-1")) + ranks = { + int(x) + for x in os.getenv("ATOM_DEBUG_DUMP_FUSED_QK_CACHE_RANKS", "0").split(",") + if x.strip() + } + layers = { + int(x) + for x in os.getenv("ATOM_DEBUG_DUMP_FUSED_QK_CACHE_LAYERS", "0").split(",") + if x.strip() + } + bss = { + int(x) + for x in os.getenv( + "ATOM_DEBUG_DUMP_FUSED_QK_CACHE_BS", + str(int(forward_batch.batch_size)), + ).split(",") + if x.strip() + } + if ( + rank in ranks + and int(getattr(mla_attn, "layer_id", -1)) in layers + and int(forward_batch.batch_size) in bss + ): + if not hasattr(self, "_debug_fused_qk_cache_counts"): + self._debug_fused_qk_cache_counts = {} + key = (rank, int(getattr(mla_attn, "layer_id", -1)), int(forward_batch.batch_size)) + hits = self._debug_fused_qk_cache_counts.get(key, 0) + max_hits = int(os.getenv("ATOM_DEBUG_DUMP_FUSED_QK_CACHE_MAX_HITS", "2")) + if hits < max_hits: + self._debug_fused_qk_cache_counts[key] = hits + 1 + dump_dir = os.getenv( + "ATOM_DEBUG_DUMP_FUSED_QK_CACHE_DIR", + "/home/qichu_qle/zhiwei/dsv4/atom/work_logs/bs64_issue/rootcause_20260620_fused_qk_cache", + ) + os.makedirs(dump_dir, exist_ok=True) + loc = forward_batch.out_cache_loc.detach().long() + dump_path = os.path.join( + dump_dir, + f"rank{rank}_layer{int(getattr(mla_attn, 'layer_id', -1))}_bs{int(forward_batch.batch_size)}_hit{hits + 1}.pt", + ) + torch.save( + { + "rank": rank, + "layer": int(getattr(mla_attn, "layer_id", -1)), + "batch_size": int(forward_batch.batch_size), + "mode": str(forward_batch.forward_mode), + "out_cache_loc": forward_batch.out_cache_loc.detach().cpu(), + "positions": positions.detach().cpu(), + "q_nope_out": q_nope_out.detach().cpu(), + "q_pe": q_pe.detach().cpu(), + "k_nope": k_nope.detach().cpu(), + "k_pe": k_pe.detach().cpu(), + "q": q.detach().cpu(), + "kv_cache_after": kv_cache[loc].detach().cpu(), + }, + dump_path, + ) + print(f"[ATOM_FUSED_QK_CACHE_DUMP] path={dump_path}", flush=True) k = None v = None save_kv_cache = False diff --git a/atom/plugin/sglang/models/deepseek_mla_forward.py b/atom/plugin/sglang/models/deepseek_mla_forward.py index c235230a4d..d880c107cb 100644 --- a/atom/plugin/sglang/models/deepseek_mla_forward.py +++ b/atom/plugin/sglang/models/deepseek_mla_forward.py @@ -12,6 +12,7 @@ from __future__ import annotations import logging +import os from typing import TYPE_CHECKING, Any, Optional import torch @@ -269,7 +270,15 @@ def init_sgl_attrs( attn.use_deep_gemm_bmm = False attn.alt_stream = None attn.kv_cache_dtype = kv_cache_dtype - attn.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 + attn.use_fused_qk_rope_concat_and_cache_mla = ( + _use_aiter_gfx95 + and os.getenv("ATOM_DEBUG_DISABLE_FUSED_QK_ROPE_CACHE_MLA", "0") != "1" + ) + if os.getenv("ATOM_DEBUG_LOG_FUSED_QK_FLAG", "0") == "1": + print( + f"[ATOM_DEBUG_FUSED_QK_FLAG] use_fused_qk_rope_concat_and_cache_mla={attn.use_fused_qk_rope_concat_and_cache_mla} kv_cache_dtype={kv_cache_dtype}", + flush=True, + ) attn.current_sgl_plugin_attn_path = None attn.w_kc, attn.w_vc = None, None attn.w_scale = None diff --git a/atom/plugin/sglang/runtime/forward_context.py b/atom/plugin/sglang/runtime/forward_context.py index 998f1746cd..520672e3d9 100644 --- a/atom/plugin/sglang/runtime/forward_context.py +++ b/atom/plugin/sglang/runtime/forward_context.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import logging from contextlib import ExitStack from dataclasses import dataclass, field from typing import Any, Optional @@ -12,6 +13,8 @@ from atom.plugin.sglang.runtime.context import bind_current_forward_batch +logger = logging.getLogger("atom.plugin.sglang.runtime.forward_context") + def _is_dummy_forward(forward_batch: ForwardBatch) -> bool: """Return whether an SGLang batch represents an empty/idle dummy run.""" @@ -122,6 +125,78 @@ def _resolve_num_tokens_across_dp( return num_tokens_across_dp +def _slice_v4_graph_metadata_for_capture(attn_metadata: Any, *, num_tokens: int, bs: int): + """Narrow reusable V4 graph metadata to this capture bucket. + + The DSV4 fallback metadata is initialized at max graph size. SGLang then + captures smaller buckets (e.g. bs=248, tokens=496), so per-token arrays must + be narrowed before model code reads them. + """ + + if attn_metadata is None: + return None + + md = copy.copy(attn_metadata) + + def _slice_attr(name: str, n: int): + value = getattr(md, name, None) + if torch.is_tensor(value): + setattr(md, name, value[:n]) + elif value is not None: + try: + setattr(md, name, value[:n]) + except Exception: + pass + + for name in ( + "batch_id_per_token", + "batch_id_per_token_cpu", + "kv_indices_swa", + "kv_indices_csa", + "kv_indices_hca", + "kv_indices_extend", + "kv_indices_prefix_swa", + "kv_indices_prefix_csa", + "kv_indices_prefix_hca", + "skip_prefix_len_csa", + ): + _slice_attr(name, num_tokens) + + for name in ( + "state_slot_mapping", + "state_slot_mapping_cpu", + "n_committed_csa_per_seq", + "n_committed_csa_per_seq_cpu", + "n_committed_hca_per_seq", + "n_committed_hca_per_seq_cpu", + "context_lens", + ): + _slice_attr(name, bs) + + block_tables = getattr(md, "block_tables", None) + if torch.is_tensor(block_tables): + md.block_tables = block_tables[:bs] + + indexer_meta = getattr(md, "indexer_meta", None) + if isinstance(indexer_meta, dict): + indexer_meta = dict(indexer_meta) + for key in ( + "batch_id_per_token_gpu", + "seq_base_per_token_gpu", + "cu_starts_gpu", + "cu_ends_gpu", + ): + value = indexer_meta.get(key) + if torch.is_tensor(value): + indexer_meta[key] = value[:num_tokens] + value = indexer_meta.get("n_committed_per_seq_gpu") + if torch.is_tensor(value): + indexer_meta["n_committed_per_seq_gpu"] = value[:bs] + md.indexer_meta = indexer_meta + + return md + + def _set_atom_forward_context( atom_config: Any, forward_batch: ForwardBatch, @@ -140,33 +215,76 @@ def _set_atom_forward_context( max_seqlen_q = 1 if forward_mode.is_decode_or_idle() else 0 attn_metadata = None try: + attn_metadata = getattr(forward_batch, "atom_v4_graph_metadata", None) from atom.plugin.sglang.deepseek_v4_bridge import ( build_atom_v4_attention_metadata_from_sglang, maybe_get_proxy_pool_from_sglang_backend, ) - try: - from sglang.srt.model_executor.forward_context import get_attn_backend + if attn_metadata is None: + try: + from sglang.srt.model_executor.forward_context import get_attn_backend - backend = get_attn_backend() - attn_metadata = getattr(backend, "atom_v4_graph_metadata", None) - except Exception: - attn_metadata = None + backend = get_attn_backend() + attn_metadata = getattr(backend, "atom_v4_graph_metadata", None) + except Exception: + attn_metadata = None if attn_metadata is None: backend = getattr(forward_batch, "attn_backend", None) attn_metadata = getattr(backend, "atom_v4_graph_metadata", None) + if attn_metadata is None and backend is not None: + backend_forward_batch = getattr(backend, "forward_metadata", None) + attn_metadata = getattr( + backend_forward_batch, "atom_v4_graph_metadata", None + ) + proxy_pool, req_to_token_pool = maybe_get_proxy_pool_from_sglang_backend() + try: + is_capture_batch = bool(torch.cuda.is_current_stream_capturing()) + except Exception: + is_capture_batch = False + if attn_metadata is None and is_capture_batch: + try: + from atom.plugin.sglang.attention_backend.deepseek_v4_backend import ( + ATOMDeepseekV4BackendForSgl, + ) + + attn_metadata = ( + ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata + ) + if attn_metadata is not None: + attn_metadata = _slice_v4_graph_metadata_for_capture( + attn_metadata, + num_tokens=int(positions.shape[0]), + bs=int(forward_batch.batch_size), + ) + except Exception: + attn_metadata = None if attn_metadata is None and getattr( proxy_pool, "is_atom_v4_proxy_pool", False ): - attn_metadata = build_atom_v4_attention_metadata_from_sglang( - forward_batch, - positions, - proxy_pool=proxy_pool, - req_to_token_pool=req_to_token_pool, - ) + if is_capture_batch: + logger.info( + "ATOM DSV4 capture metadata missing: backend=%s " + "fb_has=%s class_has=%s", + type(backend).__name__ if backend is not None else None, + hasattr(forward_batch, "atom_v4_graph_metadata"), + "ATOMDeepseekV4BackendForSgl" in locals() + and ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata + is not None, + ) + raise RuntimeError( + "ATOM DeepSeek-V4 CUDA graph metadata was not initialized before capture" + ) + else: + attn_metadata = build_atom_v4_attention_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=proxy_pool, + req_to_token_pool=req_to_token_pool, + ) except Exception as exc: raise RuntimeError( "Failed to build ATOM DeepSeek-V4 metadata for SGLang" From 4ceeda7251808c15955f9d2ea3123cfb01e0d7e4 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 26 Jun 2026 03:23:06 +0000 Subject: [PATCH 2/4] remove debug print --- .../attention_backend/deepseek_v4_backend.py | 15 - .../full_attention/full_attention_backend.py | 373 +----------------- .../attention_backend/sparse_mla_indexer.py | 72 ---- .../sglang/models/deepseek_mla_attention.py | 66 ---- .../sglang/models/deepseek_mla_forward.py | 11 +- atom/plugin/sglang/runtime/forward_context.py | 9 - 6 files changed, 2 insertions(+), 544 deletions(-) diff --git a/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py b/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py index 098371b351..273dfe61a4 100644 --- a/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py +++ b/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py @@ -44,12 +44,6 @@ def init_forward_metadata(self, forward_batch): def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = False): self.forward_metadata = forward_batch - logger.info( - "ATOM DSV4 init_forward_metadata_out_graph: in_capture=%s mode=%s bs=%s", - in_capture, - getattr(getattr(forward_batch, "forward_mode", None), "name", None), - getattr(forward_batch, "batch_size", None), - ) if not (in_capture or hasattr(forward_batch, "actual_forward_mode")): self.atom_v4_graph_metadata = None return @@ -65,9 +59,6 @@ def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = Fals positions = getattr(buffers, "positions", None) if positions is None: self.atom_v4_graph_metadata = None - logger.info( - "Skip ATOM DeepSeek-V4 graph metadata init: positions unavailable" - ) return atom_model = getattr(getattr(self.model_runner, "model", None), "model", None) @@ -90,12 +81,6 @@ def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = Fals ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = ( self.atom_v4_graph_metadata ) - logger.info( - "ATOM DSV4 graph metadata initialized: mode=%s bs=%s metadata=%s", - getattr(getattr(forward_batch, "forward_mode", None), "name", None), - getattr(forward_batch, "batch_size", None), - type(self.atom_v4_graph_metadata).__name__, - ) def _init_decode_cuda_graph_metadata( self, diff --git a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py index dc80850b69..596d350d4c 100644 --- a/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py +++ b/atom/plugin/sglang/attention_backend/full_attention/full_attention_backend.py @@ -173,93 +173,6 @@ def __init__( self.prefill_ps_num_kv_splits = cu_num // math.gcd(self.num_kv_head, cu_num) else: self.prefill_ps_num_kv_splits = None - self._mtp_debug_counts: dict[tuple[int, int, str], int] = {} - - def _debug_mtp_tensor(self, name: str, tensor: Optional[torch.Tensor]) -> str: - if tensor is None: - return f"{name}=None" - flat = tensor.detach().flatten() - if flat.numel() == 0: - return f"{name}=empty shape={tuple(tensor.shape)} dtype={tensor.dtype}" - sample = flat[: min(6, flat.numel())].to(torch.float32).cpu().tolist() - stats_tensor = flat.to(torch.float32) - return ( - f"{name}=shape={tuple(tensor.shape)} dtype={tensor.dtype} " - f"min={float(stats_tensor.min().item()):.6g} " - f"max={float(stats_tensor.max().item()):.6g} " - f"mean={float(stats_tensor.mean().item()):.6g} " - f"sample={sample}" - ) - - def _debug_mtp_target_verify( - self, - tag: str, - layer: "RadixAttention", - forward_batch: ForwardBatch, - q: Optional[torch.Tensor] = None, - o: Optional[torch.Tensor] = None, - ) -> None: - if os.getenv("ATOM_DEBUG_MTP_VERIFY", "0") != "1": - return - rank = -1 - try: - from sglang.srt.distributed import get_tp_group - - rank = int(get_tp_group().rank_in_group) - except Exception: - rank = int(os.getenv("RANK", "-1")) - debug_ranks = { - int(x) - for x in os.getenv("ATOM_DEBUG_MTP_VERIFY_RANKS", "0").split(",") - if x.strip() - } - if rank not in debug_ranks: - return - bs = int(forward_batch.batch_size) - debug_bs = { - int(x) - for x in os.getenv("ATOM_DEBUG_MTP_VERIFY_BS", "63,64").split(",") - if x.strip() - } - if bs not in debug_bs: - return - layer_id = int(getattr(layer, "layer_id", -1)) - debug_layers = { - int(x) - for x in os.getenv("ATOM_DEBUG_MTP_VERIFY_LAYERS", "0,1,60").split(",") - if x.strip() - } - if layer_id not in debug_layers: - return - key = (bs, layer_id, tag) - max_hits = int(os.getenv("ATOM_DEBUG_MTP_VERIFY_MAX_HITS", "4")) - hits = self._mtp_debug_counts.get(key, 0) - if hits >= max_hits: - return - self._mtp_debug_counts[key] = hits + 1 - - md = self.forward_metadata - spec_info = getattr(forward_batch, "spec_info", None) - draft_num = getattr(spec_info, "draft_token_num", None) - pieces = [ - f"[ATOM_MTP_DEBUG] tag={tag} hit={hits + 1} rank={rank} layer={layer_id} bs={bs}", - f"mode={forward_batch.forward_mode}", - f"draft_num={draft_num}", - f"max_q_len={getattr(md, 'max_q_len', None)}", - f"num_kv_splits={getattr(md, 'num_kv_splits', None)}", - f"kv_indices_len={None if md.kv_indices is None else md.kv_indices.numel()}", - f"work_metadata_shape={None if md.work_metadata is None else tuple(md.work_metadata.shape)}", - self._debug_mtp_tensor("kv_indices", md.kv_indices[: min(64, md.kv_indices.numel())] if md.kv_indices is not None else None), - self._debug_mtp_tensor("seq_lens", forward_batch.seq_lens[:bs]), - self._debug_mtp_tensor("req_pool", forward_batch.req_pool_indices[:bs]), - self._debug_mtp_tensor("out_cache_loc", getattr(forward_batch, "out_cache_loc", None)), - self._debug_mtp_tensor("qo_indptr", md.qo_indptr[: bs + 1] if md.qo_indptr is not None else None), - self._debug_mtp_tensor("kv_indptr", md.kv_indptr[: bs + 1] if md.kv_indptr is not None else None), - self._debug_mtp_tensor("kv_last_page_len", md.kv_last_page_len[:bs] if md.kv_last_page_len is not None else None), - self._debug_mtp_tensor("q", q), - self._debug_mtp_tensor("o", o), - ] - print(" | ".join(pieces), flush=True) def _cuda_graph_mla_max_seqlen_qo(self) -> int: """Largest q length used by MLA CUDA graph speculative paths.""" @@ -1815,83 +1728,6 @@ def forward_extend( ) elif self.use_mla: forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) - if os.getenv("ATOM_DEBUG_DUMP_KV_WRITE", "0") == "1": - try: - from sglang.srt.distributed import get_tp_group - - rank = int(get_tp_group().rank_in_group) - except Exception: - rank = int(os.getenv("RANK", "-1")) - dump_ranks = { - int(x) - for x in os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_RANKS", "0").split(",") - if x.strip() - } - dump_layers = { - int(x) - for x in os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_LAYERS", "0").split(",") - if x.strip() - } - dump_bs = { - int(x) - for x in os.getenv( - "ATOM_DEBUG_DUMP_KV_WRITE_BS", - str(int(forward_batch.batch_size)), - ).split(",") - if x.strip() - } - if ( - rank in dump_ranks - and int(getattr(layer, "layer_id", -1)) in dump_layers - and int(forward_batch.batch_size) in dump_bs - ): - if not hasattr(self, "_debug_kv_write_counts"): - self._debug_kv_write_counts = {} - key = ( - rank, - int(getattr(layer, "layer_id", -1)), - int(forward_batch.batch_size), - ) - hits = self._debug_kv_write_counts.get(key, 0) - max_hits = int(os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_MAX_HITS", "2")) - if hits < max_hits: - self._debug_kv_write_counts[key] = hits + 1 - dump_dir = os.getenv( - "ATOM_DEBUG_DUMP_KV_WRITE_DIR", - "/home/qichu_qle/zhiwei/dsv4/atom/work_logs/bs64_issue/rootcause_20260620_kv_write", - ) - os.makedirs(dump_dir, exist_ok=True) - k_buffer = forward_batch.token_to_kv_pool.get_key_buffer( - layer.layer_id - ) - loc = cache_loc.detach().long() - dump_path = os.path.join( - dump_dir, - f"rank{rank}_layer{int(getattr(layer, 'layer_id', -1))}_bs{int(forward_batch.batch_size)}_hit{hits + 1}.pt", - ) - torch.save( - { - "rank": rank, - "layer": int(getattr(layer, "layer_id", -1)), - "batch_size": int(forward_batch.batch_size), - "forward_mode": str(forward_batch.forward_mode), - "cache_loc": cache_loc.detach().cpu(), - "positions": None - if getattr(forward_batch, "positions", None) is None - else forward_batch.positions.detach().cpu(), - "seq_lens": None - if getattr(forward_batch, "seq_lens", None) is None - else forward_batch.seq_lens.detach().cpu(), - "req_pool_indices": None - if getattr(forward_batch, "req_pool_indices", None) is None - else forward_batch.req_pool_indices.detach().cpu(), - "k_input": k.detach().cpu(), - "v_input": v.detach().cpu(), - "k_after_write": k_buffer[loc].detach().cpu(), - }, - dump_path, - ) - print(f"[ATOM_KV_WRITE_DUMP] path={dump_path}", flush=True) else: k_buffer, v_buffer = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id @@ -2110,8 +1946,6 @@ def _try_fused_mxfp4_kv_b_proj_fp8( qk_nope_head_dim, ): """Return FP8 k/v from MXFP4 kv_b_proj when the fused split-cat path fits.""" - if os.getenv("ATOM_DEBUG_DISABLE_MXFP4_KVB_FUSED", "0") == "1": - return None if fused_gemm_afp4wfp4_preshuffle_split_cat is None: return None weight = getattr(layer.kv_b_proj, "weight", None) @@ -2454,212 +2288,7 @@ def _forward_extend_mla_speculative( (q.shape[0], layer.tp_q_head_num, layer.v_head_dim), dtype=self.input_dtype, ) - dump_enabled = os.getenv("ATOM_DUMP_MLA_VERIFY_REPRO", "0") == "1" - dump_hit = False - dump_path = None - if dump_enabled: - try: - from sglang.srt.distributed import get_tp_group - - rank = int(get_tp_group().rank_in_group) - except Exception: - rank = int(os.getenv("RANK", "-1")) - dump_ranks = { - int(x) - for x in os.getenv("ATOM_DUMP_MLA_VERIFY_RANKS", "0").split(",") - if x.strip() - } - dump_layers = { - int(x) - for x in os.getenv("ATOM_DUMP_MLA_VERIFY_LAYERS", "0").split(",") - if x.strip() - } - key = ( - int(forward_batch.batch_size), - int(getattr(layer, "layer_id", -1)), - "dump_mla_verify", - ) - if not hasattr(self, "_dump_mla_verify_counts"): - self._dump_mla_verify_counts = {} - hits = self._dump_mla_verify_counts.get(key, 0) - max_hits = int(os.getenv("ATOM_DUMP_MLA_VERIFY_MAX_HITS", "1")) - dump_bs = { - int(x) - for x in os.getenv("ATOM_DUMP_MLA_VERIFY_BS", str(int(forward_batch.batch_size))).split(",") - if x.strip() - } - dump_hit = ( - rank in dump_ranks - and int(forward_batch.batch_size) in dump_bs - and int(getattr(layer, "layer_id", -1)) in dump_layers - and hits < max_hits - ) - if dump_hit: - self._dump_mla_verify_counts[key] = hits + 1 - dump_dir = os.getenv( - "ATOM_DUMP_MLA_VERIFY_DIR", - "/home/qichu_qle/zhiwei/dsv4/atom/work_logs/bs64_issue/fixed_prompt_rootcause_20260618/mla_kernel_repro", - ) - os.makedirs(dump_dir, exist_ok=True) - draft_num = getattr(forward_batch.spec_info, "draft_token_num", -1) - dump_path = os.path.join( - dump_dir, - f"rank{rank}_layer{int(getattr(layer, 'layer_id', -1))}_bs{int(forward_batch.batch_size)}_draft{int(draft_num)}_hit{hits + 1}.pt", - ) - md = self.forward_metadata - compact_k = K_Buffer[md.kv_indices.long()].contiguous() - unique_kv_indices = torch.unique(md.kv_indices.long()) - raw_k_slots = K_Buffer[unique_kv_indices].contiguous() - torch.save( - { - "q": q.detach().cpu(), - "k_compact": compact_k.detach().cpu(), - "unique_kv_indices": unique_kv_indices.detach().cpu(), - "raw_k_slots": raw_k_slots.detach().cpu(), - "qo_indptr": md.qo_indptr.detach().cpu(), - "kv_indptr": md.kv_indptr.detach().cpu(), - "compact_kv_indices": torch.arange( - compact_k.shape[0], dtype=torch.int32 - ), - "kv_indices_original": None - if md.kv_indices is None - else md.kv_indices.detach().cpu(), - "kv_last_page_len": md.kv_last_page_len.detach().cpu(), - "seq_lens": None - if getattr(forward_batch, "seq_lens", None) is None - else forward_batch.seq_lens.detach().cpu(), - "req_pool_indices": None - if getattr(forward_batch, "req_pool_indices", None) is None - else forward_batch.req_pool_indices.detach().cpu(), - "out_cache_loc": None - if getattr(forward_batch, "out_cache_loc", None) is None - else forward_batch.out_cache_loc.detach().cpu(), - "positions": None - if getattr(forward_batch, "positions", None) is None - else forward_batch.positions.detach().cpu(), - "max_q_len": int(md.max_q_len), - "num_kv_splits": int(md.num_kv_splits or 0), - "q_scale": None - if getattr(layer, "k_scale", None) is None - else layer.k_scale.detach().cpu(), - "kv_scale": None - if getattr(layer, "k_scale", None) is None - else layer.k_scale.detach().cpu(), - "tp_q_head_num": int(layer.tp_q_head_num), - "qk_head_dim": int(layer.qk_head_dim), - "v_head_dim": int(layer.v_head_dim), - "scaling": float(layer.scaling), - "logit_cap": float(layer.logit_cap), - "batch_size": int(forward_batch.batch_size), - "draft_num": int(draft_num), - }, - dump_path, - ) - print(f"[ATOM_MLA_REPRO_DUMP] before path={dump_path}", flush=True) - self._debug_mtp_target_verify("before_mla_decode", layer, forward_batch, q=q) - if ( - os.getenv("ATOM_DEBUG_SPLIT_TARGET_VERIFY_Q2", "0") == "1" - and int(getattr(md, "max_q_len", 0)) == 4 - and int(forward_batch.batch_size) > 0 - ): - bs = int(forward_batch.batch_size) - idx_first = torch.tensor( - [r * 4 + j for r in range(bs) for j in (0, 1)], - device=q.device, - dtype=torch.long, - ) - idx_second = torch.tensor( - [r * 4 + j for r in range(bs) for j in (2, 3)], - device=q.device, - dtype=torch.long, - ) - qo2 = torch.arange(0, (bs + 1) * 2, 2, dtype=torch.int32, device=q.device) - ( - work_metadata2, - work_indptr2, - work_info_set2, - reduce_indptr2, - reduce_final_map2, - reduce_partial_map2, - ) = self.make_mla_decode_meta_data_buffer(2, bs) - num_kv_splits2 = self.max_split_per_batch - self.make_mla_meta_data( - qo2, - md.kv_indptr, - md.kv_last_page_len, - work_metadata2, - work_info_set2, - work_indptr2, - reduce_indptr2, - reduce_final_map2, - reduce_partial_map2, - 2, - fast_mode=_sglang_aiter.fast_mode, - max_split_per_batch=num_kv_splits2, - intra_batch_mode=_sglang_aiter.intra_batch_mode, - ) - old_md = self.forward_metadata - try: - self.forward_metadata = ForwardMetadata( - md.kv_indptr, - md.kv_indices, - qo2, - md.kv_last_page_len, - 2, - None, - None, - None, - work_metadata=work_metadata2, - work_info_set=work_info_set2, - work_indptr=work_indptr2, - reduce_indptr=reduce_indptr2, - reduce_final_map=reduce_final_map2, - reduce_partial_map=reduce_partial_map2, - num_kv_splits=num_kv_splits2, - run_graph=False, - ) - o_first = o.new_empty((bs * 2, layer.tp_q_head_num, layer.v_head_dim)) - o_second = torch.empty_like(o_first) - self._call_mla_decode_fwd(q[idx_first].contiguous(), K_Buffer, o_first, layer) - self._call_mla_decode_fwd(q[idx_second].contiguous(), K_Buffer, o_second, layer) - o[idx_first] = o_first - o[idx_second] = o_second - finally: - self.forward_metadata = old_md - print( - f"[ATOM_MTP_DEBUG] tag=split_target_verify_q2 bs={bs} layer={int(getattr(layer, 'layer_id', -1))}", - flush=True, - ) - elif ( - os.getenv("ATOM_DEBUG_TARGET_VERIFY_PREFILL_FWD", "0") == "1" - and int(getattr(md, "max_q_len", 0)) == 4 - ): - o_prefill = self._extend_mla_absorbed_prefix( - q, - layer, - K_Buffer, - md.kv_indptr, - md.kv_indices, - md.qo_indptr, - ) - if o_prefill.ndim == 2: - o_prefill = o_prefill.view(-1, layer.tp_q_head_num, layer.v_head_dim) - o.copy_(o_prefill) - print( - f"[ATOM_MTP_DEBUG] tag=target_verify_prefill_fwd bs={int(forward_batch.batch_size)} " - f"layer={int(getattr(layer, 'layer_id', -1))}", - flush=True, - ) - else: - self._call_mla_decode_fwd(q, K_Buffer, o, layer) - if dump_hit and dump_path is not None: - saved = torch.load(dump_path, map_location="cpu") - saved["o"] = o.detach().cpu() - torch.save(saved, dump_path) - print(f"[ATOM_MLA_REPRO_DUMP] after path={dump_path}", flush=True) - self._debug_mtp_target_verify( - "after_mla_decode", layer, forward_batch, q=q, o=o - ) + self._call_mla_decode_fwd(q, K_Buffer, o, layer) return o if forward_batch.forward_mode.is_draft_extend(include_v2=True): diff --git a/atom/plugin/sglang/attention_backend/sparse_mla_indexer.py b/atom/plugin/sglang/attention_backend/sparse_mla_indexer.py index a0c29b7431..ee9e46565d 100644 --- a/atom/plugin/sglang/attention_backend/sparse_mla_indexer.py +++ b/atom/plugin/sglang/attention_backend/sparse_mla_indexer.py @@ -6,7 +6,6 @@ from __future__ import annotations import re -import os from typing import Optional import torch @@ -240,77 +239,6 @@ def forward_sparse_mla_for_sglang( forward_batch.token_to_kv_pool.set_kv_buffer( layer, forward_batch.out_cache_loc, k, v ) - if os.getenv("ATOM_DEBUG_DUMP_KV_WRITE", "0") == "1": - try: - from sglang.srt.distributed import get_tp_group - - rank = int(get_tp_group().rank_in_group) - except Exception: - rank = int(os.getenv("RANK", "-1")) - dump_ranks = { - int(x) - for x in os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_RANKS", "0").split(",") - if x.strip() - } - dump_layers = { - int(x) - for x in os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_LAYERS", "0").split(",") - if x.strip() - } - dump_bs = { - int(x) - for x in os.getenv( - "ATOM_DEBUG_DUMP_KV_WRITE_BS", - str(int(forward_batch.batch_size)), - ).split(",") - if x.strip() - } - if ( - rank in dump_ranks - and int(getattr(layer, "layer_id", -1)) in dump_layers - and int(forward_batch.batch_size) in dump_bs - ): - if not hasattr(forward_sparse_mla_for_sglang, "_debug_kv_write_counts"): - forward_sparse_mla_for_sglang._debug_kv_write_counts = {} - key = (rank, int(getattr(layer, "layer_id", -1)), int(forward_batch.batch_size)) - hits = forward_sparse_mla_for_sglang._debug_kv_write_counts.get(key, 0) - max_hits = int(os.getenv("ATOM_DEBUG_DUMP_KV_WRITE_MAX_HITS", "2")) - if hits < max_hits: - forward_sparse_mla_for_sglang._debug_kv_write_counts[key] = hits + 1 - dump_dir = os.getenv( - "ATOM_DEBUG_DUMP_KV_WRITE_DIR", - "/home/qichu_qle/zhiwei/dsv4/atom/work_logs/bs64_issue/rootcause_20260620_kv_write", - ) - os.makedirs(dump_dir, exist_ok=True) - k_buffer_after = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - loc = forward_batch.out_cache_loc.detach().long() - dump_path = os.path.join( - dump_dir, - f"sparse_rank{rank}_layer{int(getattr(layer, 'layer_id', -1))}_bs{int(forward_batch.batch_size)}_hit{hits + 1}.pt", - ) - torch.save( - { - "rank": rank, - "layer": int(getattr(layer, "layer_id", -1)), - "batch_size": int(forward_batch.batch_size), - "forward_mode": str(forward_batch.forward_mode), - "cache_loc": forward_batch.out_cache_loc.detach().cpu(), - "positions": None - if getattr(forward_batch, "positions", None) is None - else forward_batch.positions.detach().cpu(), - "seq_lens": None - if getattr(forward_batch, "seq_lens", None) is None - else forward_batch.seq_lens.detach().cpu(), - "req_pool_indices": None - if getattr(forward_batch, "req_pool_indices", None) is None - else forward_batch.req_pool_indices.detach().cpu(), - "k_input": k.detach().cpu(), - "v_input": v.detach().cpu(), - "k_after_write": k_buffer_after[loc].detach().cpu(), - }, - dump_path, - ) - print(f"[ATOM_KV_WRITE_DUMP] path={dump_path}", flush=True) q = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) num_tokens = q.shape[0] diff --git a/atom/plugin/sglang/models/deepseek_mla_attention.py b/atom/plugin/sglang/models/deepseek_mla_attention.py index 3b85d375f8..e9ffe79f2e 100644 --- a/atom/plugin/sglang/models/deepseek_mla_attention.py +++ b/atom/plugin/sglang/models/deepseek_mla_attention.py @@ -10,7 +10,6 @@ from __future__ import annotations -import os from typing import TYPE_CHECKING, Any import torch @@ -229,71 +228,6 @@ def _forward_absorbed( is_neox=attn.rotary_emb.is_neox_style, is_nope_first=True, ) - if os.getenv("ATOM_DEBUG_DUMP_FUSED_QK_CACHE", "0") == "1": - try: - from sglang.srt.distributed import get_tp_group - - rank = int(get_tp_group().rank_in_group) - except Exception: - rank = int(os.getenv("RANK", "-1")) - ranks = { - int(x) - for x in os.getenv("ATOM_DEBUG_DUMP_FUSED_QK_CACHE_RANKS", "0").split(",") - if x.strip() - } - layers = { - int(x) - for x in os.getenv("ATOM_DEBUG_DUMP_FUSED_QK_CACHE_LAYERS", "0").split(",") - if x.strip() - } - bss = { - int(x) - for x in os.getenv( - "ATOM_DEBUG_DUMP_FUSED_QK_CACHE_BS", - str(int(forward_batch.batch_size)), - ).split(",") - if x.strip() - } - if ( - rank in ranks - and int(getattr(mla_attn, "layer_id", -1)) in layers - and int(forward_batch.batch_size) in bss - ): - if not hasattr(self, "_debug_fused_qk_cache_counts"): - self._debug_fused_qk_cache_counts = {} - key = (rank, int(getattr(mla_attn, "layer_id", -1)), int(forward_batch.batch_size)) - hits = self._debug_fused_qk_cache_counts.get(key, 0) - max_hits = int(os.getenv("ATOM_DEBUG_DUMP_FUSED_QK_CACHE_MAX_HITS", "2")) - if hits < max_hits: - self._debug_fused_qk_cache_counts[key] = hits + 1 - dump_dir = os.getenv( - "ATOM_DEBUG_DUMP_FUSED_QK_CACHE_DIR", - "/home/qichu_qle/zhiwei/dsv4/atom/work_logs/bs64_issue/rootcause_20260620_fused_qk_cache", - ) - os.makedirs(dump_dir, exist_ok=True) - loc = forward_batch.out_cache_loc.detach().long() - dump_path = os.path.join( - dump_dir, - f"rank{rank}_layer{int(getattr(mla_attn, 'layer_id', -1))}_bs{int(forward_batch.batch_size)}_hit{hits + 1}.pt", - ) - torch.save( - { - "rank": rank, - "layer": int(getattr(mla_attn, "layer_id", -1)), - "batch_size": int(forward_batch.batch_size), - "mode": str(forward_batch.forward_mode), - "out_cache_loc": forward_batch.out_cache_loc.detach().cpu(), - "positions": positions.detach().cpu(), - "q_nope_out": q_nope_out.detach().cpu(), - "q_pe": q_pe.detach().cpu(), - "k_nope": k_nope.detach().cpu(), - "k_pe": k_pe.detach().cpu(), - "q": q.detach().cpu(), - "kv_cache_after": kv_cache[loc].detach().cpu(), - }, - dump_path, - ) - print(f"[ATOM_FUSED_QK_CACHE_DUMP] path={dump_path}", flush=True) k = None v = None save_kv_cache = False diff --git a/atom/plugin/sglang/models/deepseek_mla_forward.py b/atom/plugin/sglang/models/deepseek_mla_forward.py index d880c107cb..c235230a4d 100644 --- a/atom/plugin/sglang/models/deepseek_mla_forward.py +++ b/atom/plugin/sglang/models/deepseek_mla_forward.py @@ -12,7 +12,6 @@ from __future__ import annotations import logging -import os from typing import TYPE_CHECKING, Any, Optional import torch @@ -270,15 +269,7 @@ def init_sgl_attrs( attn.use_deep_gemm_bmm = False attn.alt_stream = None attn.kv_cache_dtype = kv_cache_dtype - attn.use_fused_qk_rope_concat_and_cache_mla = ( - _use_aiter_gfx95 - and os.getenv("ATOM_DEBUG_DISABLE_FUSED_QK_ROPE_CACHE_MLA", "0") != "1" - ) - if os.getenv("ATOM_DEBUG_LOG_FUSED_QK_FLAG", "0") == "1": - print( - f"[ATOM_DEBUG_FUSED_QK_FLAG] use_fused_qk_rope_concat_and_cache_mla={attn.use_fused_qk_rope_concat_and_cache_mla} kv_cache_dtype={kv_cache_dtype}", - flush=True, - ) + attn.use_fused_qk_rope_concat_and_cache_mla = _use_aiter_gfx95 attn.current_sgl_plugin_attn_path = None attn.w_kc, attn.w_vc = None, None attn.w_scale = None diff --git a/atom/plugin/sglang/runtime/forward_context.py b/atom/plugin/sglang/runtime/forward_context.py index 520672e3d9..cda8d36f37 100644 --- a/atom/plugin/sglang/runtime/forward_context.py +++ b/atom/plugin/sglang/runtime/forward_context.py @@ -266,15 +266,6 @@ def _set_atom_forward_context( proxy_pool, "is_atom_v4_proxy_pool", False ): if is_capture_batch: - logger.info( - "ATOM DSV4 capture metadata missing: backend=%s " - "fb_has=%s class_has=%s", - type(backend).__name__ if backend is not None else None, - hasattr(forward_batch, "atom_v4_graph_metadata"), - "ATOMDeepseekV4BackendForSgl" in locals() - and ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata - is not None, - ) raise RuntimeError( "ATOM DeepSeek-V4 CUDA graph metadata was not initialized before capture" ) From 84f2e0dbeb71c2b4caeb56d0d7cd0909b83f39a0 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 26 Jun 2026 03:26:15 +0000 Subject: [PATCH 3/4] mtp layer wrapper --- .../models/deepseek_v4_nextn_wrapper.py | 263 ++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 atom/plugin/sglang/models/deepseek_v4_nextn_wrapper.py diff --git a/atom/plugin/sglang/models/deepseek_v4_nextn_wrapper.py b/atom/plugin/sglang/models/deepseek_v4_nextn_wrapper.py new file mode 100644 index 0000000000..de38768c05 --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_v4_nextn_wrapper.py @@ -0,0 +1,263 @@ +"""ATOM DeepSeek-V4 NextN wrapper for SGLang external loading. + +SGLang rewrites a DeepSeek-V4 draft runner to the architecture name +``DeepseekV4ForCausalLMNextN``. This wrapper keeps that public name while +delegating the actual MTP block to ATOM's ``DeepseekV4MTP`` implementation. +""" + +import copy +import logging +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args + +from atom.config import QuantizationConfig as AtomQuantizationConfig +from atom.config import SpeculativeConfig +from atom.model_ops.embed_head import VocabParallelEmbedding +from atom.models.deepseek_v4 import DeepseekV4Attention, ParallelHead +from atom.plugin.config import generate_atom_config_for_plugin_mode +from atom.plugin.sglang.runtime import ( + SGLangForwardBatchMetadata, + SGLangPluginRuntime, + plugin_runtime_scope, +) + +logger = logging.getLogger("atom.plugin.sglang.models") + + +def _sync_replaced_weights() -> None: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def _replace_weight(module: nn.Module, attr_name: str, weight) -> None: + if hasattr(module, attr_name): + delattr(module, attr_name) + setattr(module, attr_name, weight) + + +def _materialize_dummy_hidden_states( + hidden_states: torch.Tensor, + *, + length: int, +) -> torch.Tensor: + shape = (length, *hidden_states.shape[1:]) + return hidden_states.new_zeros(shape) + + +def _install_deepseek_v4_mtp_adapters(model: nn.Module) -> None: + from atom.plugin.sglang.models.deepseek_v4_attention import ( + patch_deepseek_v4_attention_for_sglang, + ) + + for module in model.modules(): + if isinstance(module, DeepseekV4Attention): + patch_deepseek_v4_attention_for_sglang(module) + + +class _DeepseekV4MTPLogitsHeadAdapter(nn.Module): + """Expose ``DeepseekV4MTP.compute_logits`` as an SGLang lm_head.""" + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + + def set_lora(self, *args, **kwargs) -> None: + return None + + def apply_lora(self, *args, **kwargs) -> None: + return None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.model.compute_logits(hidden_states) + + +class DeepseekV4ForCausalLMNextN(nn.Module): + """SGLang-compatible draft wrapper backed by ATOM's ``DeepseekV4MTP``.""" + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + del prefix + super().__init__() + + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + with plugin_runtime_scope(framework="sglang"): + self.atom_config = generate_atom_config_for_plugin_mode(config) + + server_args = get_global_server_args() + draft_model_path = ( + server_args.speculative_draft_model_path or server_args.model_path + ) + use_standalone_draft = ( + server_args.speculative_draft_model_path is not None + and server_args.speculative_draft_model_path != server_args.model_path + ) + self.use_standalone_draft = use_standalone_draft + self.atom_config.model = draft_model_path + if use_standalone_draft and hasattr(config, "quantization_config"): + self.atom_config.hf_config.quantization_config = copy.deepcopy( + config.quantization_config + ) + SpeculativeConfig.hf_config_override( + self.atom_config.hf_config, model_path=draft_model_path + ) + if use_standalone_draft: + self.atom_config.quant_config = AtomQuantizationConfig( + self.atom_config.hf_config, + self.atom_config.online_quant_config, + ) + + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + from atom.models.deepseek_v4_mtp import DeepseekV4MTP + from atom.plugin.register import ( + init_aiter_dist, + register_ops_to_sglang, + set_attn_cls, + ) + + register_ops_to_sglang(atom_config=self.atom_config) + set_attn_cls() + init_aiter_dist(config=self.atom_config) + + self.model = DeepseekV4MTP(config=self.atom_config) + self.model.atom_config = self.atom_config + _install_deepseek_v4_mtp_adapters(self.model) + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.shared_head = ParallelHead( + config.vocab_size, + config.hidden_size, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + hc_eps=getattr(config, "hc_eps", 1e-6), + ) + self._bind_shared_modules() + self.logits_head = _DeepseekV4MTPLogitsHeadAdapter(self.model) + self.logits_processor = LogitsProcessor(config, skip_all_gather=True) + + def _mtp_blocks(self): + return list(self.model.model.mtp) + + def _bind_shared_modules(self) -> None: + for block in self._mtp_blocks(): + block.embed = self.embed_tokens + block.head = self.shared_head + + def get_embed_and_head(self): + return self.embed_tokens.weight, self.shared_head.weight + + def set_embed_and_head(self, embed, head): + self.set_embed(embed) + _replace_weight(self.shared_head, "weight", head) + self._bind_shared_modules() + _sync_replaced_weights() + + def set_embed(self, embed): + _replace_weight(self.embed_tokens, "weight", embed) + self._bind_shared_modules() + _sync_replaced_weights() + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs, + ): + del input_embeds, kwargs + if forward_batch.spec_info is None: + raise ValueError("DeepSeek-V4 MTP draft forward requires speculative info") + + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + with SGLangPluginRuntime( + atom_config=self.atom_config, + forward_batch=forward_batch, + positions=positions, + input_ids=input_ids, + ) as runtime: + from atom.plugin.sglang.deepseek_v4_bridge import ( + bind_deepseek_v4_proxy_cache_views, + maybe_get_proxy_pool_from_sglang_backend, + reset_deepseek_v4_state_slots, + ) + + proxy_pool, _ = maybe_get_proxy_pool_from_sglang_backend() + if not bind_deepseek_v4_proxy_cache_views(self.model, proxy_pool): + raise RuntimeError( + "DeepSeek-V4 MTP SGLang proxy KV pool is not initialized" + ) + from atom.utils.forward_context import get_forward_context + + reset_slots = getattr( + get_forward_context().attn_metadata, "reset_slots", None + ) + reset_deepseek_v4_state_slots(self.model, reset_slots) + + model_hidden_states = forward_batch.spec_info.hidden_states + if runtime.forward_batch is not forward_batch: + model_hidden_states = _materialize_dummy_hidden_states( + model_hidden_states, + length=int(runtime.positions.shape[0]), + ) + + metadata = SGLangForwardBatchMetadata.build(runtime.forward_batch) + with SGLangForwardBatchMetadata.bind(metadata): + hidden_states = self.model( + input_ids=runtime.input_ids, + positions=runtime.positions, + hidden_states=model_hidden_states, + ) + + if self.pp_group.is_last_rank: + hidden_states = runtime.trim_output(hidden_states) + return self.logits_processor( + input_ids, + hidden_states, + self.logits_head, + forward_batch, + hidden_states_before_norm=hidden_states, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + del weights + from atom.model_loader.loader import load_model + + server_args = get_global_server_args() + draft_model_path = ( + server_args.speculative_draft_model_path or server_args.model_path + ) + self.atom_config.model = draft_model_path + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + return load_model( + model=self.model, + model_name_or_path=draft_model_path, + hf_config=self.atom_config.hf_config, + load_dummy=self.atom_config.load_dummy, + spec_decode=True, + ) + + +EntryClass = [DeepseekV4ForCausalLMNextN] From a6ecff4b1b3c99d41c35dee6447d54060cca6f11 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Fri, 26 Jun 2026 07:01:10 +0000 Subject: [PATCH 4/4] format --- atom/plugin/register.py | 19 +++++++++---------- .../attention_backend/deepseek_v4_backend.py | 14 ++++++++------ atom/plugin/sglang/runtime/forward_context.py | 8 ++++---- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 074d50983f..566e42108e 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -163,16 +163,17 @@ def _patch_sglang_dsv4_spec_cuda_graph() -> None: def can_run(self, forward_batch): try: model_runner = getattr(self, "model_runner", None) - hf_config = getattr(getattr(model_runner, "model_config", None), "hf_config", None) + hf_config = getattr( + getattr(model_runner, "model_config", None), "hf_config", None + ) arches = getattr(hf_config, "architectures", None) or [] is_dsv4 = any("DeepseekV4" in str(arch) for arch in arches) mode = getattr(forward_batch, "forward_mode", None) - is_spec_extend = ( - bool(getattr(mode, "is_target_verify", lambda: False)()) - or bool( - getattr(mode, "is_draft_extend", lambda **kwargs: False)( - include_v2=True - ) + is_spec_extend = bool( + getattr(mode, "is_target_verify", lambda: False)() + ) or bool( + getattr(mode, "is_draft_extend", lambda **kwargs: False)( + include_v2=True ) ) if is_dsv4 and is_spec_extend: @@ -189,9 +190,7 @@ def can_run(self, forward_batch): def init_cuda_graphs(self): try: - arch = ( - self.draft_runner.model_config.hf_config.architectures[0] - ) + arch = self.draft_runner.model_config.hf_config.architectures[0] if arch == "DeepseekV4ForCausalLMNextN": self.cuda_graph_runner = None self.cuda_graph_runner_for_draft_extend = None diff --git a/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py b/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py index 273dfe61a4..7b42ea23c6 100644 --- a/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py +++ b/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py @@ -63,12 +63,14 @@ def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = Fals atom_model = getattr(getattr(self.model_runner, "model", None), "model", None) if forward_batch.forward_mode.is_decode_or_idle(): - self.atom_v4_graph_metadata = build_atom_v4_decode_graph_metadata_from_sglang( - forward_batch, - positions, - proxy_pool=self.token_to_kv_pool, - req_to_token_pool=self.req_to_token_pool, - model=atom_model, + self.atom_v4_graph_metadata = ( + build_atom_v4_decode_graph_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + model=atom_model, + ) ) else: self.atom_v4_graph_metadata = build_atom_v4_attention_metadata_from_sglang( diff --git a/atom/plugin/sglang/runtime/forward_context.py b/atom/plugin/sglang/runtime/forward_context.py index cda8d36f37..8bb17c14a9 100644 --- a/atom/plugin/sglang/runtime/forward_context.py +++ b/atom/plugin/sglang/runtime/forward_context.py @@ -125,7 +125,9 @@ def _resolve_num_tokens_across_dp( return num_tokens_across_dp -def _slice_v4_graph_metadata_for_capture(attn_metadata: Any, *, num_tokens: int, bs: int): +def _slice_v4_graph_metadata_for_capture( + attn_metadata: Any, *, num_tokens: int, bs: int +): """Narrow reusable V4 graph metadata to this capture bucket. The DSV4 fallback metadata is initialized at max graph size. SGLang then @@ -251,9 +253,7 @@ def _set_atom_forward_context( ATOMDeepseekV4BackendForSgl, ) - attn_metadata = ( - ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata - ) + attn_metadata = ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata if attn_metadata is not None: attn_metadata = _slice_v4_graph_metadata_for_capture( attn_metadata,