Skip to content
42 changes: 26 additions & 16 deletions atom/model_engine/prefill_delayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

Mechanism (per scheduler tick):
1. Each DP rank reports its local state via cpu all_gather:
(local_prefillable, watermark_force_allow)
(local_prefillable, local_alignment_ready, watermark_force_allow)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your changes seem to require >=2 bs for TBO to be ready; does this approach has performance improvement? Or whether it will affect old performance.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’ve made this behavior configurable via an environment variable (disabled by default to avoid affecting existing performance).

Below is a Kimi K2.5 performance comparison (conc=128, isl=8k, osl=1k) with ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS=0/2, about 25% throughput gain observed.

ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS=0

============ Serving Benchmark Result ============
Successful requests:                     512       
Benchmark duration (s):                  235.88    
Total input tokens:                      3775394   
Total generated tokens:                  473911    
Request throughput (req/s):              2.17      
Output token throughput (tok/s):         2009.10   
Total Token throughput (tok/s):          18014.52  
---------------Time to First Token----------------
Mean TTFT (ms):                          2994.52   
Median TTFT (ms):                        786.19    
P99 TTFT (ms):                           17830.07  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          59.09     
Median TPOT (ms):                        61.84     
P99 TPOT (ms):                           77.20     
---------------Inter-token Latency----------------
Mean ITL (ms):                           59.33     
Median ITL (ms):                         27.85     
P99 ITL (ms):                            626.02    
----------------End-to-end Latency----------------
Mean E2EL (ms):                          57912.88  
Median E2EL (ms):                        58740.95  
P99 E2EL (ms):                           81474.65  
==================================================

ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS=2

============ Serving Benchmark Result ============
Successful requests:                     512       
Benchmark duration (s):                  178.86    
Total input tokens:                      3775394   
Total generated tokens:                  473911    
Request throughput (req/s):              2.86      
Output token throughput (tok/s):         2649.66   
Total Token throughput (tok/s):          23758.06  
---------------Time to First Token----------------
Mean TTFT (ms):                          3435.93   
Median TTFT (ms):                        1459.46   
P99 TTFT (ms):                           16975.80  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          43.29     
Median TPOT (ms):                        44.66     
P99 TPOT (ms):                           54.32     
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.39     
Median ITL (ms):                         27.81     
P99 ITL (ms):                            742.68    
----------------End-to-end Latency----------------
Mean E2EL (ms):                          43596.09  
Median E2EL (ms):                        43810.70  
P99 E2EL (ms):                           60233.29  
==================================================

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, that's good news. We will test ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS=2 on deepseek v4 and other models if it's indeed effective

