Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions atom/plugin/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,120 @@ def create_dsv4_backend(runner):
return ATOMDeepseekV4BackendForSgl(runner)


def _patch_sglang_dsv4_draft_backends() -> None:
"""Route SGLang's hard-coded DSV4 speculative factories to ATOM.

DraftBackendFactory constructs DeepSeek-V4 draft backends directly instead
of going through the attention registry. SGLang's native backend asserts a
native DeepSeekV4TokenToKVPool, while ATOM plugin mode uses a proxy KV pool,
so patch the factory methods to return the ATOM shim.
"""

try:
from sglang.srt.speculative.draft_utils import DraftBackendFactory
from atom.plugin.sglang.attention_backend.deepseek_v4_backend import (
ATOMDeepseekV4BackendForSgl,
)
except Exception as exc:
logger.debug("Skip patching SGLang DSV4 draft backends: %s", exc)
return

if getattr(DraftBackendFactory, "_atom_dsv4_draft_backend_patched", False):
return

def _create_atom_dsv4_decode_backend(self):
return ATOMDeepseekV4BackendForSgl(
self.draft_model_runner,
topk=self.topk,
speculative_num_steps=self.speculative_num_steps,
)

def _create_atom_dsv4_prefill_backend(self):
return ATOMDeepseekV4BackendForSgl(
self.draft_model_runner,
skip_prefill=False,
)

DraftBackendFactory._create_dsv4_decode_backend = _create_atom_dsv4_decode_backend
DraftBackendFactory._create_dsv4_prefill_backend = _create_atom_dsv4_prefill_backend
DraftBackendFactory._atom_dsv4_draft_backend_patched = True
logger.info("Patched SGLang DSV4 speculative draft backends to ATOM")


def _patch_sglang_dsv4_spec_cuda_graph() -> None:
"""Avoid replaying generic SGLang graphs for DSV4 speculative extend modes.

The target decode graph is still useful and remains enabled. Target verify
and draft-extend need DSV4-specific per-token metadata; until that metadata
is fully graph-safe, let those forwards run eager to avoid replaying a graph
captured with decode-shaped metadata.
"""

try:
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.speculative.eagle_worker_v2 import EagleDraftWorker
except Exception as exc:
logger.debug("Skip patching SGLang DSV4 spec cuda graph: %s", exc)
return

if getattr(CudaGraphRunner, "_atom_dsv4_spec_can_run_patched", False):
return

original_can_run = CudaGraphRunner.can_run

def can_run(self, forward_batch):
try:
model_runner = getattr(self, "model_runner", None)
hf_config = getattr(
getattr(model_runner, "model_config", None), "hf_config", None
)
arches = getattr(hf_config, "architectures", None) or []
is_dsv4 = any("DeepseekV4" in str(arch) for arch in arches)
mode = getattr(forward_batch, "forward_mode", None)
is_spec_extend = bool(
getattr(mode, "is_target_verify", lambda: False)()
) or bool(
getattr(mode, "is_draft_extend", lambda **kwargs: False)(
include_v2=True
)
)
if is_dsv4 and is_spec_extend:
return False
except Exception:
pass
return original_can_run(self, forward_batch)

CudaGraphRunner.can_run = can_run
CudaGraphRunner._atom_dsv4_spec_can_run_patched = True

if not getattr(EagleDraftWorker, "_atom_dsv4_init_cuda_graphs_patched", False):
original_init_cuda_graphs = EagleDraftWorker.init_cuda_graphs

def init_cuda_graphs(self):
try:
arch = self.draft_runner.model_config.hf_config.architectures[0]
if arch == "DeepseekV4ForCausalLMNextN":
self.cuda_graph_runner = None
self.cuda_graph_runner_for_draft_extend = None
logger.info("Skip DSV4 draft cuda graph capture in ATOM plugin")
return
except Exception:
pass
return original_init_cuda_graphs(self)

EagleDraftWorker.init_cuda_graphs = init_cuda_graphs
EagleDraftWorker._atom_dsv4_init_cuda_graphs_patched = True

logger.info("Patched SGLang DSV4 speculative cuda graph handling")


def register_ops_to_sglang(atom_config: Config) -> None:
"""
Register custom ops to sglang, including attention
"""
_register_custom_attention_to_sglang()
_patch_sglang_dsv4_draft_backends()
_patch_sglang_dsv4_spec_cuda_graph()


