diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 0367630d0..7d605af03 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -331,6 +331,34 @@ def forward( x, self.weight, self.eps, residual, self.x_pad_to_multiple ) return x, residual + if ( + self.fused_allreduce + and self.tp_size > 1 + and self.use_fused_quant + and residual is not None + ): + # Combined AllReduce + RMSNorm + FP8 quant: the downstream GEMM + # (e.g. qkv_proj) consumes the (fp8, scale) output directly, skipping + # a separate quant kernel. `_aiter_transpose_scale` (resolved at + # init) selects the scale layout matching that GEMM (column-major for + # the preshuffle path, row-major otherwise). The fused kernel does + # not support non-contiguous input. + assert self.quant_type.value in ( + _QV_PER_1X128, + _QV_PER_TOKEN, + ), "Unsupported quant type for fused allreduce rmsnorm quant" + # Unified kernel: `quant_type` selects the epilogue (per-token vs + # per-group); group_size / transpose_scale apply to per-group only. + x, residual, x_scale = tensor_model_parallel_fused_allreduce_rmsnorm_quant( + x.contiguous(), + residual, + self.weight, + self.eps, + quant_type=self.quant_type, + group_size=128, + transpose_scale=self._aiter_transpose_scale, + ) + return (x, x_scale), residual if self.fused_allreduce and self.tp_size > 1: assert ( residual is not None diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index bd72e39cb..9bb99e8e6 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -2141,12 +2141,31 @@ def __init__( reduce_results=not self.fuse_ar_input_norm, prefix=f"{prefix}.mlp", ) + # Fuse activation quant into AR+RMSNorm when fused_qkv_a_proj is + # per-1x128/per-token FP8, so the kernel emits the (fp8, scale) it + # consumes directly. Independent of the non-AR fuse_input_norm_quant path. + qkv_proj = getattr(self.self_attn, "fused_qkv_a_proj", None) + input_norm_fused_quant = ( + not self.self_attn.is_v32 + and qkv_proj is not None + and qkv_proj.params_dtype == dtypes.fp8 + and qkv_proj.quant_type.value + in (QuantType.per_1x128.value, QuantType.per_Token.value) + ) + fused_allreduce = ( + self.fuse_ar_input_norm and self.layer_idx > 0 and not is_mtp_block + ) self.input_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, - fused_allreduce=self.fuse_ar_input_norm - and self.layer_idx > 0 - and not is_mtp_block, + fused_allreduce=fused_allreduce, + fused_quant=fused_allreduce and input_norm_fused_quant, + quant_config=quant_config, + prefix=( + f"{prefix}.self_attn.fused_qkv_a_proj" + if qkv_proj is not None + else f"{prefix}.self_attn.q_a_proj" + ), ) self.post_attention_layernorm = RMSNorm( config.hidden_size,