diff --git a/atom/model_engine/prefill_delayer.py b/atom/model_engine/prefill_delayer.py index 36a55c808..4ed05f21a 100644 --- a/atom/model_engine/prefill_delayer.py +++ b/atom/model_engine/prefill_delayer.py @@ -7,11 +7,12 @@ Mechanism (per scheduler tick): 1. Each DP rank reports its local state via cpu all_gather: - (local_prefillable, watermark_force_allow) + (local_prefillable, local_alignment_ready, watermark_force_allow) 2. Compute `prefillable_status` ∈ {all, none, mixed}: - - "all" → every rank has a new prefill ready → allow (8-way aligned) + - "all" → every rank has enough prefill ready → allow (8-way aligned) - "none" → no rank has any prefill → allow (vacuous) - - "mixed" → only some ranks have prefill ready → DELAY + - "mixed" → at least one rank has prefill, but + not all ranks are alignment-ready → DELAY 3. In "mixed", refuse the prefill (return False from `should_allow_prefill`) for up to `max_delay_passes` consecutive ticks (default 30, ≈ 255ms at 8.5ms/decode-tick) OR `max_delay_ms` wall-clock (default 5000ms), @@ -87,7 +88,7 @@ def __init__( self.max_delay_ms = max_delay_ms self.token_usage_low_watermark = token_usage_low_watermark - # 3-slot MAX-reduce buffer (gloo-friendly; mirrors the proven + # 4-slot MAX-reduce buffer (gloo-friendly; mirrors the proven # `_sync_dp_state` all_reduce path in engine_core.py rather than # relying on all_gather_into_tensor — the stateless gloo group's # docstring warns broadcast-like ops are unreliable, and the @@ -97,13 +98,14 @@ def __init__( # Encoding: # slot 0 = local_prefillable (MAX → "any rank prefillable") # slot 1 = local_force (MAX → "any rank forces allow") - # slot 2 = NOT local_prefillable (MAX → "any rank lacks prefill") + # slot 2 = local_alignment_ready (MAX → "any rank ready") + # slot 3 = NOT local_alignment_ready (MAX → "any rank not ready") # Then prefillable_status: - # any_prefillable AND any_not_prefillable → "mixed" - # any_prefillable AND NOT any_not_prefillable → "all" + # any_prefillable AND any_not_ready → "mixed" + # any_prefillable AND any_ready AND NOT any_not_ready → "all" # NOT any_prefillable → "none" - # Single all_reduce, 3 int64s on cpu — negligible overhead. - self._reduce_buf = torch.zeros(3, dtype=torch.int64, device="cpu") + # Single all_reduce, 4 int64s on cpu — negligible overhead. + self._reduce_buf = torch.zeros(4, dtype=torch.int64, device="cpu") self._delayed_count: int = 0 self._delay_start_ts: float = 0.0 @@ -132,6 +134,7 @@ def should_allow_prefill( self, local_prefillable: bool, token_usage: float, + local_alignment_ready: Optional[bool] = None, ) -> bool: """ Returns True iff this rank is allowed to admit new prefills this tick. @@ -142,7 +145,13 @@ def should_allow_prefill( token_usage: fraction of KV cache blocks currently in use (used_blocks / total_blocks ∈ [0, 1]). Used by the low-watermark safety valve. + local_alignment_ready: this rank meets the current alignment + threshold. Defaults to ``local_prefillable`` for the legacy + one-request policy; TBO prefill passes ``>= 2`` here. """ + if local_alignment_ready is None: + local_alignment_ready = local_prefillable + # Local "force allow" if KV cache is underutilized — don't delay # when GPU is starving. Only meaningful if this rank actually has # a prefill to push through (otherwise force_allow is a no-op). @@ -154,10 +163,11 @@ def should_allow_prefill( ): force = True - # Cross-DP MAX-reduce: 3 booleans encoded as int64. + # Cross-DP MAX-reduce: 4 booleans encoded as int64. self._reduce_buf[0] = 1 if local_prefillable else 0 self._reduce_buf[1] = 1 if force else 0 - self._reduce_buf[2] = 0 if local_prefillable else 1 + self._reduce_buf[2] = 1 if local_alignment_ready else 0 + self._reduce_buf[3] = 0 if local_alignment_ready else 1 torch.distributed.all_reduce( self._reduce_buf, op=torch.distributed.ReduceOp.MAX, @@ -165,11 +175,11 @@ def should_allow_prefill( ) any_prefillable = int(self._reduce_buf[0].item()) > 0 force_max = int(self._reduce_buf[1].item()) - any_not_prefillable = int(self._reduce_buf[2].item()) > 0 + any_ready = int(self._reduce_buf[2].item()) > 0 + any_not_ready = int(self._reduce_buf[3].item()) > 0 # Derive 3-way status: all / none / mixed. - prefillable_max = 1 if any_prefillable else 0 - prefillable_min = 0 if any_not_prefillable else 1 + all_ready = any_ready and not any_not_ready # Watermark short-circuit: ANY rank below the watermark forces all # ranks to allow this tick. Without this the delayer can stall a @@ -189,7 +199,7 @@ def should_allow_prefill( return True # status = "all" or "none" → no skew, just allow - if prefillable_min == prefillable_max: + if not any_prefillable or all_ready: self._reset_delay() self._stat_allow += 1 self._maybe_log() @@ -211,7 +221,7 @@ def should_allow_prefill( f"[PrefillDelayer] DELAY: count={self._delayed_count} " f"elapsed={elapsed_ms:.1f}ms " f"any_prefillable={any_prefillable} " - f"any_not_prefillable={any_not_prefillable}" + f"any_ready={any_ready} any_not_ready={any_not_ready}" ) self._maybe_log() return False diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 6db875046..2603dc469 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -29,6 +29,7 @@ from atom.model_engine.block_manager import BlockManager from atom.model_engine.request import RequestOutput from atom.model_engine.sequence import Sequence, SequenceStatus, SequenceType +from atom.utils import envs logger = logging.getLogger("atom") @@ -507,9 +508,8 @@ def __init__(self, config: Config): def set_prefill_delayer(self, delayer) -> None: self.prefill_delayer = delayer - def _can_admit_head_prefill(self) -> bool: - """Match SGL's `local_prefillable=True` semantics: report True iff - this rank would *actually* admit a new prefill this tick. + def _count_admittable_head_prefills(self, limit: int) -> int: + """Count how many head prefills this rank can admit this tick. Just having `self.waiting` non-empty is too coarse — during a concurrent-burst workload (e.g. 1k/1k @ high concurrency) every @@ -521,10 +521,13 @@ def _can_admit_head_prefill(self) -> bool: We peek the front of `waiting` (skipping a few unschedulable entries) and check `can_allocate` + token-budget, mirroring the - same checks the admission while-loop runs below. + same checks the admission while-loop runs below. The count is capped + at ``limit`` so the helper stays cheap for the delayer gate. """ - if not self.waiting: - return False + if limit <= 0 or not self.waiting: + return 0 + count = 0 + num_batched_tokens = 0 for i, seq in enumerate(self.waiting): if i >= 4: break @@ -536,9 +539,25 @@ def _can_admit_head_prefill(self) -> bool: if num_new_tokens > self.max_num_batched_tokens: continue if self.block_manager.can_allocate(seq) < 0: - return False # KV-pressured: definitely cannot prefill - return True - return False + break # KV-pressured: definitely cannot prefill more now. + if num_batched_tokens + num_new_tokens > self.max_num_batched_tokens: + break + count += 1 + if count >= limit: + break + num_batched_tokens += num_new_tokens + return count + + def _prefill_delayer_readiness(self) -> tuple[bool, bool]: + """Return the local presence and alignment bits for PrefillDelayer. + + ``ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS`` controls how many local + prefill requests this rank must be able to admit before reporting + alignment-ready. A value of 0 disables the local-count threshold. + """ + required = max(envs.ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS, 0) + count = self._count_admittable_head_prefills(max(required, 1)) + return count > 0, count >= required def _kv_usage(self) -> float: """Fraction of KV-cache blocks currently in use ∈ [0, 1]. @@ -700,9 +719,13 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: # ─── Cross-DP prefill alignment (PrefillDelayer) ─────────────── _delayer_allows_prefill = True if self.prefill_delayer is not None: + _local_prefillable, _local_alignment_ready = ( + self._prefill_delayer_readiness() + ) _delayer_allows_prefill = self.prefill_delayer.should_allow_prefill( - local_prefillable=self._can_admit_head_prefill(), + local_prefillable=_local_prefillable, token_usage=self._kv_usage(), + local_alignment_ready=_local_alignment_ready, ) if not self.running and not self.waiting: diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index ee98185f4..0ee1c509c 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -1016,7 +1016,7 @@ def _forward_decode( paged_kv_indices = self.sparse_kv_indices_buffer dp_size = get_dp_group().world_size - use_persistent_mode = not (dp_size > 1) + use_persistent_mode = dp_size <= 8 if envs.ATOM_MLA_PAGE_SIZE > 1: use_persistent_mode = False @@ -1171,92 +1171,102 @@ def forward_impl( prefill_q, k_nope, k_rope, kv_cache, attn_metadata ) else: - q_nope, q_rope = self._q_proj_and_k_up_proj(q, x_scale=q_scale) - - if self.use_seg_mla: - # Seg path: allocate q_out with a padded last dim so each head row - # has a 768-byte stride (required by the gfx1250 decode asm). The - # kernel only writes the first kv_lora_rank + qk_rope_head_dim - # columns; the padding tail is left untouched and never read. - q_out = torch.empty( - ( - q_nope.shape[0], - self.num_heads, - _MLA_Q_OUT_PADDED_DIM, - ), - dtype=attn_metadata.dtype_q, - device=q_nope.device, - ) - else: - q_out = torch.empty( - ( - q_nope.shape[0], - self.num_heads, - self.kv_lora_rank + self.qk_rope_head_dim, - ), - dtype=attn_metadata.dtype_q, - device=q_nope.device, - ) - if kv_cache.numel() > 0: - if envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV: - shuffled_cache = self._shuffled_kv_view(kv_cache) - triton_fused_qk_rope_cat_and_cache_mla( - q_nope, - q_rope, - k_nope.view(-1, self.num_kv_heads, self.kv_lora_rank), - k_rope.view(-1, self.num_kv_heads, self.qk_rope_head_dim), - shuffled_cache, - attn_metadata.slot_mapping, - positions, - self.rotary_emb.cos_cache, - self.rotary_emb.sin_cache, - self._k_scale, - self.rotary_emb.is_neox_style, - num_decode_toks_for_zeros=0, - apply_scale=True, - q_out=q_out, - shuffled_kv_cache=True, - ) - elif self.use_seg_mla: - kv_cache_seg = self._seg_kv_cache_view(kv_cache) - fused_qk_rope_concat_and_cache_mla_seg( - q_nope, - q_rope, - k_nope, - k_rope, - # Flat seg layout: [num_blocks, page_size*(kv_lora + pe)]. - kv_cache_seg, - q_out, - attn_metadata.slot_mapping, - self._k_scale, - self._q_scale, - positions, - self.rotary_emb.cos_cache, - self.rotary_emb.sin_cache, - is_neox=self.rotary_emb.is_neox_style, + use_shuffle_kv = ( + envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV + ) + needs_explicit_q_cache = ( + attn_metadata.max_seqlen_q > 1 or self.use_seg_mla or use_shuffle_kv + ) + if needs_explicit_q_cache: + q_nope, q_rope = self._q_proj_and_k_up_proj(q, x_scale=q_scale) + + if self.use_seg_mla: + # Seg path: allocate q_out with a padded last dim so each head row + # has a 768-byte stride (required by the gfx1250 decode asm). The + # kernel only writes the first kv_lora_rank + qk_rope_head_dim + # columns; the padding tail is left untouched and never read. + q_out = torch.empty( + ( + q_nope.shape[0], + self.num_heads, + _MLA_Q_OUT_PADDED_DIM, + ), + dtype=attn_metadata.dtype_q, + device=q_nope.device, ) else: - fused_qk_rope_concat_and_cache_mla( - q_nope, - q_rope, - k_nope, - k_rope, - kv_cache.view( - kv_cache.shape[0], - -1, + q_out = torch.empty( + ( + q_nope.shape[0], + self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim, ), - q_out, - attn_metadata.slot_mapping, - self._k_scale, - self._q_scale, - positions, - self.rotary_emb.cos_cache, - self.rotary_emb.sin_cache, - is_neox=self.rotary_emb.is_neox_style, - is_nope_first=True, + dtype=attn_metadata.dtype_q, + device=q_nope.device, ) - # q_out = self.fused_kv_bmm(q, q_scale, k_nope, k_rope, positions, kv_cache, attn_metadata) + if kv_cache.numel() > 0: + if use_shuffle_kv: + shuffled_cache = self._shuffled_kv_view(kv_cache) + triton_fused_qk_rope_cat_and_cache_mla( + q_nope, + q_rope, + k_nope.view(-1, self.num_kv_heads, self.kv_lora_rank), + k_rope.view(-1, self.num_kv_heads, self.qk_rope_head_dim), + shuffled_cache, + attn_metadata.slot_mapping, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + self._k_scale, + self.rotary_emb.is_neox_style, + num_decode_toks_for_zeros=0, + apply_scale=True, + q_out=q_out, + shuffled_kv_cache=True, + ) + elif self.use_seg_mla: + kv_cache_seg = self._seg_kv_cache_view(kv_cache) + fused_qk_rope_concat_and_cache_mla_seg( + q_nope, + q_rope, + k_nope, + k_rope, + # Flat seg layout: [num_blocks, page_size*(kv_lora + pe)]. + kv_cache_seg, + q_out, + attn_metadata.slot_mapping, + self._k_scale, + self._q_scale, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + is_neox=self.rotary_emb.is_neox_style, + ) + else: + fused_qk_rope_concat_and_cache_mla( + q_nope, + q_rope, + k_nope, + k_rope, + kv_cache.view( + kv_cache.shape[0], + -1, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + q_out, + attn_metadata.slot_mapping, + self._k_scale, + self._q_scale, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + is_neox=self.rotary_emb.is_neox_style, + is_nope_first=True, + ) + elif kv_cache.numel() > 0: + q_out = self.fused_kv_bmm( + q, q_scale, k_nope, k_rope, positions, kv_cache, attn_metadata + ) if context.is_prefill: output = self._forward_prefill_mla(q_out, kv_cache, attn_metadata) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index a92ce8449..cfc3998de 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -80,6 +80,9 @@ def from_model_config(a_quant_dtype: str | None) -> "MoEActivationQuant": return MoEActivationQuant.BF16 +_TBO_KEEPALIVE: dict[tuple[str, int], tuple[torch.Tensor, ...]] = {} + + class FusedMoeWeightScaleSupported(Enum): """Supported quantization strategies for MoE weight scales.""" @@ -231,7 +234,10 @@ def pad_for_all_gather(x: torch.Tensor) -> Tuple[torch.Tensor, int]: padding_shape[0] = max_batch_size padded_x = torch.empty(padding_shape, device=x.device, dtype=x.dtype) padded_x[:original_batch_size, :].copy_(x) - # padded_x[original_batch_size:, :].zero_() + # Padded rows still enter fused-MoE routing/sort/dispatch before being + # sliced away after reduce-scatter; uninitialized NaN/Inf rows can perturb + # expert buckets or shared scratch and corrupt real tokens. + padded_x[original_batch_size:, :].zero_() return padded_x, original_batch_size @@ -2375,6 +2381,17 @@ def __init__( ), dim=0, ) + # In the DP-attn fallback path (dp>1, no MORI all2all), MoE runs + # after all_gather_with_padding, so the token dim can be dp_size times + # the per-rank max. + moe_max_num_tokens = atom_config.max_num_batched_tokens + if ( + self.moe_parallel_config.dp_size > 1 + and not self.moe_parallel_config.use_all2all_kernels + and atom_config.enable_dp_attention + ): + moe_max_num_tokens *= self.moe_parallel_config.dp_size + if fuse_shared_experts and self.num_fused_shared_experts > 0: init_aiter_topK_meta_data( n_routed_experts=self.global_num_experts, @@ -2387,7 +2404,7 @@ def __init__( if is_rocm_aiter_fuse_routed_scaling_factor() else 1 / self.routed_scaling_factor ), - max_num_tokens=atom_config.max_num_batched_tokens, + max_num_tokens=moe_max_num_tokens, is_EP=self.use_ep, ) if fuse_shared_experts: @@ -2427,7 +2444,7 @@ def __init__( moe_parallel_config=self.moe_parallel_config, in_dtype=atom_config.torch_dtype, a_quant_dtype=a_quant_dtype, - max_num_tokens=atom_config.max_num_batched_tokens, + max_num_tokens=moe_max_num_tokens, has_bias=self.has_bias, # is_act_and_mul=True, is_lora_enabled=False, @@ -3325,6 +3342,26 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): hidden_states, router_logits, self.layer_name ) + def _tbo_keepalive_slot(self) -> int: + try: + from atom.utils.tbo.ubatching import tbo_current_ubatch_id + + return tbo_current_ubatch_id() + except Exception: + return 0 + + def _hold_tbo_keepalive(self, role: str, *tensors: torch.Tensor) -> None: + tensors = tuple(tensor for tensor in tensors if tensor is not None) + if tensors: + # Keep one previous tensor set per ubatch/role alive globally. + # The next same-role hold, often in the next MoE layer, happens + # after this ubatch has waited on the prior comm work, so + # overwriting here is the delayed safe release point. + key = (role, self._tbo_keepalive_slot()) + if key in _TBO_KEEPALIVE: + del _TBO_KEEPALIVE[key] + _TBO_KEEPALIVE[key] = tensors + def forward_impl_graph( self, hidden_states: torch.Tensor, router_logits: torch.Tensor ): @@ -3355,6 +3392,7 @@ def forward_impl_graph( ) tbo_yield_and_switch_from_compute_to_comm() + self._hold_tbo_keepalive("ag_source", hidden_states, router_logits) ( hidden_states, @@ -3367,6 +3405,7 @@ def forward_impl_graph( if _tbo: tbo_switch_to_compute_sync() + self._hold_tbo_keepalive("ag_output", hidden_states, router_logits) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -3392,6 +3431,7 @@ def forward_impl_graph( if use_dp_gather_scatter: if _tbo: tbo_yield_and_switch_from_compute_to_comm() + self._hold_tbo_keepalive("rs_source", final_hidden_states) if dp_eager_mode: final_hidden_states = reduce_scatterv( final_hidden_states, sizes, dp_group @@ -3402,6 +3442,7 @@ def forward_impl_graph( ) if _tbo: tbo_switch_to_compute_sync() + self._hold_tbo_keepalive("rs_output", final_hidden_states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) diff --git a/atom/model_ops/topK.py b/atom/model_ops/topK.py index 1966ee003..11b72df4f 100644 --- a/atom/model_ops/topK.py +++ b/atom/model_ops/topK.py @@ -21,7 +21,13 @@ def is_rocm_aiter_fusion_shared_expert_enabled_for_quant_config( quant_config = config.quant_config dp_size = config.parallel_config.data_parallel_size - if dp_size > 1 and _has_module("mori") and config.enable_dp_attention: + use_mori_all2all = ( + dp_size > 1 + and _has_module("mori") + and config.enable_dp_attention + and config.enable_expert_parallel + ) + if use_mori_all2all: return False if quant_config is not None and shared_expert_prefix is not None: diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 3bd2a53ab..4e07a669f 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -230,6 +230,11 @@ if os.getenv("ATOM_PREFILL_DELAYER_TOKEN_USAGE_LOW_WATERMARK", "") == "" else float(os.getenv("ATOM_PREFILL_DELAYER_TOKEN_USAGE_LOW_WATERMARK")) ), + # Number of local prefill requests required before reporting alignment-ready. + # Default 0 disables this local-count threshold. + "ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS": lambda: int( + os.getenv("ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS", "0") + ), # --- TBO prefill ubatch splitting --- # Split prefill ubatches at the exact token midpoint (vLLM-DBO style), # cutting through a request if needed for perfectly balanced 50/50 ubatches. diff --git a/atom/utils/tbo/ubatch_wrapper.py b/atom/utils/tbo/ubatch_wrapper.py index 3873ac64e..c1d4d16d7 100644 --- a/atom/utils/tbo/ubatch_wrapper.py +++ b/atom/utils/tbo/ubatch_wrapper.py @@ -452,6 +452,7 @@ def _make_ubatch_context( batch_size=ub_num_reqs, graph_bs=graph_bs, is_draft=ctx.context.is_draft, + dp_uniform_decode=ctx.context.dp_uniform_decode, ) return ForwardContext(