def set_attn_cls() -> None:
Expand Down
103 changes: 80 additions & 23 deletions atom/plugin/sglang/attention_backend/deepseek_v4_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,23 @@ class ATOMDeepseekV4BackendForSgl(AttentionBackend):
"""

needs_cpu_seq_lens = True
_last_atom_v4_graph_metadata = None

def __init__(self, model_runner, *args, **kwargs):
del args, kwargs
del args
logger.info("Initializing ATOMDeepseekV4BackendForSgl")
self.model_runner = model_runner
self.device = torch.device(model_runner.device)
self.token_to_kv_pool = model_runner.token_to_kv_pool
self.req_to_token_pool = model_runner.req_to_token_pool
self.forward_metadata = None
self.atom_v4_graph_metadata = None
speculative_num_steps = int(kwargs.pop("speculative_num_steps", 0) or 0)
# SGLang EAGLE multi-step draft code expects decode backends to expose
# one attention backend per draft step. ATOM DSV4 owns the real
# per-layer state in the model/bridge, so all draft steps can share this
# shim instance.
self.attn_backends = [self] * max(1, speculative_num_steps)

@staticmethod
def get_name() -> str:
Expand All @@ -40,11 +47,8 @@ def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = Fals
if not (in_capture or hasattr(forward_batch, "actual_forward_mode")):
self.atom_v4_graph_metadata = None
return
if not forward_batch.forward_mode.is_decode_or_idle():
self.atom_v4_graph_metadata = None
return

from atom.plugin.sglang.deepseek_v4_bridge import (
build_atom_v4_attention_metadata_from_sglang,
build_atom_v4_decode_graph_metadata_from_sglang,
)

Expand All @@ -58,12 +62,26 @@ def init_forward_metadata_out_graph(self, forward_batch, in_capture: bool = Fals
return

atom_model = getattr(getattr(self.model_runner, "model", None), "model", None)
self.atom_v4_graph_metadata = build_atom_v4_decode_graph_metadata_from_sglang(
forward_batch,
positions,
proxy_pool=self.token_to_kv_pool,
req_to_token_pool=self.req_to_token_pool,
model=atom_model,
if forward_batch.forward_mode.is_decode_or_idle():
self.atom_v4_graph_metadata = (
build_atom_v4_decode_graph_metadata_from_sglang(
forward_batch,
positions,
proxy_pool=self.token_to_kv_pool,
req_to_token_pool=self.req_to_token_pool,
model=atom_model,
)
)
else:
self.atom_v4_graph_metadata = build_atom_v4_attention_metadata_from_sglang(
forward_batch,
positions,
proxy_pool=self.token_to_kv_pool,
req_to_token_pool=self.req_to_token_pool,
)
forward_batch.atom_v4_graph_metadata = self.atom_v4_graph_metadata
ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = (
self.atom_v4_graph_metadata
)

def _init_decode_cuda_graph_metadata(
Expand Down Expand Up @@ -114,18 +132,24 @@ def _init_decode_cuda_graph_metadata(
req_to_token_pool=self.req_to_token_pool,
model=atom_model,
)
forward_batch.atom_v4_graph_metadata = self.atom_v4_graph_metadata
ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = (
self.atom_v4_graph_metadata
)

def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens,
forward_mode,
spec_info,
):
del num_tokens, encoder_lens, spec_info
def init_forward_metadata_capture_cuda_graph(self, *args, **kwargs):
# New SGLang graph API passes a ForwardBatch. Older call sites pass
# unpacked fields. Support both because speculative draft graph code
# still calls this legacy-named hook directly.
if len(args) == 1 and not kwargs and hasattr(args[0], "forward_mode"):
return self.init_forward_metadata_out_graph(args[0], in_capture=True)

bs = kwargs.get("bs", args[0] if len(args) > 0 else None)
req_pool_indices = kwargs.get(
"req_pool_indices", args[2] if len(args) > 2 else None
)
seq_lens = kwargs.get("seq_lens", args[3] if len(args) > 3 else None)
forward_mode = kwargs.get("forward_mode", args[5] if len(args) > 5 else None)
self._init_decode_cuda_graph_metadata(
bs=bs,
req_pool_indices=req_pool_indices,
Expand Down Expand Up @@ -158,7 +182,40 @@ def init_forward_metadata_replay_cuda_graph(
)

def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
del max_bs, max_num_tokens
from sglang.srt.model_executor.forward_batch_info import ForwardMode

from atom.plugin.sglang.deepseek_v4_bridge import (
build_atom_v4_decode_graph_metadata_from_sglang,
)

bs = int(max_bs)
tokens_per_req = max(1, int(max_num_tokens) // max(1, bs))
seq_lens = torch.full(
(bs,), tokens_per_req, dtype=torch.int32, device=self.device
)
req_pool_indices = torch.arange(bs, dtype=torch.int64, device=self.device)
positions = torch.arange(tokens_per_req, dtype=torch.int64, device=self.device)
positions = positions.repeat(bs)
forward_batch = SimpleNamespace(
forward_mode=ForwardMode.DECODE,
actual_forward_mode=ForwardMode.DECODE,
batch_size=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens.detach().cpu(),
out_cache_loc=None,
)
atom_model = getattr(getattr(self.model_runner, "model", None), "model", None)
self.atom_v4_graph_metadata = build_atom_v4_decode_graph_metadata_from_sglang(
forward_batch,
positions,
proxy_pool=self.token_to_kv_pool,
req_to_token_pool=self.req_to_token_pool,
model=atom_model,
)
ATOMDeepseekV4BackendForSgl._last_atom_v4_graph_metadata = (
self.atom_v4_graph_metadata
)
return None

def get_cuda_graph_seq_len_fill_value(self):
Expand Down
Loading
Loading