diff --git a/atom/plugin/register.py b/atom/plugin/register.py index 5158fed8cf..d6e33a6dcb 100644 --- a/atom/plugin/register.py +++ b/atom/plugin/register.py @@ -1,4 +1,5 @@ import logging +import os from atom.models.qwen3 import Qwen3ForCausalLM from atom.models.qwen3_moe import Qwen3MoeForCausalLM @@ -99,11 +100,602 @@ def create_dsv4_backend(runner): return ATOMDeepseekV4BackendForSgl(runner) +def _patch_sglang_dsv4_draft_backends() -> None: + """Route SGLang's hard-coded DSV4 speculative factories to ATOM. + + DraftBackendFactory constructs DeepSeek-V4 draft backends directly instead + of going through the attention registry. SGLang's native backend asserts a + native DeepSeekV4TokenToKVPool, while ATOM plugin mode uses a proxy KV pool, + so patch the factory methods to return the ATOM shim. + """ + + try: + from sglang.srt.speculative.draft_utils import DraftBackendFactory + from atom.plugin.sglang.attention_backend.deepseek_v4_backend import ( + ATOMDeepseekV4BackendForSgl, + ) + except Exception as exc: + logger.debug("Skip patching SGLang DSV4 draft backends: %s", exc) + return + + if getattr(DraftBackendFactory, "_atom_dsv4_draft_backend_patched", False): + return + + def _create_atom_dsv4_decode_backend(self): + return ATOMDeepseekV4BackendForSgl( + self.draft_model_runner, + topk=self.topk, + speculative_num_steps=self.speculative_num_steps, + ) + + def _create_atom_dsv4_prefill_backend(self): + return ATOMDeepseekV4BackendForSgl( + self.draft_model_runner, + skip_prefill=False, + ) + + DraftBackendFactory._create_dsv4_decode_backend = _create_atom_dsv4_decode_backend + DraftBackendFactory._create_dsv4_prefill_backend = _create_atom_dsv4_prefill_backend + DraftBackendFactory._atom_dsv4_draft_backend_patched = True + logger.info("Patched SGLang DSV4 speculative draft backends to ATOM") + + +def _patch_sglang_dsv4_spec_cuda_graph() -> None: + """Patch SGLang speculative CUDA graph handling for ATOM DSV4. + + SGLang's draft graph buffers store hidden states as flattened + ``spec_hidden_size`` tensors. ATOM DSV4 keeps the mHC residual as + ``[tokens, hc, hidden]``. Flatten just for graph replay input staging, then + let the ATOM NextN wrapper reshape it back before running the MTP block. + """ + + try: + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( + EAGLEDraftCudaGraphRunner, + ) + from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import ( + EAGLEDraftExtendCudaGraphRunner, + ) + from sglang.srt.speculative.eagle_worker_v2 import EagleDraftWorker + except Exception as exc: + logger.debug("Skip patching SGLang DSV4 spec cuda graph: %s", exc) + return + + def _is_dsv4_nextn_runner(runner) -> bool: + try: + arches = ( + getattr( + getattr(getattr(runner, "model_config", None), "hf_config", None), + "architectures", + None, + ) + or [] + ) + return any("DeepseekV4ForCausalLMNextN" in str(arch) for arch in arches) + except Exception: + return False + + def _is_dsv4_runner(runner) -> bool: + try: + arches = ( + getattr( + getattr(getattr(runner, "model_config", None), "hf_config", None), + "architectures", + None, + ) + or [] + ) + return any("DeepseekV4" in str(arch) for arch in arches) + except Exception: + return False + + def _flatten_spec_hidden_states(forward_batch): + spec_info = getattr(forward_batch, "spec_info", None) + hidden_states = getattr(spec_info, "hidden_states", None) + if hidden_states is None or getattr(hidden_states, "dim", lambda: 0)() <= 2: + return None + flattened = hidden_states.reshape(hidden_states.shape[0], -1) + input_ids = getattr(forward_batch, "input_ids", None) + num_tokens = int(input_ids.shape[0]) if hasattr(input_ids, "shape") else 0 + mode = getattr(forward_batch, "forward_mode", None) + is_draft_extend = bool( + getattr(mode, "is_draft_extend", lambda **kwargs: False)(include_v2=True) + ) + if is_draft_extend and num_tokens > 0 and flattened.shape[0] != num_tokens: + if num_tokens % int(flattened.shape[0]) != 0: + raise RuntimeError( + "DSV4 speculative hidden layout cannot be expanded for graph " + f"input: hidden={tuple(hidden_states.shape)} " + f"flattened={tuple(flattened.shape)} num_tokens={num_tokens}" + ) + flattened = flattened.repeat_interleave( + num_tokens // int(flattened.shape[0]), dim=0 + ) + spec_info.hidden_states = flattened + return hidden_states + + def _env_flag(name: str) -> bool: + return os.environ.get(name, "0").lower() in ("1", "true", "yes", "on") + + def _is_dsv4_flash_runner(runner) -> bool: + model_path = str( + getattr(getattr(runner, "server_args", None), "model_path", "") + or getattr(getattr(runner, "model_config", None), "path", "") + ) + return "DeepSeek-V4-Flash" in model_path + + def _is_dsv4_pro_runner(runner) -> bool: + model_path = str( + getattr(getattr(runner, "server_args", None), "model_path", "") + or getattr(getattr(runner, "model_config", None), "path", "") + ) + return "DeepSeek-V4-Pro" in model_path + + def _draft_extend_graph_enabled(runner) -> bool: + if _env_flag("ATOM_SGLANG_V4_DISABLE_DRAFT_EXTEND_CG"): + return False + return _env_flag("ATOM_SGLANG_V4_ENABLE_DRAFT_EXTEND_CG") or ( + _is_dsv4_nextn_runner(runner) and _is_dsv4_flash_runner(runner) + ) + + def _target_verify_graph_enabled() -> bool: + return _env_flag("ATOM_SGLANG_V4_ENABLE_TARGET_VERIFY_CG") and not _env_flag( + "ATOM_SGLANG_V4_DISABLE_TARGET_VERIFY_CG" + ) + + def _safe_spec_graph_bs(original_bs, env_name: str): + configured = os.environ.get(env_name) + if not configured: + return list(original_bs) + allowed = {int(x) for x in configured.replace(" ", ",").split(",") if x.strip()} + return [bs for bs in original_bs if int(bs) in allowed] + + if not getattr(CudaGraphRunner, "_atom_dsv4_init_patched", False): + original_target_init = CudaGraphRunner.__init__ + + def __init__(self, model_runner, *args, **kwargs): + should_cap = False + server_args = getattr(model_runner, "server_args", None) + original_cuda_graph_bs = ( + list(getattr(server_args, "cuda_graph_bs", [])) + if server_args is not None + else None + ) + try: + should_cap = _is_dsv4_runner(model_runner) and bool( + getattr( + getattr(model_runner, "spec_algorithm", None), + "is_speculative", + lambda: False, + )() + ) + should_cap = ( + should_cap + and not getattr(model_runner, "is_draft_worker", False) + and _target_verify_graph_enabled() + ) + except Exception: + should_cap = False + + try: + if should_cap and server_args is not None and original_cuda_graph_bs: + server_args.cuda_graph_bs = _safe_spec_graph_bs( + original_cuda_graph_bs, + "ATOM_SGLANG_V4_TARGET_VERIFY_CG_BS", + ) + original_target_init(self, model_runner, *args, **kwargs) + finally: + if ( + should_cap + and server_args is not None + and original_cuda_graph_bs is not None + ): + server_args.cuda_graph_bs = original_cuda_graph_bs + + CudaGraphRunner.__init__ = __init__ + CudaGraphRunner._atom_dsv4_init_patched = True + + if not getattr(CudaGraphRunner, "_atom_dsv4_spec_can_run_patched", False): + original_can_run = CudaGraphRunner.can_run + + def can_run(self, forward_batch): + try: + model_runner = getattr(self, "model_runner", None) + hf_config = getattr( + getattr(model_runner, "model_config", None), "hf_config", None + ) + arches = getattr(hf_config, "architectures", None) or [] + is_dsv4 = any("DeepseekV4" in str(arch) for arch in arches) + mode = getattr(forward_batch, "forward_mode", None) + is_target_verify = bool( + getattr(mode, "is_target_verify", lambda: False)() + ) + is_draft_extend = bool( + getattr(mode, "is_draft_extend", lambda **kwargs: False)( + include_v2=True + ) + ) + if is_dsv4 and is_target_verify and not _target_verify_graph_enabled(): + return False + if is_dsv4 and is_draft_extend: + return False + except Exception: + pass + return original_can_run(self, forward_batch) + + CudaGraphRunner.can_run = can_run + CudaGraphRunner._atom_dsv4_spec_can_run_patched = True + + if not getattr(EAGLEDraftCudaGraphRunner, "_atom_dsv4_replay_patched", False): + original_draft_replay = EAGLEDraftCudaGraphRunner.replay + + def replay(self, forward_batch): + if not _is_dsv4_nextn_runner(getattr(self, "model_runner", None)): + return original_draft_replay(self, forward_batch) + if _env_flag("ATOM_SGLANG_V4_DISABLE_DRAFT_CG"): + raise RuntimeError( + "DSV4 draft cuda graph replay was disabled after capture; " + "disable it before graph initialization instead." + ) + original_hidden_states = _flatten_spec_hidden_states(forward_batch) + try: + return original_draft_replay(self, forward_batch) + finally: + if original_hidden_states is not None: + forward_batch.spec_info.hidden_states = original_hidden_states + + EAGLEDraftCudaGraphRunner.replay = replay + EAGLEDraftCudaGraphRunner._atom_dsv4_replay_patched = True + + if not getattr(EAGLEDraftExtendCudaGraphRunner, "_atom_dsv4_replay_patched", False): + original_extend_replay = EAGLEDraftExtendCudaGraphRunner.replay + original_extend_can_run = EAGLEDraftExtendCudaGraphRunner.can_run + + def _dsv4_draft_extend_graph_layout_ok(runner, forward_batch=None): + try: + num_draft_tokens = int(getattr(runner, "num_tokens_per_bs", 0) or 0) + if num_draft_tokens <= 0: + return False + raw_bs = int(getattr(forward_batch, "batch_size", 0) or 0) + if raw_bs <= 0: + raw_bs = min(getattr(runner, "capture_bs", [0]) or [0]) + if raw_bs <= 0: + return False + if forward_batch is not None and getattr( + runner, "require_mlp_tp_gather", False + ): + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = max_num_tokens // num_draft_tokens + else: + max_batch_size = raw_bs + import bisect + + index = bisect.bisect_left(runner.capture_bs, max_batch_size) + if index >= len(runner.capture_bs): + return False + bs = runner.capture_bs[index] + output = runner.output_buffers.get(bs) + logits = getattr(output, "next_token_logits", None) + expected = bs * num_draft_tokens + if logits is None or int(logits.shape[0]) < expected: + return False + return True + except Exception: + return False + + def can_run(self, forward_batch): + if not _is_dsv4_nextn_runner(getattr(self, "model_runner", None)): + return original_extend_can_run(self, forward_batch) + if not original_extend_can_run(self, forward_batch): + return False + return _dsv4_draft_extend_graph_layout_ok(self, forward_batch) + + def replay(self, forward_batch): + if not _is_dsv4_nextn_runner(getattr(self, "model_runner", None)): + return original_extend_replay(self, forward_batch) + if not _draft_extend_graph_enabled(getattr(self, "model_runner", None)): + raise RuntimeError( + "DSV4 draft-extend cuda graph replay was disabled after capture; " + "disable it before graph initialization instead." + ) + original_hidden_states = _flatten_spec_hidden_states(forward_batch) + backend = getattr(self, "draft_extend_attn_backend", None) + previous_runner = ( + getattr(backend, "_atom_dsv4_draft_extend_graph_runner", None) + if backend is not None + else None + ) + previous_replay_batch = ( + getattr(backend, "_replay_forward_batch", None) + if backend is not None + else None + ) + try: + if backend is not None: + backend._atom_dsv4_draft_extend_graph_runner = self + buffers = getattr(self, "buffers", None) + input_ids = getattr(forward_batch, "input_ids", None) + num_tokens = ( + int(input_ids.shape[0]) if hasattr(input_ids, "shape") else 0 + ) + if buffers is not None and num_tokens > 0: + from types import SimpleNamespace + + backend._replay_forward_batch = SimpleNamespace( + forward_mode=getattr(forward_batch, "forward_mode", None), + positions=getattr(buffers, "positions", None)[:num_tokens], + out_cache_loc=getattr(buffers, "out_cache_loc", None)[ + :num_tokens + ], + ) + out = original_extend_replay(self, forward_batch) + try: + # EAGLE V2 consumes draft-extend logits with a fixed + # `seq * speculative_num_draft_tokens + offset` layout. + # SGLang's runner trims to the actual compact token count, + # which makes that indexing OOB when fewer than the padded + # graph tokens were materialized. Return the captured + # padded output buffer for DSV4 so downstream indexing stays + # within the fixed graph layout. + if bool( + getattr( + getattr(self, "forward_mode", None), + "is_draft_extend_v2", + lambda: False, + )() + ): + padded_out = getattr(self, "output_buffers", {}).get( + getattr(self, "bs", None) + ) + if padded_out is not None: + out = padded_out + except Exception: + logger.exception( + "Failed to restore padded DSV4 draft-extend graph output" + ) + return out + finally: + if backend is not None: + if previous_runner is None: + try: + delattr(backend, "_atom_dsv4_draft_extend_graph_runner") + except AttributeError: + pass + else: + backend._atom_dsv4_draft_extend_graph_runner = previous_runner + if previous_replay_batch is None: + try: + delattr(backend, "_replay_forward_batch") + except AttributeError: + pass + else: + backend._replay_forward_batch = previous_replay_batch + if original_hidden_states is not None: + forward_batch.spec_info.hidden_states = original_hidden_states + + EAGLEDraftExtendCudaGraphRunner.can_run = can_run + EAGLEDraftExtendCudaGraphRunner.replay = replay + EAGLEDraftExtendCudaGraphRunner._atom_dsv4_replay_patched = True + + if not getattr(EagleDraftWorker, "_atom_dsv4_draft_extend_accept_patched", False): + original_draft_extend_for_decode = EagleDraftWorker._draft_extend_for_decode + + def _draft_extend_for_decode(self, batch, batch_result): + try: + if ( + not _is_dsv4_nextn_runner(getattr(self, "draft_runner", None)) + or getattr(self, "cuda_graph_runner_for_draft_extend", None) is None + ): + return original_draft_extend_for_decode(self, batch, batch_result) + + import torch + from sglang.srt.speculative.eagle_info import EagleDraftInput + from sglang.srt.speculative.spec_utils import fast_topk + + num_draft_tokens = int( + getattr(self, "speculative_num_draft_tokens", 0) + or getattr(self.server_args, "speculative_num_draft_tokens", 0) + or 0 + ) + if num_draft_tokens <= 0: + return original_draft_extend_for_decode(self, batch, batch_result) + + if not _dsv4_draft_extend_graph_layout_ok( + self.cuda_graph_runner_for_draft_extend + ): + runner = self.cuda_graph_runner_for_draft_extend + self.cuda_graph_runner_for_draft_extend = None + try: + return original_draft_extend_for_decode( + self, batch, batch_result + ) + finally: + self.cuda_graph_runner_for_draft_extend = runner + + accept_lens = getattr(batch_result, "accept_lens", None) + if not torch.is_tensor(accept_lens): + return original_draft_extend_for_decode(self, batch, batch_result) + + # DRAFT_EXTEND_V2 materializes exactly `num_draft_tokens` slots + # per sequence. `accept_lens` includes the target bonus token, + # so the value can be `num_draft_tokens + 1`; using that directly + # in the fixed-layout index points one slot past the graph output. + graph_accept_lens = accept_lens.clamp(min=1, max=num_draft_tokens) + + draft_input = EagleDraftInput( + hidden_states=batch_result.logits_output.hidden_states, + num_tokens_per_req=self.speculative_num_steps + 1, + num_tokens_for_logprob_per_req=self.speculative_num_steps + 1, + ) + select_index = ( + torch.arange(len(batch.seq_lens), device=self.device) + * num_draft_tokens + + graph_accept_lens + - 1 + ) + + with self.plan_stream_ctx: + forward_batch = ( + draft_input.prepare_for_extend_to_fill_draft_kvcache( + batch, + batch_result.next_token_ids, + num_draft_tokens, + self.draft_runner, + self.cuda_graph_runner_for_draft_extend, + ) + ) + + if self.plan_stream: + torch.get_device_module(self.device).current_stream().wait_stream( + self.plan_stream + ) + + # The graph only fills draft slots. Keep the scheduler-facing + # `batch_result.accept_lens` untouched, but make the graph's + # per-sequence counts match the fixed draft-token layout. + forward_batch.spec_info.num_correct_drafts = graph_accept_lens - 1 + forward_batch.spec_info.num_accept_tokens = graph_accept_lens + + can_cuda_graph = ( + self.cuda_graph_runner_for_draft_extend + and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) + ) + if can_cuda_graph: + draft_logits_output = ( + self.cuda_graph_runner_for_draft_extend.replay(forward_batch) + ) + else: + draft_logits_output = self.draft_runner.forward( + forward_batch, skip_attn_backend_init=True + ).logits_output + + output_len = int(draft_logits_output.next_token_logits.shape[0]) + max_index = ( + int(select_index.max().detach().cpu()) + if select_index.numel() + else -1 + ) + if max_index >= output_len and can_cuda_graph: + draft_logits_output = self.draft_runner.forward( + forward_batch, skip_attn_backend_init=True + ).logits_output + can_cuda_graph = False + output_len = int(draft_logits_output.next_token_logits.shape[0]) + if max_index >= output_len: + raise RuntimeError( + "DSV4 DRAFT_EXTEND_V2 output/index layout mismatch: " + f"max_index={max_index}, output_len={output_len}, " + f"batch={len(batch.seq_lens)}, " + f"num_draft_tokens={num_draft_tokens}, " + f"can_cuda_graph={bool(can_cuda_graph)}" + ) + + selected_logits = draft_logits_output.next_token_logits.index_select( + 0, select_index + ) + selected_hidden_states = draft_logits_output.hidden_states + if draft_logits_output.hidden_states is not None: + selected_hidden_states = ( + draft_logits_output.hidden_states.index_select(0, select_index) + ) + + probs = torch.softmax(selected_logits, dim=-1) + ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1) + + next_draft_input = batch_result.next_draft_input + ( + next_draft_input.topk_p, + next_draft_input.topk_index, + next_draft_input.hidden_states, + ) = ( + ret_topk_p, + ret_topk_index, + selected_hidden_states, + ) + return None + except Exception: + raise + + EagleDraftWorker._draft_extend_for_decode = _draft_extend_for_decode + EagleDraftWorker._atom_dsv4_draft_extend_accept_patched = True + + if not getattr(EagleDraftWorker, "_atom_dsv4_init_cuda_graphs_patched", False): + original_init_cuda_graphs = EagleDraftWorker.init_cuda_graphs + + def init_cuda_graphs(self): + ret = original_init_cuda_graphs(self) + try: + if _env_flag( + "ATOM_SGLANG_V4_DISABLE_DRAFT_CG" + ) and _is_dsv4_nextn_runner(getattr(self, "draft_runner", None)): + self.cuda_graph_runner = None + if ( + self.cuda_graph_runner_for_draft_extend is None + and _is_dsv4_nextn_runner(getattr(self, "draft_runner", None)) + and not self.server_args.disable_cuda_graph + and _draft_extend_graph_enabled(getattr(self, "draft_runner", None)) + and self.draft_extend_attn_backend is not None + ): + seq_len_fill = max( + 1024, + int( + getattr(self.server_args, "speculative_num_draft_tokens", 1) + or 1 + ), + ) + for backend in ( + getattr( + getattr(self, "draft_runner", None), "attn_backend", None + ), + getattr(self, "draft_extend_attn_backend", None), + ): + if backend is not None and hasattr( + backend, "_cuda_graph_seq_len_fill_value" + ): + backend._cuda_graph_seq_len_fill_value = seq_len_fill + draft_runner = getattr(self, "draft_runner", None) + server_args = getattr(draft_runner, "server_args", None) + original_cuda_graph_bs = ( + list(getattr(server_args, "cuda_graph_bs", [])) + if server_args is not None + else None + ) + try: + if server_args is not None and original_cuda_graph_bs: + server_args.cuda_graph_bs = _safe_spec_graph_bs( + original_cuda_graph_bs, + "ATOM_SGLANG_V4_DRAFT_EXTEND_CG_BS", + ) + self.cuda_graph_runner_for_draft_extend = ( + EAGLEDraftExtendCudaGraphRunner(self) + ) + finally: + if ( + server_args is not None + and original_cuda_graph_bs is not None + ): + server_args.cuda_graph_bs = original_cuda_graph_bs + elif _is_dsv4_nextn_runner(getattr(self, "draft_runner", None)): + self.cuda_graph_runner_for_draft_extend = None + except Exception as exc: + logger.warning( + "Failed to enable DSV4 draft-extend cuda graph in ATOM plugin: %s", + exc, + ) + return ret + + EagleDraftWorker.init_cuda_graphs = init_cuda_graphs + EagleDraftWorker._atom_dsv4_init_cuda_graphs_patched = True + + def register_ops_to_sglang(atom_config: Config) -> None: """ Register custom ops to sglang, including attention """ _register_custom_attention_to_sglang() + _patch_sglang_dsv4_draft_backends() + _patch_sglang_dsv4_spec_cuda_graph() def set_attn_cls() -> None: diff --git a/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py b/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py index 9bc772add7..db775a4ef5 100644 --- a/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py +++ b/atom/plugin/sglang/attention_backend/deepseek_v4_backend.py @@ -16,9 +16,10 @@ class ATOMDeepseekV4BackendForSgl(AttentionBackend): """ needs_cpu_seq_lens = True + _last_atom_v4_graph_metadata = None def __init__(self, model_runner, *args, **kwargs): - del args, kwargs + del args logger.info("Initializing ATOMDeepseekV4BackendForSgl") self.model_runner = model_runner self.device = torch.device(model_runner.device) @@ -26,6 +27,13 @@ def __init__(self, model_runner, *args, **kwargs): self.req_to_token_pool = model_runner.req_to_token_pool self.forward_metadata = None self.atom_v4_graph_metadata = None + self._cuda_graph_seq_len_fill_value = 1 + speculative_num_steps = int(kwargs.pop("speculative_num_steps", 0) or 0) + # SGLang EAGLE multi-step draft code expects decode backends to expose + # one attention backend per draft step. ATOM DSV4 owns the real + # per-layer state in the model/bridge, so all draft steps can share this + # shim instance. + self.attn_backends = [self] * max(1, speculative_num_steps) @staticmethod def get_name() -> str: @@ -37,15 +45,31 @@ def init_forward_metadata(self, forward_batch): def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = False): self.forward_metadata = forward_batch + is_draft_extend = bool( + getattr( + forward_batch.forward_mode, "is_draft_extend", lambda **kwargs: False + )(include_v2=True) + ) + draft_extend_runner = getattr( + self, "_atom_dsv4_draft_extend_graph_runner", None + ) + if ( + is_draft_extend + and draft_extend_runner is not None + and not hasattr(forward_batch, "actual_forward_mode") + ): + forward_batch = self._build_draft_extend_replay_metadata_view( + forward_batch, draft_extend_runner + ) + self.forward_metadata = forward_batch + if not (in_capture or hasattr(forward_batch, "actual_forward_mode")): self.atom_v4_graph_metadata = None return - if not forward_batch.forward_mode.is_decode_or_idle(): - self.atom_v4_graph_metadata = None - return - from atom.plugin.sglang.deepseek_v4_bridge import ( + build_atom_v4_attention_metadata_from_sglang, build_atom_v4_decode_graph_metadata_from_sglang, + build_atom_v4_verify_graph_metadata_from_sglang, ) positions = getattr(forward_batch, "positions", None) @@ -58,13 +82,82 @@ def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = Fals return atom_model = getattr(getattr(self.model_runner, "model", None), "model", None) - self.atom_v4_graph_metadata = build_atom_v4_decode_graph_metadata_from_sglang( - forward_batch, - positions, - proxy_pool=self.token_to_kv_pool, - req_to_token_pool=self.req_to_token_pool, - model=atom_model, + if forward_batch.forward_mode.is_decode_or_idle(): + self.atom_v4_graph_metadata = ( + build_atom_v4_decode_graph_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + model=atom_model, + ) + ) + elif forward_batch.forward_mode.is_target_verify() or bool( + getattr( + forward_batch.forward_mode, "is_draft_extend", lambda **kwargs: False + )(include_v2=True) + ): + self.atom_v4_graph_metadata = ( + build_atom_v4_verify_graph_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + model=atom_model, + ) + ) + else: + self.atom_v4_graph_metadata = build_atom_v4_attention_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + ) + forward_batch.atom_v4_graph_metadata = self.atom_v4_graph_metadata + ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = ( + self.atom_v4_graph_metadata + ) + + def _build_draft_extend_replay_metadata_view(self, forward_batch, runner): + """Fill missing replay fields from EAGLE draft-extend graph buffers. + + SGLang's standalone draft-extend runner builds a lightweight replay + view that lacks `positions`/`actual_forward_mode` and keeps + `out_cache_loc` on the raw batch. ATOM's DSV4 graph metadata must + reference the padded CUDA graph buffers because those are the tensors + captured by the graph. + """ + buffers = getattr(runner, "buffers", None) + if buffers is None: + return forward_batch + + bs = int(getattr(forward_batch, "batch_size", 0) or 0) + spec_info = getattr(forward_batch, "spec_info", None) + tokens_per_req = int( + getattr(spec_info, "num_tokens_per_req", None) + or getattr(runner, "num_tokens_per_bs", 1) + or 1 ) + total = max(0, bs * max(1, tokens_per_req)) + + def _slice(name, stop): + value = getattr(buffers, name, None) + return value[:stop] if value is not None else None + + values = dict(getattr(forward_batch, "__dict__", {})) + values.update( + actual_forward_mode=getattr( + forward_batch, "actual_forward_mode", forward_batch.forward_mode + ), + input_ids=_slice("input_ids", total), + positions=_slice("positions", total), + req_pool_indices=_slice("req_pool_indices", bs), + seq_lens=_slice("seq_lens", bs), + seq_lens_cpu=_slice("seq_lens_cpu", bs), + out_cache_loc=_slice("out_cache_loc", total), + spec_info=spec_info, + ) + return SimpleNamespace(**values) def _init_decode_cuda_graph_metadata( self, @@ -114,18 +207,137 @@ def _init_decode_cuda_graph_metadata( req_to_token_pool=self.req_to_token_pool, model=atom_model, ) + forward_batch.atom_v4_graph_metadata = self.atom_v4_graph_metadata + ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = ( + self.atom_v4_graph_metadata + ) - def init_forward_metadata_capture_cuda_graph( + def _init_verify_cuda_graph_metadata( self, + *, bs: int, - num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens, forward_mode, - spec_info, - ): - del num_tokens, encoder_lens, spec_info + seq_lens_cpu=None, + out_cache_loc=None, + positions=None, + spec_info=None, + actual_forward_mode=None, + ) -> None: + is_graph_extend = forward_mode.is_target_verify() or bool( + getattr(forward_mode, "is_draft_extend", lambda **kwargs: False)( + include_v2=True + ) + ) + if not is_graph_extend: + self.atom_v4_graph_metadata = None + return + + def _positive_int(value): + try: + value = int(value) + except (TypeError, ValueError): + return None + return value if value > 0 else None + + tokens_per_req = _positive_int(getattr(spec_info, "num_tokens_per_req", None)) + if tokens_per_req is None: + tokens_per_req = _positive_int(getattr(spec_info, "draft_token_num", None)) + if tokens_per_req is None: + tokens_per_req = _positive_int( + getattr(spec_info, "speculative_num_draft_tokens", None) + ) + if tokens_per_req is None: + tokens_per_req = ( + max(1, int(positions.numel()) // max(1, int(bs))) + if positions is not None + else 1 + ) + tokens_per_req = int(tokens_per_req) + if positions is None: + base = (seq_lens[:bs].to(torch.int64) - tokens_per_req).clamp_min_(0) + offsets = torch.arange( + tokens_per_req, dtype=torch.int64, device=self.device + ) + positions = (base[:, None] + offsets[None, :]).reshape(-1) + elif positions.shape[0] < bs * tokens_per_req: + padded_positions = torch.zeros( + (bs * tokens_per_req,), dtype=torch.int64, device=self.device + ) + padded_positions[: positions.shape[0]].copy_(positions) + positions = padded_positions + if seq_lens_cpu is None: + seq_lens_cpu = seq_lens.detach().cpu() + + if spec_info is None: + spec_info = SimpleNamespace(num_tokens_per_req=tokens_per_req) + elif _positive_int(getattr(spec_info, "num_tokens_per_req", None)) is None: + spec_info_dict = dict(getattr(spec_info, "__dict__", {})) + spec_info_dict.pop("num_tokens_per_req", None) + spec_info = SimpleNamespace( + **spec_info_dict, + num_tokens_per_req=tokens_per_req, + ) + + forward_batch = SimpleNamespace( + forward_mode=forward_mode, + actual_forward_mode=actual_forward_mode or forward_mode, + batch_size=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + spec_info=spec_info, + ) + + from atom.plugin.sglang.deepseek_v4_bridge import ( + build_atom_v4_verify_graph_metadata_from_sglang, + ) + + atom_model = getattr(getattr(self.model_runner, "model", None), "model", None) + self.forward_metadata = forward_batch + self.atom_v4_graph_metadata = build_atom_v4_verify_graph_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + model=atom_model, + ) + forward_batch.atom_v4_graph_metadata = self.atom_v4_graph_metadata + ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = ( + self.atom_v4_graph_metadata + ) + + def init_forward_metadata_capture_cuda_graph(self, *args, **kwargs): + # New SGLang graph API passes a ForwardBatch. Older call sites pass + # unpacked fields. Support both because speculative draft graph code + # still calls this legacy-named hook directly. + if len(args) == 1 and not kwargs and hasattr(args[0], "forward_mode"): + return self.init_forward_metadata_out_graph(args[0], in_capture=True) + + bs = kwargs.get("bs", args[0] if len(args) > 0 else None) + req_pool_indices = kwargs.get( + "req_pool_indices", args[2] if len(args) > 2 else None + ) + seq_lens = kwargs.get("seq_lens", args[3] if len(args) > 3 else None) + forward_mode = kwargs.get("forward_mode", args[5] if len(args) > 5 else None) + spec_info = kwargs.get("spec_info", args[6] if len(args) > 6 else None) + if forward_mode is not None and ( + forward_mode.is_target_verify() + or bool( + getattr(forward_mode, "is_draft_extend", lambda **kwargs: False)( + include_v2=True + ) + ) + ): + return self._init_verify_cuda_graph_metadata( + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + forward_mode=forward_mode, + spec_info=spec_info, + ) self._init_decode_cuda_graph_metadata( bs=bs, req_pool_indices=req_pool_indices, @@ -133,19 +345,69 @@ def init_forward_metadata_capture_cuda_graph( forward_mode=forward_mode, ) - def init_forward_metadata_replay_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens, - forward_mode, - spec_info, - seq_lens_cpu, - ): - del seq_lens_sum, encoder_lens, spec_info + def init_forward_metadata_replay_cuda_graph(self, *args, **kwargs): + # Older SGLang draft graph runners call this hook as + # ``init_forward_metadata_replay_cuda_graph(forward_batch, bs)``. + # Newer runners pass unpacked fields. Support both so the ATOM plugin + # owns DSV4 compatibility without patching SGLang source. + if len(args) == 2 and hasattr(args[0], "forward_mode"): + forward_batch, bs = args + forward_mode = forward_batch.forward_mode + if forward_mode.is_target_verify() or bool( + getattr(forward_mode, "is_draft_extend", lambda **kwargs: False)( + include_v2=True + ) + ): + return self._init_verify_cuda_graph_metadata( + bs=bs, + req_pool_indices=forward_batch.req_pool_indices, + seq_lens=forward_batch.seq_lens, + seq_lens_cpu=getattr(forward_batch, "seq_lens_cpu", None), + forward_mode=forward_mode, + out_cache_loc=getattr(forward_batch, "out_cache_loc", None), + positions=getattr(forward_batch, "positions", None), + spec_info=getattr(forward_batch, "spec_info", None), + actual_forward_mode=forward_mode, + ) + return self._init_decode_cuda_graph_metadata( + bs=bs, + req_pool_indices=forward_batch.req_pool_indices, + seq_lens=forward_batch.seq_lens, + seq_lens_cpu=getattr(forward_batch, "seq_lens_cpu", None), + forward_mode=forward_mode, + out_cache_loc=getattr(forward_batch, "out_cache_loc", None), + positions=getattr(forward_batch, "positions", None), + actual_forward_mode=forward_mode, + ) + + bs = kwargs.get("bs", args[0] if len(args) > 0 else None) + req_pool_indices = kwargs.get( + "req_pool_indices", args[1] if len(args) > 1 else None + ) + seq_lens = kwargs.get("seq_lens", args[2] if len(args) > 2 else None) + seq_lens_sum = kwargs.get("seq_lens_sum", args[3] if len(args) > 3 else None) + encoder_lens = kwargs.get("encoder_lens", args[4] if len(args) > 4 else None) + forward_mode = kwargs.get("forward_mode", args[5] if len(args) > 5 else None) + spec_info = kwargs.get("spec_info", args[6] if len(args) > 6 else None) + seq_lens_cpu = kwargs.get("seq_lens_cpu", args[7] if len(args) > 7 else None) + del seq_lens_sum, encoder_lens replay_batch = getattr(self, "_replay_forward_batch", None) + if forward_mode.is_target_verify() or bool( + getattr(forward_mode, "is_draft_extend", lambda **kwargs: False)( + include_v2=True + ) + ): + return self._init_verify_cuda_graph_metadata( + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + forward_mode=forward_mode, + out_cache_loc=getattr(replay_batch, "out_cache_loc", None), + positions=getattr(replay_batch, "positions", None), + spec_info=spec_info, + actual_forward_mode=getattr(replay_batch, "forward_mode", forward_mode), + ) self._init_decode_cuda_graph_metadata( bs=bs, req_pool_indices=req_pool_indices, @@ -158,11 +420,163 @@ def init_forward_metadata_replay_cuda_graph( ) def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): - del max_bs, max_num_tokens + from sglang.srt.model_executor.forward_batch_info import ForwardMode + + from atom.plugin.sglang.deepseek_v4_bridge import ( + build_atom_v4_decode_graph_metadata_from_sglang, + build_atom_v4_verify_graph_metadata_from_sglang, + ) + + bs = int(max_bs) + tokens_per_req = max(1, int(max_num_tokens) // max(1, bs)) + seq_lens = torch.full( + (bs,), tokens_per_req, dtype=torch.int32, device=self.device + ) + req_pool_indices = torch.arange(bs, dtype=torch.int64, device=self.device) + positions = torch.arange(tokens_per_req, dtype=torch.int64, device=self.device) + positions = positions.repeat(bs) + is_target_verify_graph = bool( + getattr( + getattr(self.model_runner, "spec_algorithm", None), + "is_speculative", + lambda: False, + )() + and not getattr(self.model_runner, "is_draft_worker", False) + ) + is_draft_extend_graph = bool( + getattr( + getattr(self.model_runner, "spec_algorithm", None), + "is_speculative", + lambda: False, + )() + and getattr(self.model_runner, "is_draft_worker", False) + and tokens_per_req > 1 + ) + is_graph_extend = is_target_verify_graph or is_draft_extend_graph + forward_mode = ( + ForwardMode.TARGET_VERIFY + if is_target_verify_graph + else ( + ForwardMode.DRAFT_EXTEND_V2 + if is_draft_extend_graph + else ForwardMode.DECODE + ) + ) + self._cuda_graph_seq_len_fill_value = ( + max(tokens_per_req, 1024) if is_graph_extend else 1 + ) + if is_graph_extend: + seq_lens.fill_(self._cuda_graph_seq_len_fill_value) + positions = ( + torch.arange(tokens_per_req, dtype=torch.int64, device=self.device) + + (self._cuda_graph_seq_len_fill_value - tokens_per_req) + ).repeat(bs) + forward_batch = SimpleNamespace( + forward_mode=forward_mode, + actual_forward_mode=forward_mode, + batch_size=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.detach().cpu(), + out_cache_loc=None, + spec_info=SimpleNamespace(num_tokens_per_req=tokens_per_req), + ) + atom_model = getattr(getattr(self.model_runner, "model", None), "model", None) + if is_graph_extend: + self.atom_v4_graph_metadata = ( + build_atom_v4_verify_graph_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + model=atom_model, + ) + ) + else: + self.atom_v4_graph_metadata = ( + build_atom_v4_decode_graph_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=self.token_to_kv_pool, + req_to_token_pool=self.req_to_token_pool, + model=atom_model, + ) + ) + ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = ( + self.atom_v4_graph_metadata + ) return None def get_cuda_graph_seq_len_fill_value(self): - return 1 + return int(self._cuda_graph_seq_len_fill_value) + + def get_verify_buffers_to_fill_after_draft(self): + graph_runner = getattr(self.model_runner, "graph_runner", None) + buffers = getattr(graph_runner, "buffers", None) + if buffers is None: + return [None, None] + # Let SGLang's tree builder fill the captured mask buffer in-place. + # Keep positions allocated by the builder: it returns the full provided + # buffer, while replay_prepare expects an exact raw-token-length tensor. + return [getattr(buffers, "custom_mask", None), None] + + def update_verify_buffers_to_fill_after_draft(self, spec_info, cuda_graph_bs): + if cuda_graph_bs is None: + return + graph_runner = getattr(self.model_runner, "graph_runner", None) + buffers = getattr(graph_runner, "buffers", None) + if buffers is None: + return + + tokens_per_req = int( + getattr( + spec_info, + "num_tokens_per_req", + getattr(spec_info, "draft_token_num", 1), + ) + or 1 + ) + total = int(cuda_graph_bs) * tokens_per_req + + positions = getattr(spec_info, "positions", None) + if torch.is_tensor(positions): + copy_n = min(int(positions.numel()), total) + if copy_n: + buffers.positions[:copy_n].copy_(positions[:copy_n]) + if total > copy_n: + buffers.positions[copy_n:total].zero_() + positions = buffers.positions[:total] + else: + positions = buffers.positions[:total] + + custom_mask = getattr(spec_info, "custom_mask", None) + graph_custom_mask = getattr(buffers, "custom_mask", None) + if ( + torch.is_tensor(custom_mask) + and torch.is_tensor(graph_custom_mask) + and custom_mask.data_ptr() != graph_custom_mask.data_ptr() + ): + graph_custom_mask[: custom_mask.numel()].copy_(custom_mask) + + forward_mode = getattr( + getattr(self, "forward_metadata", None), "forward_mode", None + ) + if forward_mode is None: + return + seq_lens_cpu = getattr(buffers, "seq_lens_cpu", None) + self._init_verify_cuda_graph_metadata( + bs=int(cuda_graph_bs), + req_pool_indices=buffers.req_pool_indices[: int(cuda_graph_bs)], + seq_lens=buffers.seq_lens[: int(cuda_graph_bs)], + seq_lens_cpu=( + seq_lens_cpu[: int(cuda_graph_bs)] if seq_lens_cpu is not None else None + ), + forward_mode=forward_mode, + out_cache_loc=buffers.out_cache_loc[:total], + positions=positions, + spec_info=spec_info, + actual_forward_mode=forward_mode, + ) def forward_decode(self, *args, **kwargs): raise RuntimeError("ATOM DeepSeek-V4 SGLang bridge should use ATOM attention") diff --git a/atom/plugin/sglang/deepseek_v4_bridge.py b/atom/plugin/sglang/deepseek_v4_bridge.py index 33e3ca4e5c..5b69502a87 100644 --- a/atom/plugin/sglang/deepseek_v4_bridge.py +++ b/atom/plugin/sglang/deepseek_v4_bridge.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import os from types import SimpleNamespace from typing import Any, Optional @@ -8,15 +7,9 @@ import numpy as np import torch -logger = logging.getLogger("atom.plugin.sglang.deepseek_v4_bridge") - ATOM_DEEPSEEK_V4_BLOCK_SIZE = 128 -def _debug_enabled() -> bool: - return os.environ.get("ATOM_SGLANG_V4_DEBUG") == "1" - - def _aligned_index_dim(index_head_dim: int) -> int: # extra 4 bytes for scale, then 16-byte alignment. return ((int(index_head_dim) + 4 + 15) // 16) * 16 @@ -30,6 +23,26 @@ def _layer_counts(compress_ratios) -> tuple[list[int], int, int, int]: return ratios, dense, csa, hca +def _resolve_sglang_spec_steps() -> int: + try: + from sglang.srt.server_args import get_global_server_args + + server_args = get_global_server_args() + value = getattr(server_args, "speculative_num_steps", None) + if value is not None: + return max(0, int(value)) + except Exception: + pass + for name in ("ATOM_SGLANG_V4_MAX_SPEC_STEPS", "MTP_STEPS"): + try: + value = os.environ.get(name) + if value: + return max(0, int(value)) + except Exception: + pass + return 0 + + try: from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool except Exception: # pragma: no cover - SGLang import-time fallback @@ -103,9 +116,12 @@ def __init__( self.num_slots = max(1, self.max_num_reqs) # SGLang's DSV4 allocator is initialized with page_size/swa_page_size=256 # for paged-SWA bookkeeping, but ATOM V4-Pro attention uses a 128-token - # SWA ring/window. Keep the SGLang-facing size above intact and size all - # ATOM cache views + metadata with the native V4 window. + # attention window. Native ATOM sizes the SWA ring as + # ``window + max_spec_steps`` so MTP draft slots do not alias the + # verified-token window during speculative rounds. self.window_size = ATOM_DEEPSEEK_V4_BLOCK_SIZE + self.max_spec_steps = _resolve_sglang_spec_steps() + self.swa_cache_size = self.window_size + self.max_spec_steps # In the ATOM bridge layout one original-token block contributes one # HCA entry, so the HCA compressed-entry count is the physical block # count for the unified tails. @@ -121,18 +137,9 @@ def __init__( self.views = self._slice_views() self.is_atom_v4_proxy_pool = True - logger.info( - "Initialized ATOM DeepSeek-V4 SGLang proxy KV pool: " - "slots=%s blocks=%s layers=%s raw=%.2f MiB", - self.num_slots, - self.num_blocks, - len(self.stage_ratios), - total_bytes / (1 << 20), - ) - def _compute_raw_bytes(self) -> int: total = 0 - swa_bytes = self.num_slots * self.window_size * self.head_dim * 2 + swa_bytes = self.num_slots * self.swa_cache_size * self.head_dim * 2 for ratio in self.stage_ratios: total += swa_bytes if ratio == 4: @@ -169,11 +176,11 @@ def _slice_views(self) -> dict[str, list[torch.Tensor]]: for ratio in self.stage_ratios: layer_start = offset - swa_bytes = self.num_slots * self.window_size * self.head_dim * 2 + swa_bytes = self.num_slots * self.swa_cache_size * self.head_dim * 2 swa_view = ( self._take(offset, swa_bytes) .view(torch.bfloat16) - .view(self.num_slots, self.window_size, self.head_dim) + .view(self.num_slots, self.swa_cache_size, self.head_dim) ) offset += swa_bytes swa.append(swa_view) @@ -194,7 +201,7 @@ def _slice_views(self) -> dict[str, list[torch.Tensor]]: self.raw_arena[layer_start:offset] .view(torch.bfloat16) .view( - self.num_slots * self.window_size + self.num_blocks * k, + self.num_slots * self.swa_cache_size + self.num_blocks * k, self.head_dim, ) ) @@ -226,14 +233,14 @@ def _slice_views(self) -> dict[str, list[torch.Tensor]]: self.raw_arena[layer_start:offset] .view(torch.bfloat16) .view( - self.num_slots * self.window_size + self.num_blocks * k, + self.num_slots * self.swa_cache_size + self.num_blocks * k, self.head_dim, ) ) hca_main.append(main) else: unified.append( - swa_view.view(self.num_slots * self.window_size, self.head_dim) + swa_view.view(self.num_slots * self.swa_cache_size, self.head_dim) ) return { @@ -358,12 +365,6 @@ def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor) -> None: if allocator is not None: allocator.remap_blocks(block_pairs[:, 1], block_pairs[:, 0]) - if _debug_enabled(): - logger.info( - "ATOM V4 proxy relocated %d KV blocks for SGLang radix cache", - block_pairs.shape[0], - ) - def install_deepseek_v4_proxy_pool_patch() -> None: """Patch SGLang's DSV4 pool constructor before ModelRunner._init_pools(). @@ -381,7 +382,6 @@ def install_deepseek_v4_proxy_pool_patch() -> None: return mixin.DeepSeekV4TokenToKVPool = ATOMDeepSeekV4ProxyKVPool dsv4_pool.ATOMDeepSeekV4ProxyKVPool = ATOMDeepSeekV4ProxyKVPool - logger.info("Installed ATOM DeepSeek-V4 proxy KV pool patch for SGLang") def _bind_compressor_state( @@ -423,6 +423,19 @@ def _bind_compressor_state( compressor.cache_scale = None +def _iter_deepseek_v4_cache_blocks(model): + inner = getattr(model, "model", None) + if inner is None: + return [] + layers = getattr(inner, "layers", None) + if layers is not None: + return list(layers) + mtp = getattr(inner, "mtp", None) + if mtp is not None: + return list(mtp) + return [] + + def bind_deepseek_v4_proxy_cache_views(model, proxy_pool: Any) -> bool: """Bind the SGLang-visible proxy arena to ATOM V4 attention modules. @@ -439,7 +452,7 @@ def bind_deepseek_v4_proxy_cache_views(model, proxy_pool: Any) -> bool: csa_i = 0 hca_i = 0 - for local_layer_id, block in enumerate(model.model.layers): + for local_layer_id, block in enumerate(_iter_deepseek_v4_cache_blocks(model)): attn = block.attn ratio = int(attn.compress_ratio) attn.unified_kv = proxy_pool.views["unified"][local_layer_id] @@ -474,10 +487,9 @@ def bind_deepseek_v4_proxy_cache_views(model, proxy_pool: Any) -> bool: model._atom_v4_meta_params = SimpleNamespace( num_slots=proxy_pool.num_slots, window_size=proxy_pool.window_size, - cs=proxy_pool.window_size, + cs=proxy_pool.swa_cache_size, index_topk=int(getattr(model.args, "index_topk", 1024)), ) - logger.info("Bound ATOM DeepSeek-V4 proxy cache views to model") return True @@ -695,6 +707,89 @@ def stage(self, buf, arr_np, n: Optional[int] = None): return buf.copy_to_gpu(n) +class _V4SGLangVerifyGraphBuffers: + """Persistent fixed-address target-verify metadata buffers. + + Target verify is extend-shaped (``bs * draft_tokens`` query tokens), but + CUDA graph replay has the same pointer-stability requirement as decode. + These buffers mirror the eager prefill metadata fields while keeping every + tensor address stable across capture and replay. + """ + + def __init__( + self, + *, + num_slots: int, + max_verify_tokens: int, + window: int, + index_topk: int, + max_committed_hca: int, + max_blocks: int, + device: torch.device, + ) -> None: + from atom.utils import CpuGpuBuffer + + self.device = device + self.num_slots = max(1, int(num_slots)) + self.max_verify_tokens = max(1, int(max_verify_tokens)) + self.window = int(window) + self.index_topk = int(index_topk) + self.max_committed_hca = max(1, int(max_committed_hca)) + self.max_blocks = max(1, int(max_blocks)) + + def i32(*shape): + return CpuGpuBuffer(*shape, dtype=torch.int32, device=device) + + t = self.max_verify_tokens + s = self.num_slots + win = self.window + topk = self.index_topk + hca = self.max_committed_hca + + self.cu_q = i32(s + 1) + self.state_slot = i32(s) + self.n_csa = i32(s) + self.n_hca = i32(s) + self.batch_id = i32(t) + self.block_tables = i32(s, self.max_blocks) + + self.indptr_extend = i32(t + 1) + self.indptr_prefix_swa = i32(t + 1) + self.indptr_prefix_csa = i32(t + 1) + self.indptr_prefix_hca = i32(t + 1) + self.idx_extend = i32(t * max(1, win)) + self.idx_prefix_swa = i32(t * max(1, win)) + self.idx_prefix_csa = i32(t * max(1, win + topk)) + self.idx_prefix_hca = i32(t * max(1, win + hca)) + self.skip_prefix_len_csa = i32(t) + self.chunk_start_per_seq = i32(s) + + self.indexer_cu_committed = i32(s + 1) + self.indexer_seq_base = i32(t) + self.indexer_cu_ends = i32(t) + + self.plan_buffers = { + 4: { + "compress": i32(t, 4), + "write": i32(t * 4, 4), + }, + 128: { + "compress": i32(t, 4), + "write": i32(t * 128, 4), + }, + } + self.verify_compress_cap = {4: t, 128: t} + + def stage(self, buf, arr_np, n: Optional[int] = None): + n = int(arr_np.shape[0]) if n is None else int(n) + assert ( + n <= buf.np.shape[0] + ), f"V4 verify graph buffer too small: need {n}, have {buf.np.shape[0]}" + if n: + buf.np[:n] = arr_np[:n] + return buf.copy_to_gpu(n) + + def _make_decode_graph_compress_plans(extend_lens_cpu, context_lens_cpu, bufs): from atom.model_ops.v4_kernels.compress_plan import make_compress_plans @@ -738,6 +833,91 @@ def _get_extend_lens_cpu( ) +def _make_verify_graph_compress_plans(extend_lens_cpu, context_lens_cpu, bufs): + from atom.model_ops.v4_kernels.compress_plan import make_compress_plans + + return make_compress_plans( + np.ascontiguousarray(extend_lens_cpu, dtype=np.int32), + np.ascontiguousarray(context_lens_cpu, dtype=np.int32), + [(4, True), (128, False)], + plan_buffers=bufs.plan_buffers, + decode_capacity_per_ratio=bufs.verify_compress_cap, + ) + + +def _make_verify_graph_compress_plans_from_positions(pos_np, batch_np, bs: int, bufs): + from atom.model_ops.v4_kernels.compress_plan import CompressPlan + + pos_np = np.ascontiguousarray(pos_np, dtype=np.int32) + batch_np = np.ascontiguousarray(batch_np, dtype=np.int32) + total = int(pos_np.shape[0]) + out = {} + if total == 0 or bs == 0: + return _make_verify_graph_compress_plans( + np.zeros(bs, dtype=np.int32), + np.zeros(bs, dtype=np.int32), + bufs, + ) + + chunk_start = np.zeros(bs, dtype=np.int32) + context_after = np.zeros(bs, dtype=np.int32) + for b in range(bs): + mask = batch_np == b + if np.any(mask): + bpos = pos_np[mask] + chunk_start[b] = int(bpos[0]) + context_after[b] = int(bpos.max()) + 1 + ragged_ids = np.arange(total, dtype=np.int32) + + for ratio, overlap in ((4, True), (128, False)): + K = ratio * (2 if overlap else 1) + token_pos_in_chunk = pos_np - chunk_start[batch_np] + window_lens = np.maximum(0, K - np.minimum(token_pos_in_chunk + 1, K)).astype( + np.int32 + ) + plan_rows = np.stack( + [ragged_ids, batch_np, pos_np, window_lens], axis=1 + ).astype(np.int32) + compress_plan = plan_rows[(pos_np + 1) % ratio == 0] + compress_counts = ( + np.bincount(compress_plan[:, 1], minlength=bs).astype(np.int32) + if compress_plan.size + else np.zeros(bs, dtype=np.int32) + ) + cu_compress = np.empty(bs + 1, dtype=np.int32) + cu_compress[0] = 0 + np.cumsum(compress_counts, out=cu_compress[1:]) + + write_starts = np.maximum(0, context_after - K).astype(np.int32) + write_plan = plan_rows[pos_np >= write_starts[batch_np]] + n_compress = int(compress_plan.shape[0]) + n_write = int(write_plan.shape[0]) + + cbuf = bufs.plan_buffers[ratio]["compress"] + wbuf = bufs.plan_buffers[ratio]["write"] + slice_cap = int(bufs.verify_compress_cap[ratio]) + assert n_compress <= slice_cap <= cbuf.np.shape[0] + assert n_write <= wbuf.np.shape[0] + + if n_compress: + cbuf.np[:n_compress] = compress_plan + if slice_cap > n_compress: + cbuf.np[n_compress:slice_cap].fill(-1) + if n_write: + wbuf.np[:n_write] = write_plan + wbuf.np[n_write:].fill(-1) + + out[ratio] = CompressPlan( + compress_plan_gpu=cbuf.copy_to_gpu(slice_cap), + write_plan_gpu=wbuf.copy_to_gpu(), + num_compress=n_compress, + num_write=n_write, + cu_compress_cpu=cu_compress, + compress_plan_cpu=compress_plan if n_compress else None, + ) + return out + + def _infer_atom_attn_state(forward_batch) -> Any: """Map SGLang forward mode to the ATOM V4 attention state. @@ -772,7 +952,9 @@ def _get_seq_lens_cpu(forward_batch) -> np.ndarray: seq_lens_cpu = getattr(forward_batch, "seq_lens_cpu", None) if seq_lens_cpu is None: seq_lens_cpu = forward_batch.seq_lens.detach().cpu() - return seq_lens_cpu.numpy().astype(np.int32) + if torch.is_tensor(seq_lens_cpu): + seq_lens_cpu = seq_lens_cpu.detach().cpu().numpy() + return np.asarray(seq_lens_cpu, dtype=np.int32) def _build_block_tables( @@ -816,6 +998,7 @@ def build_atom_v4_decode_graph_metadata_from_sglang( ) is_idle = bool(getattr(actual_mode, "is_idle", lambda: False)()) out_cache_loc = getattr(forward_batch, "out_cache_loc", None) + positions_numel = int(positions.numel()) scheduled_bs = ( 0 if is_idle @@ -825,15 +1008,20 @@ def build_atom_v4_decode_graph_metadata_from_sglang( else bs ) ) - total = scheduled_bs + total = max(scheduled_bs, positions_numel) t_pad = bs max_blocks = max(1, proxy_pool.num_blocks) bufs = getattr(proxy_pool, "_atom_v4_decode_graph_buffers", None) - if bufs is None or bufs.num_slots < bs or bufs.max_blocks < max_blocks: + if ( + bufs is None + or bufs.num_slots < bs + or bufs.max_blocks < max_blocks + or bufs.max_decode_tokens < total + ): bufs = proxy_pool._atom_v4_decode_graph_buffers = _V4SGLangDecodeGraphBuffers( num_slots=proxy_pool.num_slots, - max_decode_tokens=max(proxy_pool.num_slots, bs), + max_decode_tokens=max(proxy_pool.num_slots, bs, total), window=proxy_pool.window_size, index_topk=1024, max_committed_hca=max_blocks, @@ -869,17 +1057,23 @@ def build_atom_v4_decode_graph_metadata_from_sglang( ) md.swa_num_slots = proxy_pool.num_slots md.swa_window = proxy_pool.window_size - md.swa_cs = proxy_pool.window_size + md.swa_cs = proxy_pool.swa_cache_size md.index_topk = 1024 - md.swa_pages = proxy_pool.num_slots * proxy_pool.window_size + md.swa_pages = proxy_pool.num_slots * proxy_pool.swa_cache_size if total: - pos_np = (seq_np[:total] - 1).astype(np.int32) - batch_np = np.arange(total, dtype=np.int32) + if positions_numel > scheduled_bs: + pos_np = positions[:total].detach().cpu().numpy().astype(np.int32) + repeats = max(1, total // max(1, bs)) + batch_np = np.repeat(np.arange(bs, dtype=np.int64), repeats)[:total] + else: + pos_np = (seq_np[:total] - 1).astype(np.int32) + batch_np = np.arange(total, dtype=np.int64) else: pos_np = np.zeros(0, dtype=np.int32) - batch_np = np.zeros(0, dtype=np.int32) - batch_pad = np.full(t_pad, -1, dtype=np.int32) + batch_np = np.zeros(0, dtype=np.int64) + t_pad = max(t_pad, total) + batch_pad = np.full(t_pad, -1, dtype=np.int64) if total: batch_pad[:total] = batch_np @@ -891,11 +1085,13 @@ def build_atom_v4_decode_graph_metadata_from_sglang( slot_arr = np.zeros(bs, dtype=np.int32) reset_slots: set[int] = set() - if total: - first_blocks = block_tables[:total, 0].detach().cpu().numpy().astype(np.int32) - fresh_mask = pos_np == 0 + if scheduled_bs: + first_blocks = ( + block_tables[:scheduled_bs, 0].detach().cpu().numpy().astype(np.int32) + ) + fresh_mask = seq_np[:scheduled_bs] <= 1 slot_real, reset_slots = allocator.assign(first_blocks, fresh_mask) - slot_arr[:total] = slot_real + slot_arr[:scheduled_bs] = slot_real if reset_slots and model is not None: reset_deepseek_v4_state_slots(model, reset_slots) @@ -910,9 +1106,6 @@ def build_atom_v4_decode_graph_metadata_from_sglang( md.batch_id_per_token = bufs.stage(bufs.batch_id, batch_pad, t_pad) n_csa = (seq_np // 4).astype(np.int32) n_hca = (seq_np // 128).astype(np.int32) - if os.environ.get("ATOM_SGLANG_V4_DISABLE_COMPRESS_READ") == "1": - n_csa = np.zeros_like(n_csa) - n_hca = np.zeros_like(n_hca) md.n_committed_csa_per_seq_cpu = n_csa md.n_committed_hca_per_seq_cpu = n_hca md.n_committed_csa_per_seq = bufs.stage(bufs.n_csa, n_csa, bs) @@ -924,9 +1117,9 @@ def build_atom_v4_decode_graph_metadata_from_sglang( if total: actual_swa = np.minimum(pos_np + 1, win).astype(np.int32) csa_valid = np.minimum( - np.minimum((pos_np + 1) // 4, n_csa[:total]), index_topk + np.minimum((pos_np + 1) // 4, n_csa[batch_np]), index_topk ).astype(np.int32) - hca_valid = n_hca[:total].astype(np.int32) + hca_valid = n_hca[batch_np].astype(np.int32) else: actual_swa = csa_valid = hca_valid = np.zeros(0, dtype=np.int32) @@ -977,8 +1170,289 @@ def indptr(counts): md.kv_indptr_swa = swa_indptr md.kv_indptr_csa = csa_indptr md.kv_indptr_hca = hca_indptr + cu_committed_cpu = np.concatenate( + [ + np.zeros(1, dtype=np.int32), + np.cumsum(md.n_committed_csa_per_seq_cpu, dtype=np.int32), + ] + ) + cu_committed_cpu[-1] = max(int(cu_committed_cpu[-1]), 1) + cu_committed_gpu = torch.from_numpy(cu_committed_cpu).to( + device=device, dtype=torch.int32 + ) + safe_batch_id = md.batch_id_per_token.clamp_min(0) + seq_base = cu_committed_gpu[safe_batch_id].to(torch.int32) + visible_end = seq_base + torch.minimum( + (positions_gpu.to(torch.int32) + 1) // 4, + md.n_committed_csa_per_seq[safe_batch_id], + ).to(torch.int32) md.indexer_meta = { + "total_committed": int(cu_committed_cpu[-1]), + "cu_committed_gpu": cu_committed_gpu, "n_committed_per_seq_gpu": md.n_committed_csa_per_seq, + "batch_id_per_token_gpu": md.batch_id_per_token, + "seq_base_per_token_gpu": seq_base, + "cu_starts_gpu": seq_base, + "cu_ends_gpu": visible_end, + } + return md + + +def build_atom_v4_verify_graph_metadata_from_sglang( + forward_batch, + positions: torch.Tensor, + *, + proxy_pool: ATOMDeepSeekV4ProxyKVPool, + req_to_token_pool, + model: Any = None, +): + from atom.model_ops.v4_kernels import write_v4_paged_prefill_indices + from atom.utils.forward_context import AttentionMetaData, AttnState + + device = positions.device + bs = int(forward_batch.batch_size) + seq_np = _get_seq_lens_cpu(forward_batch)[:bs] + if seq_np.size < bs: + seq_np = np.pad(seq_np, (0, bs - seq_np.size), constant_values=1).astype( + np.int32 + ) + + positions_numel = int(positions.numel()) + tokens_per_req = getattr( + getattr(forward_batch, "spec_info", None), "num_tokens_per_req", None + ) + if tokens_per_req is None: + tokens_per_req = max(1, positions_numel // max(1, bs)) + tokens_per_req = max(1, int(tokens_per_req)) + total = bs * tokens_per_req + if positions_numel < total: + padded_positions = torch.zeros(total, dtype=torch.int64, device=device) + if positions_numel: + padded_positions[:positions_numel].copy_(positions) + positions = padded_positions + else: + positions = positions[:total] + is_draft_extend = bool( + getattr(forward_batch.forward_mode, "is_draft_extend", lambda **kwargs: False)( + include_v2=True + ) + ) + use_replay_input_positions = ( + is_draft_extend + and hasattr(forward_batch, "actual_forward_mode") + and os.environ.get("ATOM_SGLANG_V4_DRAFT_EXTEND_USE_INPUT_POSITIONS", "0") + in ("1", "true", "True", "yes", "on") + ) + if is_draft_extend and not use_replay_input_positions: + # Draft-extend graph capture can be invoked with dummy seq_lens from a + # decode-shaped draft backend. Keep capture metadata structurally valid: + # an extend chunk cannot be longer than the context length it is appended to. + seq_np = np.maximum(seq_np, tokens_per_req).astype(np.int32) + prefix_np = np.maximum(seq_np[:bs].astype(np.int32) - tokens_per_req, 0) + offsets_np = np.arange(tokens_per_req, dtype=np.int32) + pos_np = (prefix_np[:, None] + offsets_np[None, :]).reshape(-1) + else: + # Target-verify and experimental draft-extend replay positions come from + # the tensor copied into the CUDA graph input buffer. Deriving them from + # seq_lens can diverge after padding/acceptance updates. + pos_np = positions[:total].detach().cpu().numpy().astype(np.int32) + buffer_attr = ( + "_atom_v4_draft_extend_graph_buffers" + if is_draft_extend + else "_atom_v4_verify_graph_buffers" + ) + max_blocks = max(1, proxy_pool.num_blocks) + bufs = getattr(proxy_pool, buffer_attr, None) + if ( + bufs is None + or bufs.num_slots < bs + or bufs.max_blocks < max_blocks + or bufs.max_verify_tokens < total + ): + bufs = _V4SGLangVerifyGraphBuffers( + num_slots=proxy_pool.num_slots, + max_verify_tokens=max(proxy_pool.num_slots, total), + window=proxy_pool.window_size, + index_topk=1024, + max_committed_hca=max_blocks, + max_blocks=max_blocks, + device=device, + ) + setattr(proxy_pool, buffer_attr, bufs) + + lens = np.full(bs, tokens_per_req, dtype=np.int32) + q_np = np.zeros(bs + 1, dtype=np.int32) + q_np[1:] = np.cumsum(lens, dtype=np.int32) + batch_np = np.repeat(np.arange(bs, dtype=np.int32), lens) + cu_q = bufs.stage(bufs.cu_q, q_np, bs + 1) + + block_tables_live = _build_block_tables( + req_to_token_pool, + forward_batch.req_pool_indices[:bs], + max_blocks * ATOM_DEEPSEEK_V4_BLOCK_SIZE, + ATOM_DEEPSEEK_V4_BLOCK_SIZE, + ) + bufs.block_tables.gpu[:bs, : block_tables_live.shape[1]].copy_(block_tables_live) + block_tables = bufs.block_tables.gpu[:bs] + + md = AttentionMetaData( + cu_seqlens_q=cu_q, + cu_seqlens_k=cu_q, + max_seqlen_q=tokens_per_req, + max_seqlen_k=int(seq_np.max()) if len(seq_np) else 1, + slot_mapping=getattr(forward_batch, "out_cache_loc", None), + context_lens=forward_batch.seq_lens[:bs], + block_tables=block_tables, + state=AttnState.PREFILL_NATIVE, + ) + md.swa_num_slots = proxy_pool.num_slots + md.swa_window = proxy_pool.window_size + md.swa_cs = proxy_pool.swa_cache_size + md.index_topk = 1024 + md.swa_pages = proxy_pool.num_slots * proxy_pool.swa_cache_size + # Target verify is extend-shaped for attention/compressor state, but the + # indexer needs the fixed-shape decode scorer to be graph-safe. + md.use_decode_indexer_for_verify_graph = True + md.is_dsv4_draft_extend_graph = is_draft_extend + + out_cache_loc = getattr(forward_batch, "out_cache_loc", None) + scheduled_bs = ( + min(bs, int(out_cache_loc.numel()) // tokens_per_req) + if torch.is_tensor(out_cache_loc) + else bs + ) + slot_arr = np.zeros(bs, dtype=np.int32) + reset_slots: set[int] = set() + if is_draft_extend and os.environ.get( + "ATOM_SGLANG_V4_DRAFT_EXTEND_USE_ALLOCATOR_SLOT", "0" + ) not in ("1", "true", "True", "yes", "on"): + # Draft-extend graph replay must not synchronize GPU metadata back to + # CPU. It runs after request slots already exist, so use SGLang's + # req_pool_indices as stable per-request ATOM state slots and update the + # persistent GPU buffer directly. + if scheduled_bs: + bufs.state_slot.gpu[:scheduled_bs].copy_( + forward_batch.req_pool_indices[:scheduled_bs] + ) + slot_arr[:scheduled_bs] = -1 + if bs > scheduled_bs: + bufs.state_slot.gpu[scheduled_bs:bs].zero_() + else: + allocator = getattr(proxy_pool, "_atom_v4_slot_allocator", None) + if allocator is None: + allocator = proxy_pool._atom_v4_slot_allocator = _V4StateSlotAllocator( + proxy_pool.num_slots + ) + if scheduled_bs: + first_blocks = ( + block_tables[:scheduled_bs, 0].detach().cpu().numpy().astype(np.int32) + ) + chunk_start_per_seq = pos_np[q_np[:-1]] + fresh_mask = chunk_start_per_seq[:scheduled_bs] == 0 + slot_real, reset_slots = allocator.assign(first_blocks, fresh_mask) + slot_arr[:scheduled_bs] = slot_real + if reset_slots and model is not None: + reset_deepseek_v4_state_slots(model, reset_slots) + bufs.stage(bufs.state_slot, slot_arr, bs) + md.reset_slots = set() + md.state_slot_mapping_cpu = slot_arr + md.state_slot_mapping = bufs.state_slot.gpu[:bs] + md.batch_id_per_token_cpu = batch_np + md.batch_id_per_token = bufs.stage(bufs.batch_id, batch_np, total) + + n_csa = (seq_np // 4).astype(np.int32) + n_hca = (seq_np // 128).astype(np.int32) + md.n_committed_csa_per_seq_cpu = n_csa + md.n_committed_hca_per_seq_cpu = n_hca + md.n_committed_csa_per_seq = bufs.stage(bufs.n_csa, n_csa, bs) + md.n_committed_hca_per_seq = bufs.stage(bufs.n_hca, n_hca, bs) + md.compress_plans = ( + _make_verify_graph_compress_plans(lens, seq_np, bufs) + if is_draft_extend + else _make_verify_graph_compress_plans_from_positions( + pos_np, batch_np, bs, bufs + ) + ) + + win = int(md.swa_window) + cs = int(md.swa_cs) + chunk_start_per_seq = pos_np[q_np[:-1]] + chunk_start_pt = chunk_start_per_seq[batch_np] + token_pos_in_chunk = pos_np - chunk_start_pt + swa_low = np.maximum(pos_np - win + 1, 0) + extend_count = np.minimum(token_pos_in_chunk + 1, win).astype(np.int32) + prefix_swa_count = np.maximum(chunk_start_pt - swa_low, 0).astype(np.int32) + csa_valid_k = np.minimum( + np.minimum((pos_np + 1) // 4, md.n_committed_csa_per_seq_cpu[batch_np]), + int(md.index_topk), + ).astype(np.int32) + hca_count = md.n_committed_hca_per_seq_cpu[batch_np].astype(np.int32) + + ext_indptr_np = _counts_to_indptr(extend_count) + swa_indptr_np = _counts_to_indptr(prefix_swa_count) + csa_indptr_np = _counts_to_indptr(prefix_swa_count + csa_valid_k) + hca_indptr_np = _counts_to_indptr(prefix_swa_count + hca_count) + + ext_indptr = bufs.stage(bufs.indptr_extend, ext_indptr_np, total + 1) + swa_indptr = bufs.stage(bufs.indptr_prefix_swa, swa_indptr_np, total + 1) + csa_indptr = bufs.stage(bufs.indptr_prefix_csa, csa_indptr_np, total + 1) + hca_indptr = bufs.stage(bufs.indptr_prefix_hca, hca_indptr_np, total + 1) + chunk_start_gpu = bufs.stage(bufs.chunk_start_per_seq, chunk_start_per_seq, bs) + skip_prefix_len_csa = bufs.stage( + bufs.skip_prefix_len_csa, prefix_swa_count.astype(np.int32), total + ) + + write_v4_paged_prefill_indices( + positions=positions[:total].to(torch.int32), + bid_per_token=md.batch_id_per_token.to(torch.int64), + chunk_start_per_seq=chunk_start_gpu, + cu_seqlens_q_per_seq=cu_q[:-1], + state_slot_per_seq=md.state_slot_mapping, + n_committed_hca_per_seq=md.n_committed_hca_per_seq, + block_tables=block_tables, + extend_indptr=ext_indptr, + prefix_swa_indptr=swa_indptr, + prefix_csa_indptr=csa_indptr, + prefix_hca_indptr=hca_indptr, + extend_indices=bufs.idx_extend.gpu, + prefix_swa_indices=bufs.idx_prefix_swa.gpu, + prefix_csa_indices=bufs.idx_prefix_csa.gpu, + prefix_hca_indices=bufs.idx_prefix_hca.gpu, + T=total, + win=win, + cs=cs, + swa_pages=int(md.swa_pages), + ) + md.kv_indices_extend = bufs.idx_extend.gpu + md.kv_indices_prefix_swa = bufs.idx_prefix_swa.gpu + md.kv_indices_prefix_csa = bufs.idx_prefix_csa.gpu + md.kv_indices_prefix_hca = bufs.idx_prefix_hca.gpu + md.kv_indptr_extend = ext_indptr + md.kv_indptr_prefix_swa = swa_indptr + md.kv_indptr_prefix_csa = csa_indptr + md.kv_indptr_prefix_hca = hca_indptr + md.skip_prefix_len_csa = skip_prefix_len_csa + md.chunk_start_per_seq_cpu = chunk_start_per_seq.astype(np.int32) + + cu_committed_cpu = np.concatenate( + [np.zeros(1, dtype=np.int32), np.cumsum(n_csa, dtype=np.int32)] + ) + cu_committed_cpu[-1] = max(int(cu_committed_cpu[-1]), 1) + cu_committed_gpu = bufs.stage(bufs.indexer_cu_committed, cu_committed_cpu, bs + 1) + seq_base_cpu = cu_committed_cpu[batch_np].astype(np.int32) + visible_end_cpu = seq_base_cpu + np.minimum( + (pos_np + 1) // 4, n_csa[batch_np] + ).astype(np.int32) + seq_base_gpu = bufs.stage(bufs.indexer_seq_base, seq_base_cpu, total) + visible_end_gpu = bufs.stage(bufs.indexer_cu_ends, visible_end_cpu, total) + md.indexer_meta = { + "total_committed": int(cu_committed_cpu[-1]), + "cu_committed_gpu": cu_committed_gpu, + "n_committed_per_seq_gpu": md.n_committed_csa_per_seq, + "batch_id_per_token_gpu": md.batch_id_per_token, + "seq_base_per_token_gpu": seq_base_gpu, + "cu_starts_gpu": seq_base_gpu, + "cu_ends_gpu": visible_end_gpu, } return md @@ -1005,6 +1479,11 @@ def build_atom_v4_attention_metadata_from_sglang( num_reqs = int(forward_batch.batch_size) seq_np = _get_seq_lens_cpu(forward_batch)[:num_reqs] is_decode = forward_batch.forward_mode.is_decode_or_idle() + is_draft_extend = bool( + getattr(forward_batch.forward_mode, "is_draft_extend", lambda **kwargs: False)( + include_v2=True + ) + ) if is_decode: lens = np.ones(num_reqs, dtype=np.int32) @@ -1012,14 +1491,52 @@ def build_atom_v4_attention_metadata_from_sglang( batch_np = np.arange(num_reqs, dtype=np.int32) pos_np = positions[:num_reqs].detach().cpu().numpy().astype(np.int32) else: - extend_lens = _get_extend_lens_cpu(forward_batch, positions) - if extend_lens is None: - raise RuntimeError("SGLang DeepSeek-V4 prefill metadata lacks extend lens") + if is_draft_extend: + tokens_per_req = getattr( + getattr(forward_batch, "spec_info", None), + "num_tokens_per_req", + None, + ) + if tokens_per_req is None: + tokens_per_req = max(1, int(positions.numel()) // max(1, num_reqs)) + extend_lens = np.full( + num_reqs, + int(tokens_per_req), + dtype=np.int32, + ) + else: + extend_lens = _get_extend_lens_cpu(forward_batch, positions) + if extend_lens is None: + tokens_per_req = getattr( + getattr(forward_batch, "spec_info", None), + "num_tokens_per_req", + None, + ) + if tokens_per_req is None: + tokens_per_req = max(1, int(positions.numel()) // max(1, num_reqs)) + extend_lens = np.full( + num_reqs, + int(tokens_per_req), + dtype=np.int32, + ) + else: + extend_lens = np.asarray(extend_lens, dtype=np.int32) lens = extend_lens[:num_reqs].astype(np.int32) q_np = np.zeros(num_reqs + 1, dtype=np.int32) q_np[1:] = np.cumsum(lens, dtype=np.int32) batch_np = np.repeat(np.arange(num_reqs, dtype=np.int32), lens) - pos_np = positions[: int(lens.sum())].detach().cpu().numpy().astype(np.int32) + if is_draft_extend: + prefix_np = np.maximum(seq_np[:num_reqs].astype(np.int32) - lens, 0) + pos_np = np.concatenate( + [ + prefix_np[i] + np.arange(int(lens[i]), dtype=np.int32) + for i in range(num_reqs) + ] + ).astype(np.int32) + else: + pos_np = ( + positions[: int(lens.sum())].detach().cpu().numpy().astype(np.int32) + ) total = int(lens.sum()) max_seq_len = int(seq_np.max()) if len(seq_np) else 1 @@ -1043,34 +1560,39 @@ def build_atom_v4_attention_metadata_from_sglang( ) md.swa_num_slots = proxy_pool.num_slots md.swa_window = proxy_pool.window_size - md.swa_cs = proxy_pool.window_size + md.swa_cs = proxy_pool.swa_cache_size md.index_topk = 1024 - md.swa_pages = proxy_pool.num_slots * proxy_pool.window_size + md.swa_pages = proxy_pool.num_slots * proxy_pool.swa_cache_size - allocator = getattr(proxy_pool, "_atom_v4_slot_allocator", None) - if allocator is None: - allocator = proxy_pool._atom_v4_slot_allocator = _V4StateSlotAllocator( - proxy_pool.num_slots + if is_draft_extend: + slot_arr = np.full(num_reqs, -1, dtype=np.int32) + md.reset_slots = set() + md.state_slot_mapping_cpu = slot_arr + md.state_slot_mapping = forward_batch.req_pool_indices[:num_reqs].to( + device=device, dtype=torch.int32 + ) + else: + allocator = getattr(proxy_pool, "_atom_v4_slot_allocator", None) + if allocator is None: + allocator = proxy_pool._atom_v4_slot_allocator = _V4StateSlotAllocator( + proxy_pool.num_slots + ) + first_block_ids = block_tables[:num_reqs, 0].detach().cpu().numpy() + fresh_mask = ( + pos_np[q_np[:-1]] == 0 + if total and len(q_np) > 1 + else np.zeros(num_reqs, dtype=bool) + ) + slot_arr, reset_slots = allocator.assign(first_block_ids, fresh_mask) + md.reset_slots = reset_slots + md.state_slot_mapping_cpu = slot_arr + md.state_slot_mapping = torch.from_numpy(slot_arr).to( + device=device, dtype=torch.int32 ) - first_block_ids = block_tables[:num_reqs, 0].detach().cpu().numpy() - fresh_mask = ( - pos_np[q_np[:-1]] == 0 - if total and len(q_np) > 1 - else np.zeros(num_reqs, dtype=bool) - ) - slot_arr, reset_slots = allocator.assign(first_block_ids, fresh_mask) - md.reset_slots = reset_slots - md.state_slot_mapping_cpu = slot_arr - md.state_slot_mapping = torch.from_numpy(slot_arr).to( - device=device, dtype=torch.int32 - ) md.batch_id_per_token_cpu = batch_np md.batch_id_per_token = torch.from_numpy(batch_np).to(device=device) md.n_committed_csa_per_seq_cpu = (seq_np // 4).astype(np.int32) md.n_committed_hca_per_seq_cpu = (seq_np // 128).astype(np.int32) - if os.environ.get("ATOM_SGLANG_V4_DISABLE_COMPRESS_READ") == "1": - md.n_committed_csa_per_seq_cpu = np.zeros_like(md.n_committed_csa_per_seq_cpu) - md.n_committed_hca_per_seq_cpu = np.zeros_like(md.n_committed_hca_per_seq_cpu) md.n_committed_csa_per_seq = torch.from_numpy(md.n_committed_csa_per_seq_cpu).to( device=device ) @@ -1084,19 +1606,6 @@ def build_atom_v4_attention_metadata_from_sglang( else: _populate_prefill_indices(md, block_tables, batch_np, pos_np, q_np, device) _populate_indexer(md, batch_np, positions[:total], device) - if _debug_enabled(): - logger.info( - "ATOM SGLang V4 metadata: mode=%s batch=%s total=%s positions=%s " - "lens=%s seq=%s state_slots=%s padded_static_len=%s", - getattr(forward_batch.forward_mode, "name", forward_batch.forward_mode), - num_reqs, - total, - int(positions.numel()), - lens.tolist(), - seq_np.tolist(), - slot_arr.tolist(), - getattr(forward_batch, "padded_static_len", None), - ) return md @@ -1339,7 +1848,7 @@ def reset_deepseek_v4_state_slots(model, slots) -> None: if not slots: return idx = None - for block in getattr(model.model, "layers", []): + for block in _iter_deepseek_v4_cache_blocks(model): attn = getattr(block, "attn", None) swa = getattr(attn, "swa_kv", None) if isinstance(swa, torch.Tensor): diff --git a/atom/plugin/sglang/models/base_model_wrapper.py b/atom/plugin/sglang/models/base_model_wrapper.py index c3251bd650..417b8be71b 100644 --- a/atom/plugin/sglang/models/base_model_wrapper.py +++ b/atom/plugin/sglang/models/base_model_wrapper.py @@ -147,6 +147,9 @@ def get_embed_and_head(self): if hasattr(self.model, "get_embed_and_head"): return self.model.get_embed_and_head() + if self.model_arch == "DeepseekV4ForCausalLM": + return self.model.model.embed.weight, self.model.model.head.weight + embed_owner = ( self.model.model if hasattr(self.model, "model") @@ -269,20 +272,13 @@ def forward( hidden_states = runtime.trim_output(hidden_states) if self.pp_group.is_last_rank: - if self.model_arch == "DeepseekV4ForCausalLM" and not getattr( - forward_batch, "return_logprob", False - ): - if forward_batch.forward_mode.is_decode_or_idle(): - pruned_states = hidden_states - elif forward_batch.forward_mode.is_extend(): - last_index = ( - torch.cumsum(forward_batch.extend_seq_lens, dim=0) - 1 - ) - pruned_states = hidden_states[last_index] - else: - pruned_states = hidden_states - return LogitsProcessorOutput( - next_token_logits=self.model.compute_logits(pruned_states) + if self.model_arch == "DeepseekV4ForCausalLM": + return self.logits_processor( + input_ids, + hidden_states, + self.logits_head, + forward_batch, + hidden_states_before_norm=hidden_states, ) return self.logits_processor( input_ids, diff --git a/atom/plugin/sglang/models/deepseek_v4_attention.py b/atom/plugin/sglang/models/deepseek_v4_attention.py index 2c28c6aa62..772ffa5365 100644 --- a/atom/plugin/sglang/models/deepseek_v4_attention.py +++ b/atom/plugin/sglang/models/deepseek_v4_attention.py @@ -5,12 +5,132 @@ from __future__ import annotations +import contextvars +import copy import types -import os import torch from torch import nn +_draft_extend_fused_swa_ctx = contextvars.ContextVar( + "atom_sglang_dsv4_draft_extend_fused_swa_ctx", + default=None, +) + + +def _install_draft_extend_fused_swa_patch() -> None: + """Patch ATOM DSV4 symbols only while SGLang graph integration needs them.""" + + import atom.models.deepseek_v4 as dsv4 + + if getattr(dsv4, "_atom_sglang_draft_extend_fused_swa_patched", False): + return + + original_qk_norm_rope_maybe_quant = dsv4.qk_norm_rope_maybe_quant + original_swa_write = dsv4.swa_write + original_indexer_score_topk = dsv4.Indexer.indexer_score_topk + original_score_topk_decode = dsv4.Indexer._score_topk_decode + + def qk_norm_rope_maybe_quant(*args, **kwargs): + ctx = _draft_extend_fused_swa_ctx.get() + if ctx is not None and kwargs.get("swa_kv") is None: + attn = ctx["attn"] + attn_md = ctx["attn_md"] + cache_size = int(attn.swa_kv.shape[1]) + kwargs.update( + swa_kv=attn.swa_kv, + state_slot_mapping=attn_md.state_slot_mapping, + batch_id_per_token=attn_md.batch_id_per_token, + swa_cu_seqlens_q=attn_md.cu_seqlens_q, + swa_cache_size=cache_size, + swa_write_per_batch=min(int(attn_md.max_seqlen_q), cache_size), + ) + return original_qk_norm_rope_maybe_quant(*args, **kwargs) + + def swa_write(*args, **kwargs): + if _draft_extend_fused_swa_ctx.get() is not None: + return None + return original_swa_write(*args, **kwargs) + + def indexer_score_topk(self, q_fp8, weights, topk): + fc = dsv4.get_forward_context() + if bool( + getattr(fc.attn_metadata, "use_decode_indexer_for_verify_graph", False) + ): + indexer_meta = fc.attn_metadata.indexer_meta + block_tables = fc.attn_metadata.block_tables + return self._score_topk_decode( + q_fp8, weights, block_tables, indexer_meta, topk + ) + return original_indexer_score_topk(self, q_fp8, weights, topk) + + def _score_topk_decode(self, q_fp8, weights, block_tables, indexer_meta, topk): + fc = dsv4.get_forward_context() + if not bool( + getattr(fc.attn_metadata, "use_decode_indexer_for_verify_graph", False) + ): + return original_score_topk_decode( + self, q_fp8, weights, block_tables, indexer_meta, topk + ) + + total_tokens = q_fp8.size(0) + n_committed_per_seq_gpu = indexer_meta["n_committed_per_seq_gpu"] + next_n = max(1, int(fc.attn_metadata.max_seqlen_q)) + bs = total_tokens // next_n + q_4d = q_fp8.view(bs, next_n, self.n_heads, self.head_dim) + kv_cache_4d = self.kv_cache.unsqueeze(-2) + logits = torch.empty( + total_tokens, + self._max_model_len_idx, + dtype=torch.float32, + device=q_fp8.device, + ) + dsv4.deepgemm_fp8_paged_mqa_logits( + q_4d, + kv_cache_4d, + weights, + logits, + n_committed_per_seq_gpu, + block_tables, + self._max_model_len_idx, + KVBlockSize=self.kv_cache.size(1), + Preshuffle=True, + ) + + cu_starts = indexer_meta.get("cu_starts_gpu") + cu_ends = indexer_meta.get("cu_ends_gpu") + if cu_starts is None or cu_ends is None: + return original_score_topk_decode( + self, q_fp8, weights, block_tables, indexer_meta, topk + ) + + topk_local = torch.empty( + total_tokens, + self.index_topk, + dtype=torch.int32, + device=q_fp8.device, + ) + local_starts = torch.zeros_like(cu_starts) + local_ends = (cu_ends - cu_starts).clamp_min_(0) + dsv4.top_k_per_row_prefill( + logits, + local_starts, + local_ends, + topk_local, + None, + total_tokens, + logits.stride(0), + logits.stride(1), + k=topk, + ) + return topk_local + + dsv4.qk_norm_rope_maybe_quant = qk_norm_rope_maybe_quant + dsv4.swa_write = swa_write + dsv4.Indexer.indexer_score_topk = indexer_score_topk + dsv4.Indexer._score_topk_decode = _score_topk_decode + dsv4._atom_sglang_draft_extend_fused_swa_patched = True + def patch_deepseek_v4_attention_for_sglang(attn: nn.Module) -> None: """Patch ATOM V4 attention for SGLang's padded prefill execution. @@ -23,6 +143,7 @@ def patch_deepseek_v4_attention_for_sglang(attn: nn.Module) -> None: if hasattr(attn, "_sglang_v4_forward_impl"): return + _install_draft_extend_fused_swa_patch() original_forward_impl = attn.forward_impl attn._sglang_v4_forward_impl = original_forward_impl @@ -34,25 +155,95 @@ def _forward_impl(self, x: torch.Tensor, positions: torch.Tensor) -> torch.Tenso return self._sglang_v4_forward_impl(x, positions) attn_md = fc.attn_metadata + is_draft_extend_graph = bool( + getattr(attn_md, "is_dsv4_draft_extend_graph", False) + ) + + def call_original( + x_arg: torch.Tensor, positions_arg: torch.Tensor + ) -> torch.Tensor: + if not is_draft_extend_graph: + return self._sglang_v4_forward_impl(x_arg, positions_arg) + token = _draft_extend_fused_swa_ctx.set( + {"attn": self, "attn_md": fc.attn_metadata} + ) + try: + return self._sglang_v4_forward_impl(x_arg, positions_arg) + finally: + _draft_extend_fused_swa_ctx.reset(token) + if attn_md is not None and attn_md.state is not AttnState.DECODE: batch_id_per_token = getattr(attn_md, "batch_id_per_token", None) - num_real = ( - int(batch_id_per_token.shape[0]) - if torch.is_tensor(batch_id_per_token) - else x.shape[0] + is_verify_graph = bool( + getattr(attn_md, "use_decode_indexer_for_verify_graph", False) ) + indptr = getattr(attn_md, "kv_indptr_extend", None) + if torch.is_tensor(indptr) and indptr.dim() > 0: + # Avoid GPU->CPU reads such as cu_seqlens_q[-1].item() under + # CUDA graph capture. Prefill/target-verify metadata carries + # true per-token extend indptrs; its shape is the safest source + # of real token count when SGLang presents padded graph tensors. + num_real = int(indptr.shape[0]) - 1 + elif is_verify_graph: + state_slots = getattr(attn_md, "state_slot_mapping", None) + num_reqs = ( + int(state_slots.shape[0]) + if torch.is_tensor(state_slots) + else int(getattr(fc.context, "batch_size", 1)) + ) + num_real = int(getattr(attn_md, "max_seqlen_q", 1)) * num_reqs + else: + num_real = ( + int(batch_id_per_token.shape[0]) + if torch.is_tensor(batch_id_per_token) + else x.shape[0] + ) if 0 <= num_real < x.shape[0]: - if os.environ.get("ATOM_SGLANG_V4_DEBUG") == "1": - import logging - - logging.getLogger("atom.plugin.sglang.deepseek_v4_attention").info( - "Slice padded V4 prefill attention: layer=%s real=%s padded=%s", - getattr(self, "layer_id", None), - num_real, - x.shape[0], - ) - out = self._sglang_v4_forward_impl(x[:num_real], positions[:num_real]) + sliced_md = copy.copy(attn_md) + + def slice_attr(name: str, n: int) -> None: + value = getattr(sliced_md, name, None) + if torch.is_tensor(value): + setattr(sliced_md, name, value[:n]) + elif value is not None: + try: + setattr(sliced_md, name, value[:n]) + except Exception: + pass + + for name in ( + "batch_id_per_token", + "batch_id_per_token_cpu", + "skip_prefix_len_csa", + ): + slice_attr(name, num_real) + for name in ( + "kv_indptr_extend", + "kv_indptr_prefix_swa", + "kv_indptr_prefix_csa", + "kv_indptr_prefix_hca", + ): + slice_attr(name, num_real + 1) + indexer_meta = getattr(sliced_md, "indexer_meta", None) + if isinstance(indexer_meta, dict): + indexer_meta = dict(indexer_meta) + for key in ( + "batch_id_per_token_gpu", + "seq_base_per_token_gpu", + "cu_starts_gpu", + "cu_ends_gpu", + ): + value = indexer_meta.get(key) + if torch.is_tensor(value): + indexer_meta[key] = value[:num_real] + sliced_md.indexer_meta = indexer_meta + original_md = fc.attn_metadata + fc.attn_metadata = sliced_md + try: + out = call_original(x[:num_real], positions[:num_real]) + finally: + fc.attn_metadata = original_md return torch.nn.functional.pad(out, (0, 0, 0, x.shape[0] - num_real)) - return self._sglang_v4_forward_impl(x, positions) + return call_original(x, positions) attn.forward_impl = types.MethodType(_forward_impl, attn) diff --git a/atom/plugin/sglang/models/deepseek_v4_nextn_wrapper.py b/atom/plugin/sglang/models/deepseek_v4_nextn_wrapper.py new file mode 100644 index 0000000000..666651f7fd --- /dev/null +++ b/atom/plugin/sglang/models/deepseek_v4_nextn_wrapper.py @@ -0,0 +1,322 @@ +"""ATOM DeepSeek-V4 NextN wrapper for SGLang external loading. + +SGLang rewrites a DeepSeek-V4 draft runner to the architecture name +``DeepseekV4ForCausalLMNextN``. This wrapper keeps that public name while +delegating the actual MTP block to ATOM's ``DeepseekV4MTP`` implementation. +""" + +import copy +import logging +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.server_args import get_global_server_args + +from atom.config import QuantizationConfig as AtomQuantizationConfig +from atom.config import SpeculativeConfig +from atom.model_ops.embed_head import VocabParallelEmbedding +from atom.models.deepseek_v4 import DeepseekV4Attention, ParallelHead +from atom.plugin.config import generate_atom_config_for_plugin_mode +from atom.plugin.sglang.runtime import ( + SGLangForwardBatchMetadata, + SGLangPluginRuntime, + plugin_runtime_scope, +) + +logger = logging.getLogger("atom.plugin.sglang.models") + + +def _sync_replaced_weights() -> None: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + +def _replace_weight(module: nn.Module, attr_name: str, weight) -> None: + if hasattr(module, attr_name): + delattr(module, attr_name) + setattr(module, attr_name, weight) + + +def _materialize_dummy_hidden_states( + hidden_states: torch.Tensor, + *, + length: int, +) -> torch.Tensor: + shape = (length, *hidden_states.shape[1:]) + return hidden_states.new_zeros(shape) + + +def _reshape_mtp_hidden_states(hidden_states: torch.Tensor, *, hidden_size: int): + if hidden_states is None: + return None + if hidden_states.dim() == 3: + return hidden_states + if hidden_states.dim() != 2: + raise ValueError( + "DeepSeek-V4 MTP hidden_states must be rank-2 flattened or rank-3 " + f"mHC, got shape={tuple(hidden_states.shape)}" + ) + width = int(hidden_states.shape[-1]) + if width == hidden_size: + return hidden_states.unsqueeze(1) + if width % hidden_size != 0: + raise ValueError( + "DeepSeek-V4 MTP flattened hidden width must be divisible by " + f"hidden_size={hidden_size}, got width={width}" + ) + return hidden_states.view(hidden_states.shape[0], width // hidden_size, hidden_size) + + +def _install_deepseek_v4_mtp_adapters(model: nn.Module) -> None: + from atom.plugin.sglang.models.deepseek_v4_attention import ( + patch_deepseek_v4_attention_for_sglang, + ) + + for module in model.modules(): + if isinstance(module, DeepseekV4Attention): + patch_deepseek_v4_attention_for_sglang(module) + + +class _DeepseekV4MTPLogitsHeadAdapter(nn.Module): + """Expose ``DeepseekV4MTP.compute_logits`` as an SGLang lm_head.""" + + def __init__(self, model: nn.Module) -> None: + super().__init__() + self.model = model + + def set_lora(self, *args, **kwargs) -> None: + return None + + def apply_lora(self, *args, **kwargs) -> None: + return None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.model.compute_logits(hidden_states) + + +class DeepseekV4ForCausalLMNextN(nn.Module): + """SGLang-compatible draft wrapper backed by ATOM's ``DeepseekV4MTP``.""" + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + del prefix + super().__init__() + + logger.info("Initializing ATOM backend for %s", self.__class__.__name__) + + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + with plugin_runtime_scope(framework="sglang"): + self.atom_config = generate_atom_config_for_plugin_mode(config) + + server_args = get_global_server_args() + draft_model_path = ( + server_args.speculative_draft_model_path or server_args.model_path + ) + use_standalone_draft = ( + server_args.speculative_draft_model_path is not None + and server_args.speculative_draft_model_path != server_args.model_path + ) + self.use_standalone_draft = use_standalone_draft + self.atom_config.model = draft_model_path + if use_standalone_draft and hasattr(config, "quantization_config"): + self.atom_config.hf_config.quantization_config = copy.deepcopy( + config.quantization_config + ) + SpeculativeConfig.hf_config_override( + self.atom_config.hf_config, model_path=draft_model_path + ) + if use_standalone_draft: + self.atom_config.quant_config = AtomQuantizationConfig( + self.atom_config.hf_config, + self.atom_config.online_quant_config, + ) + + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + from atom.models.deepseek_v4_mtp import DeepseekV4MTP + from atom.plugin.register import ( + init_aiter_dist, + register_ops_to_sglang, + set_attn_cls, + ) + + register_ops_to_sglang(atom_config=self.atom_config) + set_attn_cls() + init_aiter_dist(config=self.atom_config) + + self.model = DeepseekV4MTP(config=self.atom_config) + self.model.atom_config = self.atom_config + _install_deepseek_v4_mtp_adapters(self.model) + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.shared_head = ParallelHead( + config.vocab_size, + config.hidden_size, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + hc_eps=getattr(config, "hc_eps", 1e-6), + ) + self._bind_shared_modules() + self.logits_head = _DeepseekV4MTPLogitsHeadAdapter(self.model) + self.logits_processor = LogitsProcessor(config, skip_all_gather=True) + + def _mtp_blocks(self): + return list(self.model.model.mtp) + + def _bind_shared_modules(self) -> None: + for block in self._mtp_blocks(): + block.embed = self.embed_tokens + block.head = self.shared_head + + def get_embed_and_head(self): + return self.embed_tokens.weight, self.shared_head.weight + + def set_embed_and_head(self, embed, head): + self.set_embed(embed) + _replace_weight(self.shared_head, "weight", head) + self._bind_shared_modules() + _sync_replaced_weights() + + def set_embed(self, embed): + _replace_weight(self.embed_tokens, "weight", embed) + self._bind_shared_modules() + _sync_replaced_weights() + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + **kwargs, + ): + del input_embeds, kwargs + if forward_batch.spec_info is None: + raise ValueError("DeepSeek-V4 MTP draft forward requires speculative info") + + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + with SGLangPluginRuntime( + atom_config=self.atom_config, + forward_batch=forward_batch, + positions=positions, + input_ids=input_ids, + ) as runtime: + from atom.plugin.sglang.deepseek_v4_bridge import ( + bind_deepseek_v4_proxy_cache_views, + maybe_get_proxy_pool_from_sglang_backend, + reset_deepseek_v4_state_slots, + ) + + proxy_pool, _ = maybe_get_proxy_pool_from_sglang_backend() + if not bind_deepseek_v4_proxy_cache_views(self.model, proxy_pool): + raise RuntimeError( + "DeepSeek-V4 MTP SGLang proxy KV pool is not initialized" + ) + from atom.utils.forward_context import get_forward_context + + reset_slots = getattr( + get_forward_context().attn_metadata, "reset_slots", None + ) + reset_deepseek_v4_state_slots(self.model, reset_slots) + + model_hidden_states = forward_batch.spec_info.hidden_states + if runtime.forward_batch is not forward_batch: + model_hidden_states = _materialize_dummy_hidden_states( + model_hidden_states, + length=int(runtime.positions.shape[0]), + ) + elif ( + torch.is_tensor(model_hidden_states) + and model_hidden_states.shape[0] != runtime.input_ids.shape[0] + and bool( + getattr( + runtime.forward_batch.forward_mode, + "is_draft_extend", + lambda **kwargs: False, + )(include_v2=True) + ) + ): + tokens_per_req = int( + getattr( + getattr(runtime.forward_batch, "spec_info", None), + "num_tokens_per_req", + 0, + ) + or 0 + ) + if ( + tokens_per_req > 0 + and model_hidden_states.shape[0] * tokens_per_req + == runtime.input_ids.shape[0] + ): + model_hidden_states = model_hidden_states.repeat_interleave( + tokens_per_req, dim=0 + ) + else: + raise RuntimeError( + "DeepSeek-V4 MTP draft-extend hidden layout mismatch: " + f"hidden={tuple(model_hidden_states.shape)}, " + f"input_tokens={int(runtime.input_ids.shape[0])}, " + f"tokens_per_req={tokens_per_req}" + ) + model_hidden_states = _reshape_mtp_hidden_states( + model_hidden_states, + hidden_size=int(self.config.hidden_size), + ) + + metadata = SGLangForwardBatchMetadata.build(runtime.forward_batch) + with SGLangForwardBatchMetadata.bind(metadata): + hidden_states = self.model( + input_ids=runtime.input_ids, + positions=runtime.positions, + hidden_states=model_hidden_states, + ) + + if self.pp_group.is_last_rank: + hidden_states = runtime.trim_output(hidden_states) + return self.logits_processor( + input_ids, + hidden_states, + self.logits_head, + forward_batch, + hidden_states_before_norm=hidden_states, + ) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + del weights + from atom.model_loader.loader import load_model + + server_args = get_global_server_args() + draft_model_path = ( + server_args.speculative_draft_model_path or server_args.model_path + ) + self.atom_config.model = draft_model_path + with plugin_runtime_scope(framework="sglang", atom_config=self.atom_config): + return load_model( + model=self.model, + model_name_or_path=draft_model_path, + hf_config=self.atom_config.hf_config, + load_dummy=self.atom_config.load_dummy, + spec_decode=True, + ) + + +EntryClass = [DeepseekV4ForCausalLMNextN] diff --git a/atom/plugin/sglang/runtime/forward_context.py b/atom/plugin/sglang/runtime/forward_context.py index 998f1746cd..432a62b657 100644 --- a/atom/plugin/sglang/runtime/forward_context.py +++ b/atom/plugin/sglang/runtime/forward_context.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import logging from contextlib import ExitStack from dataclasses import dataclass, field from typing import Any, Optional @@ -12,6 +13,8 @@ from atom.plugin.sglang.runtime.context import bind_current_forward_batch +logger = logging.getLogger("atom.plugin.sglang.runtime.forward_context") + def _is_dummy_forward(forward_batch: ForwardBatch) -> bool: """Return whether an SGLang batch represents an empty/idle dummy run.""" @@ -122,6 +125,95 @@ def _resolve_num_tokens_across_dp( return num_tokens_across_dp +def _slice_v4_graph_metadata_for_capture( + attn_metadata: Any, *, num_tokens: int, bs: int +): + """Narrow reusable V4 graph metadata to this capture bucket. + + The DSV4 fallback metadata is initialized at max graph size. SGLang then + captures smaller buckets (e.g. bs=248, tokens=496), so per-token arrays must + be narrowed before model code reads them. + """ + + if attn_metadata is None: + return None + + md = copy.copy(attn_metadata) + + def _slice_attr(name: str, n: int): + value = getattr(md, name, None) + if torch.is_tensor(value): + setattr(md, name, value[:n]) + elif value is not None: + try: + setattr(md, name, value[:n]) + except Exception: + pass + + for name in ( + "batch_id_per_token", + "batch_id_per_token_cpu", + "slot_mapping", + "kv_indices_swa", + "kv_indices_csa", + "kv_indices_hca", + "kv_indices_extend", + "kv_indices_prefix_swa", + "kv_indices_prefix_csa", + "kv_indices_prefix_hca", + "skip_prefix_len_csa", + ): + _slice_attr(name, num_tokens) + + for name in ( + "kv_indptr_swa", + "kv_indptr_csa", + "kv_indptr_hca", + "kv_indptr_extend", + "kv_indptr_prefix_swa", + "kv_indptr_prefix_csa", + "kv_indptr_prefix_hca", + ): + _slice_attr(name, num_tokens + 1) + + for name in ( + "state_slot_mapping", + "state_slot_mapping_cpu", + "n_committed_csa_per_seq", + "n_committed_csa_per_seq_cpu", + "n_committed_hca_per_seq", + "n_committed_hca_per_seq_cpu", + "context_lens", + ): + _slice_attr(name, bs) + + block_tables = getattr(md, "block_tables", None) + if torch.is_tensor(block_tables): + md.block_tables = block_tables[:bs] + + for name in ("cu_seqlens_q", "cu_seqlens_k"): + _slice_attr(name, bs + 1) + + indexer_meta = getattr(md, "indexer_meta", None) + if isinstance(indexer_meta, dict): + indexer_meta = dict(indexer_meta) + for key in ( + "batch_id_per_token_gpu", + "seq_base_per_token_gpu", + "cu_starts_gpu", + "cu_ends_gpu", + ): + value = indexer_meta.get(key) + if torch.is_tensor(value): + indexer_meta[key] = value[:num_tokens] + value = indexer_meta.get("n_committed_per_seq_gpu") + if torch.is_tensor(value): + indexer_meta["n_committed_per_seq_gpu"] = value[:bs] + md.indexer_meta = indexer_meta + + return md + + def _set_atom_forward_context( atom_config: Any, forward_batch: ForwardBatch, @@ -140,33 +232,65 @@ def _set_atom_forward_context( max_seqlen_q = 1 if forward_mode.is_decode_or_idle() else 0 attn_metadata = None try: + attn_metadata = getattr(forward_batch, "atom_v4_graph_metadata", None) from atom.plugin.sglang.deepseek_v4_bridge import ( build_atom_v4_attention_metadata_from_sglang, maybe_get_proxy_pool_from_sglang_backend, ) - try: - from sglang.srt.model_executor.forward_context import get_attn_backend + if attn_metadata is None: + try: + from sglang.srt.model_executor.forward_context import get_attn_backend - backend = get_attn_backend() - attn_metadata = getattr(backend, "atom_v4_graph_metadata", None) - except Exception: - attn_metadata = None + backend = get_attn_backend() + attn_metadata = getattr(backend, "atom_v4_graph_metadata", None) + except Exception: + attn_metadata = None if attn_metadata is None: backend = getattr(forward_batch, "attn_backend", None) attn_metadata = getattr(backend, "atom_v4_graph_metadata", None) + if attn_metadata is None and backend is not None: + backend_forward_batch = getattr(backend, "forward_metadata", None) + attn_metadata = getattr( + backend_forward_batch, "atom_v4_graph_metadata", None + ) + proxy_pool, req_to_token_pool = maybe_get_proxy_pool_from_sglang_backend() + try: + is_capture_batch = bool(torch.cuda.is_current_stream_capturing()) + except Exception: + is_capture_batch = False + if attn_metadata is None and is_capture_batch: + try: + from atom.plugin.sglang.attention_backend.deepseek_v4_backend import ( + ATOMDeepseekV4BackendForSgl, + ) + + attn_metadata = ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata + if attn_metadata is not None: + attn_metadata = _slice_v4_graph_metadata_for_capture( + attn_metadata, + num_tokens=int(positions.shape[0]), + bs=int(forward_batch.batch_size), + ) + except Exception: + attn_metadata = None if attn_metadata is None and getattr( proxy_pool, "is_atom_v4_proxy_pool", False ): - attn_metadata = build_atom_v4_attention_metadata_from_sglang( - forward_batch, - positions, - proxy_pool=proxy_pool, - req_to_token_pool=req_to_token_pool, - ) + if is_capture_batch: + raise RuntimeError( + "ATOM DeepSeek-V4 CUDA graph metadata was not initialized before capture" + ) + else: + attn_metadata = build_atom_v4_attention_metadata_from_sglang( + forward_batch, + positions, + proxy_pool=proxy_pool, + req_to_token_pool=req_to_token_pool, + ) except Exception as exc: raise RuntimeError( "Failed to build ATOM DeepSeek-V4 metadata for SGLang"