From e7307f599744055a1bd435fef8c9a0f8468e70fa Mon Sep 17 00:00:00 2001 From: jpy794 Date: Sat, 30 May 2026 13:36:34 +0000 Subject: [PATCH 01/11] enable persist mla --- atom/model_ops/attention_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index ee98185f4..9061e18d0 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -1016,7 +1016,7 @@ def _forward_decode( paged_kv_indices = self.sparse_kv_indices_buffer dp_size = get_dp_group().world_size - use_persistent_mode = not (dp_size > 1) + use_persistent_mode = dp_size <= 8 if envs.ATOM_MLA_PAGE_SIZE > 1: use_persistent_mode = False From f3d47e1db64fb0c511021ac5bc3bb833977e4cbc Mon Sep 17 00:00:00 2001 From: jpy794 Date: Sat, 16 May 2026 18:10:50 +0000 Subject: [PATCH 02/11] fix dpa fused moe --- atom/model_ops/moe.py | 15 +++++++++++++-- atom/model_ops/topK.py | 9 +++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index a92ce8449..bdb9bbe39 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -2375,6 +2375,17 @@ def __init__( ), dim=0, ) + # In the DP-attn fallback path (dp>1, no MORI all2all), MoE runs + # after all_gather_with_padding, so the token dim can be dp_size times + # the per-rank max. + moe_max_num_tokens = atom_config.max_num_batched_tokens + if ( + self.moe_parallel_config.dp_size > 1 + and not self.moe_parallel_config.use_all2all_kernels + and atom_config.enable_dp_attention + ): + moe_max_num_tokens *= self.moe_parallel_config.dp_size + if fuse_shared_experts and self.num_fused_shared_experts > 0: init_aiter_topK_meta_data( n_routed_experts=self.global_num_experts, @@ -2387,7 +2398,7 @@ def __init__( if is_rocm_aiter_fuse_routed_scaling_factor() else 1 / self.routed_scaling_factor ), - max_num_tokens=atom_config.max_num_batched_tokens, + max_num_tokens=moe_max_num_tokens, is_EP=self.use_ep, ) if fuse_shared_experts: @@ -2427,7 +2438,7 @@ def __init__( moe_parallel_config=self.moe_parallel_config, in_dtype=atom_config.torch_dtype, a_quant_dtype=a_quant_dtype, - max_num_tokens=atom_config.max_num_batched_tokens, + max_num_tokens=moe_max_num_tokens, has_bias=self.has_bias, # is_act_and_mul=True, is_lora_enabled=False, diff --git a/atom/model_ops/topK.py b/atom/model_ops/topK.py index 1966ee003..74cc2ddf0 100644 --- a/atom/model_ops/topK.py +++ b/atom/model_ops/topK.py @@ -62,6 +62,15 @@ def is_rocm_aiter_fusion_shared_expert_enabled_for_quant_config( return False break + dp_size = config.parallel_config.data_parallel_size + use_mori_all2all = ( + dp_size > 1 + and _has_module("mori") + and config.enable_dp_attention + and config.enable_expert_parallel + ) + if use_mori_all2all: + return False return True From e55cdc13b54f63a6b84b134b600620dc61299d2e Mon Sep 17 00:00:00 2001 From: jpy794 Date: Sat, 30 May 2026 21:43:57 +0000 Subject: [PATCH 03/11] fix tbo tenosr live range --- atom/model_ops/moe.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index bdb9bbe39..86a1264ea 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -80,6 +80,9 @@ def from_model_config(a_quant_dtype: str | None) -> "MoEActivationQuant": return MoEActivationQuant.BF16 +_TBO_KEEPALIVE: dict[tuple[str, int], tuple[torch.Tensor, ...]] = {} + + class FusedMoeWeightScaleSupported(Enum): """Supported quantization strategies for MoE weight scales.""" @@ -3336,6 +3339,26 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): hidden_states, router_logits, self.layer_name ) + def _tbo_keepalive_slot(self) -> int: + try: + from atom.utils.tbo.ubatching import tbo_current_ubatch_id + + return tbo_current_ubatch_id() + except Exception: + return 0 + + def _hold_tbo_keepalive(self, role: str, *tensors: torch.Tensor) -> None: + tensors = tuple(tensor for tensor in tensors if tensor is not None) + if tensors: + # Keep one previous tensor set per ubatch/role alive globally. + # The next same-role hold, often in the next MoE layer, happens + # after this ubatch has waited on the prior comm work, so + # overwriting here is the delayed safe release point. + key = (role, self._tbo_keepalive_slot()) + if key in _TBO_KEEPALIVE: + del _TBO_KEEPALIVE[key] + _TBO_KEEPALIVE[key] = tensors + def forward_impl_graph( self, hidden_states: torch.Tensor, router_logits: torch.Tensor ): @@ -3366,6 +3389,7 @@ def forward_impl_graph( ) tbo_yield_and_switch_from_compute_to_comm() + self._hold_tbo_keepalive("ag_source", hidden_states, router_logits) ( hidden_states, @@ -3378,6 +3402,7 @@ def forward_impl_graph( if _tbo: tbo_switch_to_compute_sync() + self._hold_tbo_keepalive("ag_output", hidden_states, router_logits) # Matrix multiply. final_hidden_states = self.quant_method.apply( @@ -3403,6 +3428,7 @@ def forward_impl_graph( if use_dp_gather_scatter: if _tbo: tbo_yield_and_switch_from_compute_to_comm() + self._hold_tbo_keepalive("rs_source", final_hidden_states) if dp_eager_mode: final_hidden_states = reduce_scatterv( final_hidden_states, sizes, dp_group @@ -3413,6 +3439,7 @@ def forward_impl_graph( ) if _tbo: tbo_switch_to_compute_sync() + self._hold_tbo_keepalive("rs_output", final_hidden_states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs.) From 4d5a86b804fa4acb415e688ccd5e3326c5fda35f Mon Sep 17 00:00:00 2001 From: jpy794 Date: Sun, 31 May 2026 04:28:14 +0000 Subject: [PATCH 04/11] delay 2 batch in tbo --- atom/model_engine/prefill_delayer.py | 42 ++++++++++++++++----------- atom/model_engine/scheduler.py | 43 +++++++++++++++++++++------- 2 files changed, 59 insertions(+), 26 deletions(-) diff --git a/atom/model_engine/prefill_delayer.py b/atom/model_engine/prefill_delayer.py index 36a55c808..4ed05f21a 100644 --- a/atom/model_engine/prefill_delayer.py +++ b/atom/model_engine/prefill_delayer.py @@ -7,11 +7,12 @@ Mechanism (per scheduler tick): 1. Each DP rank reports its local state via cpu all_gather: - (local_prefillable, watermark_force_allow) + (local_prefillable, local_alignment_ready, watermark_force_allow) 2. Compute `prefillable_status` ∈ {all, none, mixed}: - - "all" → every rank has a new prefill ready → allow (8-way aligned) + - "all" → every rank has enough prefill ready → allow (8-way aligned) - "none" → no rank has any prefill → allow (vacuous) - - "mixed" → only some ranks have prefill ready → DELAY + - "mixed" → at least one rank has prefill, but + not all ranks are alignment-ready → DELAY 3. In "mixed", refuse the prefill (return False from `should_allow_prefill`) for up to `max_delay_passes` consecutive ticks (default 30, ≈ 255ms at 8.5ms/decode-tick) OR `max_delay_ms` wall-clock (default 5000ms), @@ -87,7 +88,7 @@ def __init__( self.max_delay_ms = max_delay_ms self.token_usage_low_watermark = token_usage_low_watermark - # 3-slot MAX-reduce buffer (gloo-friendly; mirrors the proven + # 4-slot MAX-reduce buffer (gloo-friendly; mirrors the proven # `_sync_dp_state` all_reduce path in engine_core.py rather than # relying on all_gather_into_tensor — the stateless gloo group's # docstring warns broadcast-like ops are unreliable, and the @@ -97,13 +98,14 @@ def __init__( # Encoding: # slot 0 = local_prefillable (MAX → "any rank prefillable") # slot 1 = local_force (MAX → "any rank forces allow") - # slot 2 = NOT local_prefillable (MAX → "any rank lacks prefill") + # slot 2 = local_alignment_ready (MAX → "any rank ready") + # slot 3 = NOT local_alignment_ready (MAX → "any rank not ready") # Then prefillable_status: - # any_prefillable AND any_not_prefillable → "mixed" - # any_prefillable AND NOT any_not_prefillable → "all" + # any_prefillable AND any_not_ready → "mixed" + # any_prefillable AND any_ready AND NOT any_not_ready → "all" # NOT any_prefillable → "none" - # Single all_reduce, 3 int64s on cpu — negligible overhead. - self._reduce_buf = torch.zeros(3, dtype=torch.int64, device="cpu") + # Single all_reduce, 4 int64s on cpu — negligible overhead. + self._reduce_buf = torch.zeros(4, dtype=torch.int64, device="cpu") self._delayed_count: int = 0 self._delay_start_ts: float = 0.0 @@ -132,6 +134,7 @@ def should_allow_prefill( self, local_prefillable: bool, token_usage: float, + local_alignment_ready: Optional[bool] = None, ) -> bool: """ Returns True iff this rank is allowed to admit new prefills this tick. @@ -142,7 +145,13 @@ def should_allow_prefill( token_usage: fraction of KV cache blocks currently in use (used_blocks / total_blocks ∈ [0, 1]). Used by the low-watermark safety valve. + local_alignment_ready: this rank meets the current alignment + threshold. Defaults to ``local_prefillable`` for the legacy + one-request policy; TBO prefill passes ``>= 2`` here. """ + if local_alignment_ready is None: + local_alignment_ready = local_prefillable + # Local "force allow" if KV cache is underutilized — don't delay # when GPU is starving. Only meaningful if this rank actually has # a prefill to push through (otherwise force_allow is a no-op). @@ -154,10 +163,11 @@ def should_allow_prefill( ): force = True - # Cross-DP MAX-reduce: 3 booleans encoded as int64. + # Cross-DP MAX-reduce: 4 booleans encoded as int64. self._reduce_buf[0] = 1 if local_prefillable else 0 self._reduce_buf[1] = 1 if force else 0 - self._reduce_buf[2] = 0 if local_prefillable else 1 + self._reduce_buf[2] = 1 if local_alignment_ready else 0 + self._reduce_buf[3] = 0 if local_alignment_ready else 1 torch.distributed.all_reduce( self._reduce_buf, op=torch.distributed.ReduceOp.MAX, @@ -165,11 +175,11 @@ def should_allow_prefill( ) any_prefillable = int(self._reduce_buf[0].item()) > 0 force_max = int(self._reduce_buf[1].item()) - any_not_prefillable = int(self._reduce_buf[2].item()) > 0 + any_ready = int(self._reduce_buf[2].item()) > 0 + any_not_ready = int(self._reduce_buf[3].item()) > 0 # Derive 3-way status: all / none / mixed. - prefillable_max = 1 if any_prefillable else 0 - prefillable_min = 0 if any_not_prefillable else 1 + all_ready = any_ready and not any_not_ready # Watermark short-circuit: ANY rank below the watermark forces all # ranks to allow this tick. Without this the delayer can stall a @@ -189,7 +199,7 @@ def should_allow_prefill( return True # status = "all" or "none" → no skew, just allow - if prefillable_min == prefillable_max: + if not any_prefillable or all_ready: self._reset_delay() self._stat_allow += 1 self._maybe_log() @@ -211,7 +221,7 @@ def should_allow_prefill( f"[PrefillDelayer] DELAY: count={self._delayed_count} " f"elapsed={elapsed_ms:.1f}ms " f"any_prefillable={any_prefillable} " - f"any_not_prefillable={any_not_prefillable}" + f"any_ready={any_ready} any_not_ready={any_not_ready}" ) self._maybe_log() return False diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 6db875046..3af02537f 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -507,9 +507,8 @@ def __init__(self, config: Config): def set_prefill_delayer(self, delayer) -> None: self.prefill_delayer = delayer - def _can_admit_head_prefill(self) -> bool: - """Match SGL's `local_prefillable=True` semantics: report True iff - this rank would *actually* admit a new prefill this tick. + def _count_admittable_head_prefills(self, limit: int) -> int: + """Count how many head prefills this rank can admit this tick. Just having `self.waiting` non-empty is too coarse — during a concurrent-burst workload (e.g. 1k/1k @ high concurrency) every @@ -521,10 +520,13 @@ def _can_admit_head_prefill(self) -> bool: We peek the front of `waiting` (skipping a few unschedulable entries) and check `can_allocate` + token-budget, mirroring the - same checks the admission while-loop runs below. + same checks the admission while-loop runs below. The count is capped + at ``limit`` so the helper stays cheap for the delayer gate. """ - if not self.waiting: - return False + if limit <= 0 or not self.waiting: + return 0 + count = 0 + num_batched_tokens = 0 for i, seq in enumerate(self.waiting): if i >= 4: break @@ -536,9 +538,26 @@ def _can_admit_head_prefill(self) -> bool: if num_new_tokens > self.max_num_batched_tokens: continue if self.block_manager.can_allocate(seq) < 0: - return False # KV-pressured: definitely cannot prefill - return True - return False + break # KV-pressured: definitely cannot prefill more now. + if num_batched_tokens + num_new_tokens > self.max_num_batched_tokens: + break + count += 1 + if count >= limit: + break + num_batched_tokens += num_new_tokens + return count + + def _prefill_delayer_readiness(self) -> tuple[bool, bool]: + """Return the local presence and alignment bits for PrefillDelayer. + + TBO prefill splitting needs at least two local prefill requests. + When TBO is enabled, wait for each DP rank to be able to admit two + requests before reporting "ready"; otherwise keep the legacy one + request threshold. + """ + required = 2 if self.config.enable_tbo else 1 + count = self._count_admittable_head_prefills(required) + return count > 0, count >= required def _kv_usage(self) -> float: """Fraction of KV-cache blocks currently in use ∈ [0, 1]. @@ -700,9 +719,13 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: # ─── Cross-DP prefill alignment (PrefillDelayer) ─────────────── _delayer_allows_prefill = True if self.prefill_delayer is not None: + _local_prefillable, _local_alignment_ready = ( + self._prefill_delayer_readiness() + ) _delayer_allows_prefill = self.prefill_delayer.should_allow_prefill( - local_prefillable=self._can_admit_head_prefill(), + local_prefillable=_local_prefillable, token_usage=self._kv_usage(), + local_alignment_ready=_local_alignment_ready, ) if not self.running and not self.waiting: From 029692674bebcf37d75b9b96cbcfb72a3f1273e5 Mon Sep 17 00:00:00 2001 From: jpy794 Date: Thu, 25 Jun 2026 11:43:45 +0000 Subject: [PATCH 05/11] improv (don't upstream): optimize model loading speed in MI355X (this perf bug seems MI355X only) --- atom/model_loader/loader.py | 214 ++++++++++++++++++++++++++++++++++-- atom/model_ops/moe.py | 108 ++++++++++++++++++ atom/utils/envs.py | 7 ++ 3 files changed, 322 insertions(+), 7 deletions(-) diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 42b7fc121..4a19edf76 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -6,6 +6,7 @@ import os import logging import re +import threading import time from glob import glob from typing import Generator, Tuple @@ -401,11 +402,162 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: # rewritten name doesn't correspond to any model param. (orig, mapped) pairs. dropped_ckpt_keys: list[tuple[str, str]] = [] - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] + # --- Batched MoE expert loading --------------------------------------- + # Large MoE checkpoints (Kimi-K2.5 etc.) deliver each expert's weight as a + # separate tensor. The per-expert path then issues one tiny H2D copy per + # (expert, shard), i.e. hundreds of thousands of copies, each paying kernel + # launch + pageable-bounce-buffer + GIL overhead — this dominates load time + # (it stalls *after* the shard progress bar, in the futures drain). When + # ATOM_BATCH_EXPERT_LOAD is on, we instead accumulate every expert of a + # fused param into one pinned CPU staging buffer and flush it to the GPU + # with a single large H2D copy once all expected arrivals have landed. + batch_expert_load = envs.ATOM_BATCH_EXPERT_LOAD + + staging_map: dict = {} + staging_lock = threading.Lock() + + moe_module_cache: dict = {} + param_batchable: dict = {} + + def _lookup_moe_module(full_param_name: str): + module_path = full_param_name.rsplit(".", 1)[0] + if module_path in moe_module_cache: + return moe_module_cache[module_path] + try: + mod = model.get_submodule(module_path) + except AttributeError: + mod = None + moe_module_cache[module_path] = mod + return mod + + def _param_is_batchable(param, full_param_name: str) -> bool: + pid = id(param) + cached = param_batchable.get(pid) + if cached is not None: + return cached + moe = _lookup_moe_module(full_param_name) + ok = False + if moe is not None and hasattr(moe, "stage_expert_weight"): + expected = moe.expected_batched_arrivals(param) + ok = expected is not None and expected > 0 + param_batchable[pid] = ok + return ok + + def _do_flush(param, staging): + if staging.dtype != param.data.dtype: + param.data.view(torch.uint8).copy_(staging) + else: + param.data.copy_(staging) + + def _stage_task(param, full_param_name, shard_id, global_expert_id, loaded_weight): + pid = id(param) + + with staging_lock: + entry = staging_map.get(pid) + if entry is None: + moe = _lookup_moe_module(full_param_name) + expected = moe.expected_batched_arrivals(param) + + pin = torch.cuda.is_available() + + def _alloc_like(): + t = torch.empty( + param.data.shape, + dtype=param.data.dtype, + device="cpu", + pin_memory=pin, + ) + t.zero_() + return t + + def _alloc_uint8(): + t = torch.empty( + param.data.shape, + dtype=torch.uint8, + device="cpu", + pin_memory=pin, + ) + t.zero_() + return t + + try: + try: + staging = _alloc_like() + except NotImplementedError: + staging = _alloc_uint8() + except RuntimeError as e: + logger.warning( + "Pinned-memory allocation failed for %s (%s); " + "falling back to unpinned staging.", + full_param_name, + e, + ) + pin = False + try: + staging = _alloc_like() + except NotImplementedError: + staging = _alloc_uint8() + entry = [ + "active", + staging, + 0, + expected, + moe, + param, + threading.Lock(), + ] + staging_map[pid] = entry + + if entry[0] == "fallback": + wl = getattr(param, "weight_loader", default_weight_loader) + wl(param, loaded_weight, full_param_name, shard_id, global_expert_id) + return + + _, staging, _, expected, moe, _, entry_lock = entry + + local_eid = moe._map_global_expert_id_to_local_expert_id(global_expert_id) + if local_eid == -1: + return + + ok = moe.stage_expert_weight( + param=param, + staging=staging, + loaded_weight=loaded_weight, + local_expert_id=local_eid, + shard_id=shard_id, + weight_name=full_param_name, + ) + + if not ok: + with staging_lock: + cur = staging_map.get(pid) + if cur is entry: + staging_map[pid] = ("fallback",) + logger.warning( + "stage_expert_weight returned False mid-batch for %s " + "(shard_id=%s). Falling back to per-expert path.", + full_param_name, + shard_id, + ) + wl = getattr(param, "weight_loader", default_weight_loader) + wl(param, loaded_weight, full_param_name, shard_id, global_expert_id) + return + + flush_now = False + with entry_lock: + entry[2] += 1 + if entry[2] >= expected: + flush_now = True + + if flush_now: + _do_flush(param, staging) + with staging_lock: + if staging_map.get(pid) is entry: + del staging_map[pid] + use_threadpool = envs.ATOM_LOADER_USE_THREADPOOL if use_threadpool: - executor = concurrent.futures.ThreadPoolExecutor() + executor = concurrent.futures.ThreadPoolExecutor(max_workers=16) else: executor = None futures = [] @@ -593,13 +745,24 @@ def _submit(fn, *args): if "mtp" in name and not spec_decode: matched = True break - try: - param = model.get_parameter(name) - except AttributeError: + param = params_dict.get(name) + if param is None: # Parameter absent from model (e.g. weight scales for # an unquantized drafter MTP block); skip silently. matched = True break + if batch_expert_load and _param_is_batchable(param, name): + _submit( + _stage_task, + param, + name, + shard_id, + expert_id, + weight_tensor, + ) + loaded_weights_record.add(prefix + name) + matched = True + break weight_loader = getattr(param, "weight_loader") _submit( weight_loader, @@ -657,9 +820,46 @@ def _submit(fn, *args): ) _submit(weight_loader, param, weight_tensor) loaded_weights_record.add(prefix + name) + + if executor is not None: + enable_tqdm = ( + not torch.distributed.is_initialized() + or torch.distributed.get_rank() == 0 + ) + wait_iter = concurrent.futures.as_completed(futures) + if enable_tqdm: + wait_iter = tqdm( + wait_iter, + total=len(futures), + desc="Loading weights", + mininterval=1.0, + ) + for future in wait_iter: + future.result() + + # Any staging group that never reached its expected arrival count + # (e.g. a checkpoint missing some experts) is flushed here with whatever + # data landed, so the param isn't silently left at its init value. + with staging_lock: + pending = [e for e in staging_map.values() if e[0] == "active"] + staging_map.clear() + if pending: + logger.warning( + "Batched-load safety flush: %d group(s) did not reach " + "expected arrival count; flushing with partial data.", + len(pending), + ) + for entry in pending: + _, staging, arrived, expected, _moe, param, _ = entry + logger.warning( + " param shape=%s arrived=%d expected=%d", + tuple(param.shape), + arrived, + expected, + ) + _do_flush(param, staging) finally: if executor is not None: - concurrent.futures.wait(futures) executor.shutdown(wait=True) # Verify every model parameter actually got loaded from the checkpoint. diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 86a1264ea..8684acd2f 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -3076,6 +3076,114 @@ def mxf4_merged_weight_loader( narrow_weight = loaded_weight target_param[: narrow_weight.shape[0]].copy_(narrow_weight) + def stage_expert_weight( + self, + param: torch.nn.Parameter, + staging: torch.Tensor, + loaded_weight: torch.Tensor, + local_expert_id: int, + shard_id: str, + weight_name: str, + ) -> bool: + """Write one expert's slice into the CPU staging buffer (batched load). + + Mirrors the per-expert work `weight_loader` does, but targets a CPU + staging slice instead of the GPU param so the loader can flush all + experts of a fused param with a single H2D copy. Returns False for any + case not safe to batch here; the caller then falls back to the + per-expert `weight_loader` path. + """ + # Cases not handled by the batched path — caller falls back. + if ( + "input_scale" in weight_name + or "g_idx" in weight_name + or "weight_shape" in weight_name + or weight_name == "" + ): + return False + if shard_id not in ("w1", "w2", "w3"): + return False + + # compressed-tensors packed-weight flip (matches weight_loader) + if self.quant_method.__class__.__name__ in ( + "CompressedTensorsWNA16MarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ): + loaded_weight = loaded_weight.t().contiguous() + + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} + is_transposed = getattr(param, "is_transposed", False) + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + shard_dim = int(not shard_dim) + + if len(loaded_weight.shape) == 3: + return False + + expert_data = staging[local_expert_id] + + if ( + staging.dtype == torch.uint8 + and param.data.dtype == dtypes.fp4x2 + and loaded_weight.dtype == dtypes.fp4x2 + ): + loaded_weight = loaded_weight.view(torch.uint8) + + # --- scale / zero / offset paths --- + if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name: + quant_method = self.layer_quant_config.quant_type + if quant_method == QuantType.per_Token: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + ) + return True + if quant_method in (QuantType.per_1x128, QuantType.per_1x32): + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + load_full=getattr(param, "load_full_w2", False), + ) + return True + # per_Tensor or unknown quant_method — not safe to batch here. + return False + + # --- main weight path --- + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + ) + return True + + return False + + def expected_batched_arrivals(self, param: torch.nn.Parameter) -> Optional[int]: + """How many checkpoint tensors will land in this fused param's staging + buffer before it is complete. w13 packs gate+up, so it receives two + shards (w1, w3) per local expert; w2 receives one per local expert.""" + w13_params = [ + getattr(self, n, None) + for n in ("w13_weight", "w13_weight_scale", "w13_bias") + ] + w2_params = [ + getattr(self, n, None) for n in ("w2_weight", "w2_weight_scale", "w2_bias") + ] + if any(param is p for p in w13_params if p is not None): + return self.local_num_experts * 2 + if any(param is p for p in w2_params if p is not None): + return self.local_num_experts + return None + def weight_loader( self, param: torch.nn.Parameter, diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 3bd2a53ab..8931d734d 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -106,6 +106,13 @@ "ATOM_LOADER_USE_THREADPOOL": lambda: ( os.getenv("ATOM_LOADER_USE_THREADPOOL", "1") == "1" ), + # When enabled (default), MoE expert tensors that arrive from the + # checkpoint are accumulated into a per-fused-param CPU staging buffer, + # and a single large H2D copy is submitted once the group fills. This + # collapses the hundreds-of-thousands of tiny per-expert H2D copies (which + # dominate load time on large MoE models like Kimi-K2.5) into one big copy + # per fused param. Set to 0 to revert to the per-expert path. + "ATOM_BATCH_EXPERT_LOAD": lambda: os.getenv("ATOM_BATCH_EXPERT_LOAD", "1") == "1", # --- Attention Backend --- # Use unified_attention (flash-style) for MHA paged/prefill attention instead # of pa_decode_gluon. Set to 1 to enable the unified_attention path. From cc42200d621befea88d235fc00db64381ffcc61f Mon Sep 17 00:00:00 2001 From: jpy794 Date: Fri, 26 Jun 2026 08:18:31 +0000 Subject: [PATCH 06/11] Zero MoE all-gather padding rows --- atom/model_ops/moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 8684acd2f..42725e9a4 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -234,7 +234,10 @@ def pad_for_all_gather(x: torch.Tensor) -> Tuple[torch.Tensor, int]: padding_shape[0] = max_batch_size padded_x = torch.empty(padding_shape, device=x.device, dtype=x.dtype) padded_x[:original_batch_size, :].copy_(x) - # padded_x[original_batch_size:, :].zero_() + # Padded rows still enter fused-MoE routing/sort/dispatch before being + # sliced away after reduce-scatter; uninitialized NaN/Inf rows can perturb + # expert buckets or shared scratch and corrupt real tokens. + padded_x[original_batch_size:, :].zero_() return padded_x, original_batch_size From 32ae0db81761f4ce4586e983c214cd5eebc88858 Mon Sep 17 00:00:00 2001 From: jpy794 Date: Fri, 26 Jun 2026 08:19:48 +0000 Subject: [PATCH 07/11] Revert "improv (don't upstream): optimize model loading speed in MI355X (this perf bug seems MI355X only)" This reverts commit 0a7feabe08da10d9fae3218451e0fe4fc1990d33. --- atom/model_loader/loader.py | 214 ++---------------------------------- atom/model_ops/moe.py | 108 ------------------ atom/utils/envs.py | 7 -- 3 files changed, 7 insertions(+), 322 deletions(-) diff --git a/atom/model_loader/loader.py b/atom/model_loader/loader.py index 4a19edf76..42b7fc121 100644 --- a/atom/model_loader/loader.py +++ b/atom/model_loader/loader.py @@ -6,7 +6,6 @@ import os import logging import re -import threading import time from glob import glob from typing import Generator, Tuple @@ -402,162 +401,11 @@ def extract_expert_target_and_id(name: str) -> Tuple[str, int] | None: # rewritten name doesn't correspond to any model param. (orig, mapped) pairs. dropped_ckpt_keys: list[tuple[str, str]] = [] - # --- Batched MoE expert loading --------------------------------------- - # Large MoE checkpoints (Kimi-K2.5 etc.) deliver each expert's weight as a - # separate tensor. The per-expert path then issues one tiny H2D copy per - # (expert, shard), i.e. hundreds of thousands of copies, each paying kernel - # launch + pageable-bounce-buffer + GIL overhead — this dominates load time - # (it stalls *after* the shard progress bar, in the futures drain). When - # ATOM_BATCH_EXPERT_LOAD is on, we instead accumulate every expert of a - # fused param into one pinned CPU staging buffer and flush it to the GPU - # with a single large H2D copy once all expected arrivals have landed. - batch_expert_load = envs.ATOM_BATCH_EXPERT_LOAD - - staging_map: dict = {} - staging_lock = threading.Lock() - - moe_module_cache: dict = {} - param_batchable: dict = {} - - def _lookup_moe_module(full_param_name: str): - module_path = full_param_name.rsplit(".", 1)[0] - if module_path in moe_module_cache: - return moe_module_cache[module_path] - try: - mod = model.get_submodule(module_path) - except AttributeError: - mod = None - moe_module_cache[module_path] = mod - return mod - - def _param_is_batchable(param, full_param_name: str) -> bool: - pid = id(param) - cached = param_batchable.get(pid) - if cached is not None: - return cached - moe = _lookup_moe_module(full_param_name) - ok = False - if moe is not None and hasattr(moe, "stage_expert_weight"): - expected = moe.expected_batched_arrivals(param) - ok = expected is not None and expected > 0 - param_batchable[pid] = ok - return ok - - def _do_flush(param, staging): - if staging.dtype != param.data.dtype: - param.data.view(torch.uint8).copy_(staging) - else: - param.data.copy_(staging) - - def _stage_task(param, full_param_name, shard_id, global_expert_id, loaded_weight): - pid = id(param) - - with staging_lock: - entry = staging_map.get(pid) - if entry is None: - moe = _lookup_moe_module(full_param_name) - expected = moe.expected_batched_arrivals(param) - - pin = torch.cuda.is_available() - - def _alloc_like(): - t = torch.empty( - param.data.shape, - dtype=param.data.dtype, - device="cpu", - pin_memory=pin, - ) - t.zero_() - return t - - def _alloc_uint8(): - t = torch.empty( - param.data.shape, - dtype=torch.uint8, - device="cpu", - pin_memory=pin, - ) - t.zero_() - return t - - try: - try: - staging = _alloc_like() - except NotImplementedError: - staging = _alloc_uint8() - except RuntimeError as e: - logger.warning( - "Pinned-memory allocation failed for %s (%s); " - "falling back to unpinned staging.", - full_param_name, - e, - ) - pin = False - try: - staging = _alloc_like() - except NotImplementedError: - staging = _alloc_uint8() - entry = [ - "active", - staging, - 0, - expected, - moe, - param, - threading.Lock(), - ] - staging_map[pid] = entry - - if entry[0] == "fallback": - wl = getattr(param, "weight_loader", default_weight_loader) - wl(param, loaded_weight, full_param_name, shard_id, global_expert_id) - return - - _, staging, _, expected, moe, _, entry_lock = entry - - local_eid = moe._map_global_expert_id_to_local_expert_id(global_expert_id) - if local_eid == -1: - return - - ok = moe.stage_expert_weight( - param=param, - staging=staging, - loaded_weight=loaded_weight, - local_expert_id=local_eid, - shard_id=shard_id, - weight_name=full_param_name, - ) - - if not ok: - with staging_lock: - cur = staging_map.get(pid) - if cur is entry: - staging_map[pid] = ("fallback",) - logger.warning( - "stage_expert_weight returned False mid-batch for %s " - "(shard_id=%s). Falling back to per-expert path.", - full_param_name, - shard_id, - ) - wl = getattr(param, "weight_loader", default_weight_loader) - wl(param, loaded_weight, full_param_name, shard_id, global_expert_id) - return - - flush_now = False - with entry_lock: - entry[2] += 1 - if entry[2] >= expected: - flush_now = True - - if flush_now: - _do_flush(param, staging) - with staging_lock: - if staging_map.get(pid) is entry: - del staging_map[pid] - + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] use_threadpool = envs.ATOM_LOADER_USE_THREADPOOL if use_threadpool: - executor = concurrent.futures.ThreadPoolExecutor(max_workers=16) + executor = concurrent.futures.ThreadPoolExecutor() else: executor = None futures = [] @@ -745,24 +593,13 @@ def _submit(fn, *args): if "mtp" in name and not spec_decode: matched = True break - param = params_dict.get(name) - if param is None: + try: + param = model.get_parameter(name) + except AttributeError: # Parameter absent from model (e.g. weight scales for # an unquantized drafter MTP block); skip silently. matched = True break - if batch_expert_load and _param_is_batchable(param, name): - _submit( - _stage_task, - param, - name, - shard_id, - expert_id, - weight_tensor, - ) - loaded_weights_record.add(prefix + name) - matched = True - break weight_loader = getattr(param, "weight_loader") _submit( weight_loader, @@ -820,46 +657,9 @@ def _submit(fn, *args): ) _submit(weight_loader, param, weight_tensor) loaded_weights_record.add(prefix + name) - - if executor is not None: - enable_tqdm = ( - not torch.distributed.is_initialized() - or torch.distributed.get_rank() == 0 - ) - wait_iter = concurrent.futures.as_completed(futures) - if enable_tqdm: - wait_iter = tqdm( - wait_iter, - total=len(futures), - desc="Loading weights", - mininterval=1.0, - ) - for future in wait_iter: - future.result() - - # Any staging group that never reached its expected arrival count - # (e.g. a checkpoint missing some experts) is flushed here with whatever - # data landed, so the param isn't silently left at its init value. - with staging_lock: - pending = [e for e in staging_map.values() if e[0] == "active"] - staging_map.clear() - if pending: - logger.warning( - "Batched-load safety flush: %d group(s) did not reach " - "expected arrival count; flushing with partial data.", - len(pending), - ) - for entry in pending: - _, staging, arrived, expected, _moe, param, _ = entry - logger.warning( - " param shape=%s arrived=%d expected=%d", - tuple(param.shape), - arrived, - expected, - ) - _do_flush(param, staging) finally: if executor is not None: + concurrent.futures.wait(futures) executor.shutdown(wait=True) # Verify every model parameter actually got loaded from the checkpoint. diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 42725e9a4..cfc3998de 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -3079,114 +3079,6 @@ def mxf4_merged_weight_loader( narrow_weight = loaded_weight target_param[: narrow_weight.shape[0]].copy_(narrow_weight) - def stage_expert_weight( - self, - param: torch.nn.Parameter, - staging: torch.Tensor, - loaded_weight: torch.Tensor, - local_expert_id: int, - shard_id: str, - weight_name: str, - ) -> bool: - """Write one expert's slice into the CPU staging buffer (batched load). - - Mirrors the per-expert work `weight_loader` does, but targets a CPU - staging slice instead of the GPU param so the loader can flush all - experts of a fused param with a single H2D copy. Returns False for any - case not safe to batch here; the caller then falls back to the - per-expert `weight_loader` path. - """ - # Cases not handled by the batched path — caller falls back. - if ( - "input_scale" in weight_name - or "g_idx" in weight_name - or "weight_shape" in weight_name - or weight_name == "" - ): - return False - if shard_id not in ("w1", "w2", "w3"): - return False - - # compressed-tensors packed-weight flip (matches weight_loader) - if self.quant_method.__class__.__name__ in ( - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod", - ): - loaded_weight = loaded_weight.t().contiguous() - - SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} - is_transposed = getattr(param, "is_transposed", False) - shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] - if is_transposed: - shard_dim = int(not shard_dim) - - if len(loaded_weight.shape) == 3: - return False - - expert_data = staging[local_expert_id] - - if ( - staging.dtype == torch.uint8 - and param.data.dtype == dtypes.fp4x2 - and loaded_weight.dtype == dtypes.fp4x2 - ): - loaded_weight = loaded_weight.view(torch.uint8) - - # --- scale / zero / offset paths --- - if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name: - quant_method = self.layer_quant_config.quant_type - if quant_method == QuantType.per_Token: - self._load_per_channel_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank, - ) - return True - if quant_method in (QuantType.per_1x128, QuantType.per_1x32): - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank, - load_full=getattr(param, "load_full_w2", False), - ) - return True - # per_Tensor or unknown quant_method — not safe to batch here. - return False - - # --- main weight path --- - if "weight" in weight_name: - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank, - ) - return True - - return False - - def expected_batched_arrivals(self, param: torch.nn.Parameter) -> Optional[int]: - """How many checkpoint tensors will land in this fused param's staging - buffer before it is complete. w13 packs gate+up, so it receives two - shards (w1, w3) per local expert; w2 receives one per local expert.""" - w13_params = [ - getattr(self, n, None) - for n in ("w13_weight", "w13_weight_scale", "w13_bias") - ] - w2_params = [ - getattr(self, n, None) for n in ("w2_weight", "w2_weight_scale", "w2_bias") - ] - if any(param is p for p in w13_params if p is not None): - return self.local_num_experts * 2 - if any(param is p for p in w2_params if p is not None): - return self.local_num_experts - return None - def weight_loader( self, param: torch.nn.Parameter, diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 8931d734d..3bd2a53ab 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -106,13 +106,6 @@ "ATOM_LOADER_USE_THREADPOOL": lambda: ( os.getenv("ATOM_LOADER_USE_THREADPOOL", "1") == "1" ), - # When enabled (default), MoE expert tensors that arrive from the - # checkpoint are accumulated into a per-fused-param CPU staging buffer, - # and a single large H2D copy is submitted once the group fills. This - # collapses the hundreds-of-thousands of tiny per-expert H2D copies (which - # dominate load time on large MoE models like Kimi-K2.5) into one big copy - # per fused param. Set to 0 to revert to the per-expert path. - "ATOM_BATCH_EXPERT_LOAD": lambda: os.getenv("ATOM_BATCH_EXPERT_LOAD", "1") == "1", # --- Attention Backend --- # Use unified_attention (flash-style) for MHA paged/prefill attention instead # of pa_decode_gluon. Set to 1 to enable the unified_attention path. From 426f176f108b668452604e66eccaac041b19185a Mon Sep 17 00:00:00 2001 From: jpy794 Date: Fri, 26 Jun 2026 11:53:14 +0000 Subject: [PATCH 08/11] Propagate DP mode to TBO ubatch contexts --- atom/utils/tbo/ubatch_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/atom/utils/tbo/ubatch_wrapper.py b/atom/utils/tbo/ubatch_wrapper.py index 3873ac64e..c1d4d16d7 100644 --- a/atom/utils/tbo/ubatch_wrapper.py +++ b/atom/utils/tbo/ubatch_wrapper.py @@ -452,6 +452,7 @@ def _make_ubatch_context( batch_size=ub_num_reqs, graph_bs=graph_bs, is_draft=ctx.context.is_draft, + dp_uniform_decode=ctx.context.dp_uniform_decode, ) return ForwardContext( From 05cdd91bfbdfba803741ed797b5197b944327ff4 Mon Sep 17 00:00:00 2001 From: jpy794 Date: Fri, 26 Jun 2026 13:31:59 +0000 Subject: [PATCH 09/11] enable shared expert fusion in DP gather/scatter mode --- atom/model_ops/topK.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/atom/model_ops/topK.py b/atom/model_ops/topK.py index 74cc2ddf0..11b72df4f 100644 --- a/atom/model_ops/topK.py +++ b/atom/model_ops/topK.py @@ -21,7 +21,13 @@ def is_rocm_aiter_fusion_shared_expert_enabled_for_quant_config( quant_config = config.quant_config dp_size = config.parallel_config.data_parallel_size - if dp_size > 1 and _has_module("mori") and config.enable_dp_attention: + use_mori_all2all = ( + dp_size > 1 + and _has_module("mori") + and config.enable_dp_attention + and config.enable_expert_parallel + ) + if use_mori_all2all: return False if quant_config is not None and shared_expert_prefix is not None: @@ -62,15 +68,6 @@ def is_rocm_aiter_fusion_shared_expert_enabled_for_quant_config( return False break - dp_size = config.parallel_config.data_parallel_size - use_mori_all2all = ( - dp_size > 1 - and _has_module("mori") - and config.enable_dp_attention - and config.enable_expert_parallel - ) - if use_mori_all2all: - return False return True From a97aec0447c3c3268686223b48b2808dc272a08a Mon Sep 17 00:00:00 2001 From: jpy794 Date: Sat, 16 May 2026 15:38:34 +0000 Subject: [PATCH 10/11] enable fused vup rope --- atom/model_ops/attention_mla.py | 172 +++++++++++++++++--------------- 1 file changed, 91 insertions(+), 81 deletions(-) diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index 9061e18d0..0ee1c509c 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -1171,92 +1171,102 @@ def forward_impl( prefill_q, k_nope, k_rope, kv_cache, attn_metadata ) else: - q_nope, q_rope = self._q_proj_and_k_up_proj(q, x_scale=q_scale) - - if self.use_seg_mla: - # Seg path: allocate q_out with a padded last dim so each head row - # has a 768-byte stride (required by the gfx1250 decode asm). The - # kernel only writes the first kv_lora_rank + qk_rope_head_dim - # columns; the padding tail is left untouched and never read. - q_out = torch.empty( - ( - q_nope.shape[0], - self.num_heads, - _MLA_Q_OUT_PADDED_DIM, - ), - dtype=attn_metadata.dtype_q, - device=q_nope.device, - ) - else: - q_out = torch.empty( - ( - q_nope.shape[0], - self.num_heads, - self.kv_lora_rank + self.qk_rope_head_dim, - ), - dtype=attn_metadata.dtype_q, - device=q_nope.device, - ) - if kv_cache.numel() > 0: - if envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV: - shuffled_cache = self._shuffled_kv_view(kv_cache) - triton_fused_qk_rope_cat_and_cache_mla( - q_nope, - q_rope, - k_nope.view(-1, self.num_kv_heads, self.kv_lora_rank), - k_rope.view(-1, self.num_kv_heads, self.qk_rope_head_dim), - shuffled_cache, - attn_metadata.slot_mapping, - positions, - self.rotary_emb.cos_cache, - self.rotary_emb.sin_cache, - self._k_scale, - self.rotary_emb.is_neox_style, - num_decode_toks_for_zeros=0, - apply_scale=True, - q_out=q_out, - shuffled_kv_cache=True, - ) - elif self.use_seg_mla: - kv_cache_seg = self._seg_kv_cache_view(kv_cache) - fused_qk_rope_concat_and_cache_mla_seg( - q_nope, - q_rope, - k_nope, - k_rope, - # Flat seg layout: [num_blocks, page_size*(kv_lora + pe)]. - kv_cache_seg, - q_out, - attn_metadata.slot_mapping, - self._k_scale, - self._q_scale, - positions, - self.rotary_emb.cos_cache, - self.rotary_emb.sin_cache, - is_neox=self.rotary_emb.is_neox_style, + use_shuffle_kv = ( + envs.ATOM_USE_TRITON_MLA and envs.ATOM_USE_TRITON_MLA_SHUFFLE_KV + ) + needs_explicit_q_cache = ( + attn_metadata.max_seqlen_q > 1 or self.use_seg_mla or use_shuffle_kv + ) + if needs_explicit_q_cache: + q_nope, q_rope = self._q_proj_and_k_up_proj(q, x_scale=q_scale) + + if self.use_seg_mla: + # Seg path: allocate q_out with a padded last dim so each head row + # has a 768-byte stride (required by the gfx1250 decode asm). The + # kernel only writes the first kv_lora_rank + qk_rope_head_dim + # columns; the padding tail is left untouched and never read. + q_out = torch.empty( + ( + q_nope.shape[0], + self.num_heads, + _MLA_Q_OUT_PADDED_DIM, + ), + dtype=attn_metadata.dtype_q, + device=q_nope.device, ) else: - fused_qk_rope_concat_and_cache_mla( - q_nope, - q_rope, - k_nope, - k_rope, - kv_cache.view( - kv_cache.shape[0], - -1, + q_out = torch.empty( + ( + q_nope.shape[0], + self.num_heads, self.kv_lora_rank + self.qk_rope_head_dim, ), - q_out, - attn_metadata.slot_mapping, - self._k_scale, - self._q_scale, - positions, - self.rotary_emb.cos_cache, - self.rotary_emb.sin_cache, - is_neox=self.rotary_emb.is_neox_style, - is_nope_first=True, + dtype=attn_metadata.dtype_q, + device=q_nope.device, ) - # q_out = self.fused_kv_bmm(q, q_scale, k_nope, k_rope, positions, kv_cache, attn_metadata) + if kv_cache.numel() > 0: + if use_shuffle_kv: + shuffled_cache = self._shuffled_kv_view(kv_cache) + triton_fused_qk_rope_cat_and_cache_mla( + q_nope, + q_rope, + k_nope.view(-1, self.num_kv_heads, self.kv_lora_rank), + k_rope.view(-1, self.num_kv_heads, self.qk_rope_head_dim), + shuffled_cache, + attn_metadata.slot_mapping, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + self._k_scale, + self.rotary_emb.is_neox_style, + num_decode_toks_for_zeros=0, + apply_scale=True, + q_out=q_out, + shuffled_kv_cache=True, + ) + elif self.use_seg_mla: + kv_cache_seg = self._seg_kv_cache_view(kv_cache) + fused_qk_rope_concat_and_cache_mla_seg( + q_nope, + q_rope, + k_nope, + k_rope, + # Flat seg layout: [num_blocks, page_size*(kv_lora + pe)]. + kv_cache_seg, + q_out, + attn_metadata.slot_mapping, + self._k_scale, + self._q_scale, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + is_neox=self.rotary_emb.is_neox_style, + ) + else: + fused_qk_rope_concat_and_cache_mla( + q_nope, + q_rope, + k_nope, + k_rope, + kv_cache.view( + kv_cache.shape[0], + -1, + self.kv_lora_rank + self.qk_rope_head_dim, + ), + q_out, + attn_metadata.slot_mapping, + self._k_scale, + self._q_scale, + positions, + self.rotary_emb.cos_cache, + self.rotary_emb.sin_cache, + is_neox=self.rotary_emb.is_neox_style, + is_nope_first=True, + ) + elif kv_cache.numel() > 0: + q_out = self.fused_kv_bmm( + q, q_scale, k_nope, k_rope, positions, kv_cache, attn_metadata + ) if context.is_prefill: output = self._forward_prefill_mla(q_out, kv_cache, attn_metadata) From 995a56d7c1d0ea5e1354ead4af8431a303684fbd Mon Sep 17 00:00:00 2001 From: jpy794 Date: Mon, 29 Jun 2026 08:03:53 +0000 Subject: [PATCH 11/11] configire prefill delay batch from env --- atom/model_engine/scheduler.py | 12 ++++++------ atom/utils/envs.py | 5 +++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 3af02537f..2603dc469 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -29,6 +29,7 @@ from atom.model_engine.block_manager import BlockManager from atom.model_engine.request import RequestOutput from atom.model_engine.sequence import Sequence, SequenceStatus, SequenceType +from atom.utils import envs logger = logging.getLogger("atom") @@ -550,13 +551,12 @@ def _count_admittable_head_prefills(self, limit: int) -> int: def _prefill_delayer_readiness(self) -> tuple[bool, bool]: """Return the local presence and alignment bits for PrefillDelayer. - TBO prefill splitting needs at least two local prefill requests. - When TBO is enabled, wait for each DP rank to be able to admit two - requests before reporting "ready"; otherwise keep the legacy one - request threshold. + ``ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS`` controls how many local + prefill requests this rank must be able to admit before reporting + alignment-ready. A value of 0 disables the local-count threshold. """ - required = 2 if self.config.enable_tbo else 1 - count = self._count_admittable_head_prefills(required) + required = max(envs.ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS, 0) + count = self._count_admittable_head_prefills(max(required, 1)) return count > 0, count >= required def _kv_usage(self) -> float: diff --git a/atom/utils/envs.py b/atom/utils/envs.py index 3bd2a53ab..4e07a669f 100644 --- a/atom/utils/envs.py +++ b/atom/utils/envs.py @@ -230,6 +230,11 @@ if os.getenv("ATOM_PREFILL_DELAYER_TOKEN_USAGE_LOW_WATERMARK", "") == "" else float(os.getenv("ATOM_PREFILL_DELAYER_TOKEN_USAGE_LOW_WATERMARK")) ), + # Number of local prefill requests required before reporting alignment-ready. + # Default 0 disables this local-count threshold. + "ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS": lambda: int( + os.getenv("ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS", "0") + ), # --- TBO prefill ubatch splitting --- # Split prefill ubatches at the exact token midpoint (vLLM-DBO style), # cutting through a request if needed for perfectly balanced 50/50 ubatches.