2. Compute `prefillable_status` ∈ {all, none, mixed}:
- "all" → every rank has a new prefill ready → allow (8-way aligned)
- "all" → every rank has enough prefill ready → allow (8-way aligned)
- "none" → no rank has any prefill → allow (vacuous)
- "mixed" → only some ranks have prefill ready → DELAY
- "mixed" → at least one rank has prefill, but
not all ranks are alignment-ready → DELAY
3. In "mixed", refuse the prefill (return False from `should_allow_prefill`)
for up to `max_delay_passes` consecutive ticks (default 30, ≈ 255ms at
8.5ms/decode-tick) OR `max_delay_ms` wall-clock (default 5000ms),
Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(
self.max_delay_ms = max_delay_ms
self.token_usage_low_watermark = token_usage_low_watermark

# 3-slot MAX-reduce buffer (gloo-friendly; mirrors the proven
# 4-slot MAX-reduce buffer (gloo-friendly; mirrors the proven
# `_sync_dp_state` all_reduce path in engine_core.py rather than
# relying on all_gather_into_tensor — the stateless gloo group's
# docstring warns broadcast-like ops are unreliable, and the
Expand All @@ -97,13 +98,14 @@ def __init__(
# Encoding:
# slot 0 = local_prefillable (MAX → "any rank prefillable")
# slot 1 = local_force (MAX → "any rank forces allow")
# slot 2 = NOT local_prefillable (MAX → "any rank lacks prefill")
# slot 2 = local_alignment_ready (MAX → "any rank ready")
# slot 3 = NOT local_alignment_ready (MAX → "any rank not ready")
# Then prefillable_status:
# any_prefillable AND any_not_prefillable → "mixed"
# any_prefillable AND NOT any_not_prefillable → "all"
# any_prefillable AND any_not_ready → "mixed"
# any_prefillable AND any_ready AND NOT any_not_ready → "all"
# NOT any_prefillable → "none"
# Single all_reduce, 3 int64s on cpu — negligible overhead.
self._reduce_buf = torch.zeros(3, dtype=torch.int64, device="cpu")
# Single all_reduce, 4 int64s on cpu — negligible overhead.
self._reduce_buf = torch.zeros(4, dtype=torch.int64, device="cpu")

self._delayed_count: int = 0
self._delay_start_ts: float = 0.0
Expand Down Expand Up @@ -132,6 +134,7 @@ def should_allow_prefill(
self,
local_prefillable: bool,
token_usage: float,
local_alignment_ready: Optional[bool] = None,
) -> bool:
"""
Returns True iff this rank is allowed to admit new prefills this tick.
Expand All @@ -142,7 +145,13 @@ def should_allow_prefill(
token_usage: fraction of KV cache blocks currently in use
(used_blocks / total_blocks ∈ [0, 1]). Used by the
low-watermark safety valve.
local_alignment_ready: this rank meets the current alignment
threshold. Defaults to ``local_prefillable`` for the legacy
one-request policy; TBO prefill passes ``>= 2`` here.
"""
if local_alignment_ready is None:
local_alignment_ready = local_prefillable

# Local "force allow" if KV cache is underutilized — don't delay
# when GPU is starving. Only meaningful if this rank actually has
# a prefill to push through (otherwise force_allow is a no-op).
Expand All @@ -154,22 +163,23 @@ def should_allow_prefill(
):
force = True

# Cross-DP MAX-reduce: 3 booleans encoded as int64.
# Cross-DP MAX-reduce: 4 booleans encoded as int64.
self._reduce_buf[0] = 1 if local_prefillable else 0
self._reduce_buf[1] = 1 if force else 0
self._reduce_buf[2] = 0 if local_prefillable else 1
self._reduce_buf[2] = 1 if local_alignment_ready else 0
self._reduce_buf[3] = 0 if local_alignment_ready else 1
torch.distributed.all_reduce(
self._reduce_buf,
op=torch.distributed.ReduceOp.MAX,
group=self.cpu_group,
)
any_prefillable = int(self._reduce_buf[0].item()) > 0
force_max = int(self._reduce_buf[1].item())
any_not_prefillable = int(self._reduce_buf[2].item()) > 0
any_ready = int(self._reduce_buf[2].item()) > 0
any_not_ready = int(self._reduce_buf[3].item()) > 0

# Derive 3-way status: all / none / mixed.
prefillable_max = 1 if any_prefillable else 0
prefillable_min = 0 if any_not_prefillable else 1
all_ready = any_ready and not any_not_ready

# Watermark short-circuit: ANY rank below the watermark forces all
# ranks to allow this tick. Without this the delayer can stall a
Expand All @@ -189,7 +199,7 @@ def should_allow_prefill(
return True

# status = "all" or "none" → no skew, just allow
if prefillable_min == prefillable_max:
if not any_prefillable or all_ready:
self._reset_delay()
self._stat_allow += 1
self._maybe_log()
Expand All @@ -211,7 +221,7 @@ def should_allow_prefill(
f"[PrefillDelayer] DELAY: count={self._delayed_count} "
f"elapsed={elapsed_ms:.1f}ms "
f"any_prefillable={any_prefillable} "
f"any_not_prefillable={any_not_prefillable}"
f"any_ready={any_ready} any_not_ready={any_not_ready}"
)
self._maybe_log()
return False
Expand Down
43 changes: 33 additions & 10 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,9 +507,8 @@ def __init__(self, config: Config):
def set_prefill_delayer(self, delayer) -> None:
self.prefill_delayer = delayer

def _can_admit_head_prefill(self) -> bool:
"""Match SGL's `local_prefillable=True` semantics: report True iff
this rank would *actually* admit a new prefill this tick.
def _count_admittable_head_prefills(self, limit: int) -> int:
"""Count how many head prefills this rank can admit this tick.

Just having `self.waiting` non-empty is too coarse — during a
concurrent-burst workload (e.g. 1k/1k @ high concurrency) every
Expand All @@ -521,10 +520,13 @@ def _can_admit_head_prefill(self) -> bool:

We peek the front of `waiting` (skipping a few unschedulable
entries) and check `can_allocate` + token-budget, mirroring the
same checks the admission while-loop runs below.
same checks the admission while-loop runs below. The count is capped
at ``limit`` so the helper stays cheap for the delayer gate.
"""
if not self.waiting:
return False
if limit <= 0 or not self.waiting:
return 0
count = 0
num_batched_tokens = 0
for i, seq in enumerate(self.waiting):
if i >= 4:
break
Expand All @@ -536,9 +538,26 @@ def _can_admit_head_prefill(self) -> bool:
if num_new_tokens > self.max_num_batched_tokens:
continue
if self.block_manager.can_allocate(seq) < 0:
return False # KV-pressured: definitely cannot prefill
return True
return False
break # KV-pressured: definitely cannot prefill more now.
Comment on lines 538 to +542
if num_batched_tokens + num_new_tokens > self.max_num_batched_tokens:
break
count += 1
if count >= limit:
break
num_batched_tokens += num_new_tokens
return count

def _prefill_delayer_readiness(self) -> tuple[bool, bool]:
"""Return the local presence and alignment bits for PrefillDelayer.

TBO prefill splitting needs at least two local prefill requests.
When TBO is enabled, wait for each DP rank to be able to admit two
requests before reporting "ready"; otherwise keep the legacy one
request threshold.
"""
required = 2 if self.config.enable_tbo else 1
count = self._count_admittable_head_prefills(required)
return count > 0, count >= required

def _kv_usage(self) -> float:
"""Fraction of KV-cache blocks currently in use ∈ [0, 1].
Expand Down Expand Up @@ -700,9 +719,13 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
# ─── Cross-DP prefill alignment (PrefillDelayer) ───────────────
_delayer_allows_prefill = True
if self.prefill_delayer is not None:
_local_prefillable, _local_alignment_ready = (
self._prefill_delayer_readiness()
)
_delayer_allows_prefill = self.prefill_delayer.should_allow_prefill(
local_prefillable=self._can_admit_head_prefill(),
local_prefillable=_local_prefillable,
token_usage=self._kv_usage(),
local_alignment_ready=_local_alignment_ready,
)

if not self.running and not self.waiting:
Expand Down
2 changes: 1 addition & 1 deletion atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ def _forward_decode(
paged_kv_indices = self.sparse_kv_indices_buffer

dp_size = get_dp_group().world_size
use_persistent_mode = not (dp_size > 1)
use_persistent_mode = dp_size <= 8
if envs.ATOM_MLA_PAGE_SIZE > 1:
use_persistent_mode = False

Expand Down
47 changes: 44 additions & 3 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def from_model_config(a_quant_dtype: str | None) -> "MoEActivationQuant":
return MoEActivationQuant.BF16


_TBO_KEEPALIVE: dict[tuple[str, int], tuple[torch.Tensor, ...]] = {}


class FusedMoeWeightScaleSupported(Enum):
"""Supported quantization strategies for MoE weight scales."""

Expand Down Expand Up @@ -231,7 +234,10 @@ def pad_for_all_gather(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
padding_shape[0] = max_batch_size
padded_x = torch.empty(padding_shape, device=x.device, dtype=x.dtype)
padded_x[:original_batch_size, :].copy_(x)
# padded_x[original_batch_size:, :].zero_()
# Padded rows still enter fused-MoE routing/sort/dispatch before being
# sliced away after reduce-scatter; uninitialized NaN/Inf rows can perturb
# expert buckets or shared scratch and corrupt real tokens.
padded_x[original_batch_size:, :].zero_()
return padded_x, original_batch_size


Expand Down Expand Up @@ -2375,6 +2381,17 @@ def __init__(
),
dim=0,
)
# In the DP-attn fallback path (dp>1, no MORI all2all), MoE runs
# after all_gather_with_padding, so the token dim can be dp_size times
# the per-rank max.
moe_max_num_tokens = atom_config.max_num_batched_tokens
if (
self.moe_parallel_config.dp_size > 1
and not self.moe_parallel_config.use_all2all_kernels
and atom_config.enable_dp_attention
):
moe_max_num_tokens *= self.moe_parallel_config.dp_size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we need moe_max_num_tokens *= self.moe_parallel_config.dp_size here.. In all_gahter and model runner, we have padded, * dp_size here will make BS large and kernel bad perf

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we only increase the size of the preallocated internal buffer in FusedMoE, not the actual batch size used in the forward pass. This internal buffer needs to be large enough to accommodate tokens from all DP ranks, so we multiply by dp_size, similar to what we've already done for the all-gather / reduce-scatter buffers.


if fuse_shared_experts and self.num_fused_shared_experts > 0:
init_aiter_topK_meta_data(
n_routed_experts=self.global_num_experts,
Expand All @@ -2387,7 +2404,7 @@ def __init__(
if is_rocm_aiter_fuse_routed_scaling_factor()
else 1 / self.routed_scaling_factor
),
max_num_tokens=atom_config.max_num_batched_tokens,
max_num_tokens=moe_max_num_tokens,
is_EP=self.use_ep,
)
if fuse_shared_experts:
Expand Down Expand Up @@ -2427,7 +2444,7 @@ def __init__(
moe_parallel_config=self.moe_parallel_config,
in_dtype=atom_config.torch_dtype,
a_quant_dtype=a_quant_dtype,
max_num_tokens=atom_config.max_num_batched_tokens,
max_num_tokens=moe_max_num_tokens,
has_bias=self.has_bias,
# is_act_and_mul=True,
is_lora_enabled=False,
Expand Down Expand Up @@ -3325,6 +3342,26 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
hidden_states, router_logits, self.layer_name
)

def _tbo_keepalive_slot(self) -> int:
try:
from atom.utils.tbo.ubatching import tbo_current_ubatch_id

return tbo_current_ubatch_id()
except Exception:
return 0

def _hold_tbo_keepalive(self, role: str, *tensors: torch.Tensor) -> None:
tensors = tuple(tensor for tensor in tensors if tensor is not None)
if tensors:
# Keep one previous tensor set per ubatch/role alive globally.
# The next same-role hold, often in the next MoE layer, happens
# after this ubatch has waited on the prior comm work, so
# overwriting here is the delayed safe release point.
key = (role, self._tbo_keepalive_slot())
if key in _TBO_KEEPALIVE:
del _TBO_KEEPALIVE[key]
_TBO_KEEPALIVE[key] = tensors

def forward_impl_graph(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
Expand Down Expand Up @@ -3355,6 +3392,7 @@ def forward_impl_graph(
)

tbo_yield_and_switch_from_compute_to_comm()
self._hold_tbo_keepalive("ag_source", hidden_states, router_logits)

(
hidden_states,
Expand All @@ -3367,6 +3405,7 @@ def forward_impl_graph(

if _tbo:
tbo_switch_to_compute_sync()
self._hold_tbo_keepalive("ag_output", hidden_states, router_logits)
Comment on lines 3406 to +3408

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
Expand All @@ -3392,6 +3431,7 @@ def forward_impl_graph(
if use_dp_gather_scatter:
if _tbo:
tbo_yield_and_switch_from_compute_to_comm()
self._hold_tbo_keepalive("rs_source", final_hidden_states)
if dp_eager_mode:
final_hidden_states = reduce_scatterv(
final_hidden_states, sizes, dp_group
Expand All @@ -3402,6 +3442,7 @@ def forward_impl_graph(
)
if _tbo:
tbo_switch_to_compute_sync()
self._hold_tbo_keepalive("rs_output", final_hidden_states)
Comment on lines 3443 to +3445

if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.)
Expand Down
8 changes: 7 additions & 1 deletion atom/model_ops/topK.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ def is_rocm_aiter_fusion_shared_expert_enabled_for_quant_config(
quant_config = config.quant_config

dp_size = config.parallel_config.data_parallel_size
if dp_size > 1 and _has_module("mori") and config.enable_dp_attention:
use_mori_all2all = (
dp_size > 1
and _has_module("mori")
and config.enable_dp_attention
and config.enable_expert_parallel
)
if use_mori_all2all:
return False

if quant_config is not None and shared_expert_prefix is not None:
Expand Down
1 change: 1 addition & 0 deletions atom/utils/tbo/ubatch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def _make_ubatch_context(
batch_size=ub_num_reqs,
graph_bs=graph_bs,
is_draft=ctx.context.is_draft,
dp_uniform_decode=ctx.context.dp_uniform_decode,
)

return ForwardContext(
Expand Down