-
Notifications
You must be signed in to change notification settings - Fork 82
Enable TBO Support & Fix Accuracy Regressions for Kimi K2.5 #1369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
e7307f5
f3d47e1
e55cdc1
4d5a86b
0296926
cc42200
32ae0db
426f176
05cdd91
a97aec0
995a56d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,6 +80,9 @@ def from_model_config(a_quant_dtype: str | None) -> "MoEActivationQuant": | |
| return MoEActivationQuant.BF16 | ||
|
|
||
|
|
||
| _TBO_KEEPALIVE: dict[tuple[str, int], tuple[torch.Tensor, ...]] = {} | ||
|
|
||
|
|
||
| class FusedMoeWeightScaleSupported(Enum): | ||
| """Supported quantization strategies for MoE weight scales.""" | ||
|
|
||
|
|
@@ -231,7 +234,10 @@ def pad_for_all_gather(x: torch.Tensor) -> Tuple[torch.Tensor, int]: | |
| padding_shape[0] = max_batch_size | ||
| padded_x = torch.empty(padding_shape, device=x.device, dtype=x.dtype) | ||
| padded_x[:original_batch_size, :].copy_(x) | ||
| # padded_x[original_batch_size:, :].zero_() | ||
| # Padded rows still enter fused-MoE routing/sort/dispatch before being | ||
| # sliced away after reduce-scatter; uninitialized NaN/Inf rows can perturb | ||
| # expert buckets or shared scratch and corrupt real tokens. | ||
| padded_x[original_batch_size:, :].zero_() | ||
| return padded_x, original_batch_size | ||
|
|
||
|
|
||
|
|
@@ -2375,6 +2381,17 @@ def __init__( | |
| ), | ||
| dim=0, | ||
| ) | ||
| # In the DP-attn fallback path (dp>1, no MORI all2all), MoE runs | ||
| # after all_gather_with_padding, so the token dim can be dp_size times | ||
| # the per-rank max. | ||
| moe_max_num_tokens = atom_config.max_num_batched_tokens | ||
| if ( | ||
| self.moe_parallel_config.dp_size > 1 | ||
| and not self.moe_parallel_config.use_all2all_kernels | ||
| and atom_config.enable_dp_attention | ||
| ): | ||
| moe_max_num_tokens *= self.moe_parallel_config.dp_size | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why we need moe_max_num_tokens *= self.moe_parallel_config.dp_size here.. In all_gahter and model runner, we have padded, * dp_size here will make BS large and kernel bad perf
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, we only increase the size of the preallocated internal buffer in FusedMoE, not the actual batch size used in the forward pass. This internal buffer needs to be large enough to accommodate tokens from all DP ranks, so we multiply by dp_size, similar to what we've already done for the all-gather / reduce-scatter buffers. |
||
|
|
||
| if fuse_shared_experts and self.num_fused_shared_experts > 0: | ||
| init_aiter_topK_meta_data( | ||
| n_routed_experts=self.global_num_experts, | ||
|
|
@@ -2387,7 +2404,7 @@ def __init__( | |
| if is_rocm_aiter_fuse_routed_scaling_factor() | ||
| else 1 / self.routed_scaling_factor | ||
| ), | ||
| max_num_tokens=atom_config.max_num_batched_tokens, | ||
| max_num_tokens=moe_max_num_tokens, | ||
| is_EP=self.use_ep, | ||
| ) | ||
| if fuse_shared_experts: | ||
|
|
@@ -2427,7 +2444,7 @@ def __init__( | |
| moe_parallel_config=self.moe_parallel_config, | ||
| in_dtype=atom_config.torch_dtype, | ||
| a_quant_dtype=a_quant_dtype, | ||
| max_num_tokens=atom_config.max_num_batched_tokens, | ||
| max_num_tokens=moe_max_num_tokens, | ||
| has_bias=self.has_bias, | ||
| # is_act_and_mul=True, | ||
| is_lora_enabled=False, | ||
|
|
@@ -3325,6 +3342,26 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): | |
| hidden_states, router_logits, self.layer_name | ||
| ) | ||
|
|
||
| def _tbo_keepalive_slot(self) -> int: | ||
| try: | ||
| from atom.utils.tbo.ubatching import tbo_current_ubatch_id | ||
|
|
||
| return tbo_current_ubatch_id() | ||
| except Exception: | ||
| return 0 | ||
|
|
||
| def _hold_tbo_keepalive(self, role: str, *tensors: torch.Tensor) -> None: | ||
| tensors = tuple(tensor for tensor in tensors if tensor is not None) | ||
| if tensors: | ||
| # Keep one previous tensor set per ubatch/role alive globally. | ||
| # The next same-role hold, often in the next MoE layer, happens | ||
| # after this ubatch has waited on the prior comm work, so | ||
| # overwriting here is the delayed safe release point. | ||
| key = (role, self._tbo_keepalive_slot()) | ||
| if key in _TBO_KEEPALIVE: | ||
| del _TBO_KEEPALIVE[key] | ||
| _TBO_KEEPALIVE[key] = tensors | ||
|
|
||
| def forward_impl_graph( | ||
| self, hidden_states: torch.Tensor, router_logits: torch.Tensor | ||
| ): | ||
|
|
@@ -3355,6 +3392,7 @@ def forward_impl_graph( | |
| ) | ||
|
|
||
| tbo_yield_and_switch_from_compute_to_comm() | ||
| self._hold_tbo_keepalive("ag_source", hidden_states, router_logits) | ||
|
|
||
| ( | ||
| hidden_states, | ||
|
|
@@ -3367,6 +3405,7 @@ def forward_impl_graph( | |
|
|
||
| if _tbo: | ||
| tbo_switch_to_compute_sync() | ||
| self._hold_tbo_keepalive("ag_output", hidden_states, router_logits) | ||
|
Comment on lines
3406
to
+3408
|
||
|
|
||
| # Matrix multiply. | ||
| final_hidden_states = self.quant_method.apply( | ||
|
|
@@ -3392,6 +3431,7 @@ def forward_impl_graph( | |
| if use_dp_gather_scatter: | ||
| if _tbo: | ||
| tbo_yield_and_switch_from_compute_to_comm() | ||
| self._hold_tbo_keepalive("rs_source", final_hidden_states) | ||
| if dp_eager_mode: | ||
| final_hidden_states = reduce_scatterv( | ||
| final_hidden_states, sizes, dp_group | ||
|
|
@@ -3402,6 +3442,7 @@ def forward_impl_graph( | |
| ) | ||
| if _tbo: | ||
| tbo_switch_to_compute_sync() | ||
| self._hold_tbo_keepalive("rs_output", final_hidden_states) | ||
|
Comment on lines
3443
to
+3445
|
||
|
|
||
| if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): | ||
| # Default set to False. (May have to add shared expert outputs.) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your changes seem to require >=2 bs for TBO to be ready; does this approach has performance improvement? Or whether it will affect old performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’ve made this behavior configurable via an environment variable (disabled by default to avoid affecting existing performance).
Below is a Kimi K2.5 performance comparison (conc=128, isl=8k, osl=1k) with
ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS=0/2, about 25% throughput gain observed.ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS=0ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS=2There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, that's good news. We will test
ATOM_PREFILL_DELAYER_REQUIRED_PREFILLS=2on deepseek v4 and other models if it's indeed effective