Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions tensorrt_llm/_torch/models/modeling_deepseekv4.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,18 @@ def _resolve_enable_fused_hc(config: PretrainedConfig) -> bool:
return bool(getattr(config, "enable_fused_hc", True))


def _resolve_skip_premoe_allreduce() -> bool:
"""Resolve whether to skip the redundant PRE_MOE_FUSION allreduce.

When enabled (default), the RMSNorm is folded into hc_ffn.fused_hc's
epilogue and the allreduce in forward_MoE is skipped — the data is
already full after attention's internal o_b_proj allreduce.
Set TRTLLM_DSV4_SKIP_PREMOE_ALLREDUCE=0 for ablation (baseline behavior).
"""
env = os.environ.get("TRTLLM_DSV4_SKIP_PREMOE_ALLREDUCE", "1")
return env not in ("0", "false", "False")


def _copy_deepseek_v4_fused_a_weight_scale(
module: Linear, fused_a: torch.Tensor, fused_a_scale: torch.Tensor
) -> None:
Expand Down Expand Up @@ -1860,6 +1872,11 @@ def __init__(
# layers already take). Env var TRTLLM_MHC_ENABLE_FUSED_HC overrides the
# config attr (set to "0" to force-disable for validation/rollback).
self.enable_fused_hc = _resolve_enable_fused_hc(config)
# Skip-premoe-allreduce: when fused_hc is active AND PRE_MOE_FUSION is
# on, fold post_attention_layernorm into hc_ffn.fused_hc and skip the
# redundant allreduce in forward_MoE (the data is already full after
# attention's internal o_b_proj allreduce).
self.skip_premoe_allreduce = _resolve_skip_premoe_allreduce() and self.enable_fused_hc
self.next_layer_layernorm: RMSNorm = None
# Finalized in DeepseekV4ForCausalLM.post_load_weights once the full layer
# list is visible: a layer may defer its hc_ffn.post_mapping only if
Expand Down Expand Up @@ -2000,11 +2017,24 @@ def forward(
if spec_metadata is not None and spec_metadata.is_layer_capture(self.layer_idx):
self.fusion_config.POST_MOE_FUSION = False
if self.enable_fused_hc:
# When skip_premoe_allreduce is active, fold post_attention_layernorm
# into fused_hc so layer_input emerges already RMSNorm-normalized.
# This lets us skip the redundant allreduce+norm in forward_MoE.
_norm_w = (
self.post_attention_layernorm.weight
if self.skip_premoe_allreduce and self.fusion_config.PRE_MOE_FUSION
else None
)
_norm_eps = (
self.post_attention_layernorm.variance_epsilon if _norm_w is not None else 0.0
)
residual, post_mix, comb_mix, layer_input = self.hc_ffn.fused_hc(
x_prev=x_attn,
residual_prev=residual,
post_mix_prev=post_mix,
comb_mix_prev=comb_mix,
norm_weight=_norm_w,
norm_eps=_norm_eps,
)
else:
# Break fused_hc into post_mapping and pre_mapping as separate ops.
Expand Down Expand Up @@ -2104,18 +2134,24 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize, input_ids):
)

if self.fusion_config.PRE_MOE_FUSION:
# In DeepSeek-V4 the external residual connection is handled by mHC
# (hc_ffn.post_mapping), so there is no residual to add here.
# Use fused allreduce + RMSNorm (no residual addition).
hidden_states = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RMS_NORM,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
trigger_completion_at_end=False,
),
)
if self.skip_premoe_allreduce:
# Optimization: RMSNorm was already folded into hc_ffn.fused_hc's
# epilogue, and the data is already full (attention's o_b_proj
# allreduced it). Skip the redundant allreduce+norm entirely.
pass
else:
# Baseline: fused allreduce + RMSNorm (no residual addition).
# In DeepSeek-V4 the external residual connection is handled by
# mHC (hc_ffn.post_mapping), so there is no residual to add here.
hidden_states = self.allreduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=AllReduceFusionOp.RMS_NORM,
norm_weight=self.post_attention_layernorm.weight,
eps=self.post_attention_layernorm.variance_epsilon,
trigger_completion_at_end=False,
),
)
else:
# No fusion: just normalize.
hidden_states = self.post_attention_layernorm(hidden_states)
Expand Down
Loading