From 2b092ba3b249a69d73048e510fd1cd832b209e3e Mon Sep 17 00:00:00 2001 From: ganyi Date: Sun, 28 Jun 2026 16:48:52 +0000 Subject: [PATCH 1/2] perf(deepseek): fuse AllReduce + RMSNorm + FP8 quant on input_layernorm (#1226) The combined HIP kernel emits the quantized (fp8, scale) activation that the downstream qkv GEMM consumes directly, removing a standalone per-token/per-group quant kernel from the hot path. - layernorm.py: RMSNorm.forward gains a fused AR+RMS+quant branch dispatching to `..._rmsnorm_quant_per_group` (per_1x128) or `..._rmsnorm_quant` (per_Token), returning ((fp8, scale), residual). - deepseek_v2.py: enable `fused_quant` on input_layernorm when AR fusion is on and fused_qkv_a_proj is per_1x128/per_Token FP8. Mutually exclusive with the existing non-AR `fuse_input_norm_quant` path; the attention forward already unpacks the (fp8, scale) tuple. Co-Authored-By: Claude Opus 4.8 --- atom/model_ops/layernorm.py | 39 +++++++++++++++++++++++++++++++++++++ atom/models/deepseek_v2.py | 24 ++++++++++++++++++++--- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/atom/model_ops/layernorm.py b/atom/model_ops/layernorm.py index 8460ad114d..9d159d10ee 100644 --- a/atom/model_ops/layernorm.py +++ b/atom/model_ops/layernorm.py @@ -14,6 +14,8 @@ ) from aiter.dist.communication_op import ( tensor_model_parallel_fused_allreduce_rmsnorm, + tensor_model_parallel_fused_allreduce_rmsnorm_quant, + tensor_model_parallel_fused_allreduce_rmsnorm_quant_per_group, ) from aiter.dist.parallel_state import get_tensor_model_parallel_world_size from aiter.jit.utils.torch_guard import torch_compile_guard @@ -330,6 +332,43 @@ def forward( x, self.weight, self.eps, residual, self.x_pad_to_multiple ) return x, residual + if ( + self.fused_allreduce + and self.tp_size > 1 + and self.use_fused_quant + and residual is not None + ): + # Combined AllReduce + RMSNorm + FP8 quant: the downstream GEMM + # (e.g. qkv_proj) consumes the (fp8, scale) output directly, skipping + # a separate quant kernel. `_aiter_transpose_scale` (resolved at + # init) selects the scale layout matching that GEMM (column-major for + # the preshuffle path, row-major otherwise). The fused kernel does + # not support non-contiguous input. + assert self.quant_type.value in ( + _QV_PER_1X128, + _QV_PER_TOKEN, + ), "Unsupported quant type for fused allreduce rmsnorm quant" + if self.quant_type.value == _QV_PER_1X128: + x, residual, x_scale = ( + tensor_model_parallel_fused_allreduce_rmsnorm_quant_per_group( + x.contiguous(), + residual, + self.weight, + self.eps, + group_size=128, + transpose_scale=self._aiter_transpose_scale, + ) + ) + else: # _QV_PER_TOKEN + x, residual, x_scale = ( + tensor_model_parallel_fused_allreduce_rmsnorm_quant( + x.contiguous(), + residual, + self.weight, + self.eps, + ) + ) + return (x, x_scale), residual if self.fused_allreduce and self.tp_size > 1: assert ( residual is not None diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index ff47a5fcc1..943dbe7c7d 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -2115,12 +2115,30 @@ def __init__( reduce_results=not self.fuse_ar_input_norm, prefix=f"{prefix}.mlp", ) + # Fuse activation quant into AR+RMSNorm when fused_qkv_a_proj is + # per-1x128/per-token FP8, so the kernel emits the (fp8, scale) it + # consumes directly. Independent of the non-AR fuse_input_norm_quant path. + qkv_proj = getattr(self.self_attn, "fused_qkv_a_proj", None) + input_norm_fused_quant = ( + qkv_proj is not None + and qkv_proj.params_dtype == dtypes.fp8 + and qkv_proj.quant_type.value + in (QuantType.per_1x128.value, QuantType.per_Token.value) + ) + fused_allreduce = ( + self.fuse_ar_input_norm and self.layer_idx > 0 and not is_mtp_block + ) self.input_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps, - fused_allreduce=self.fuse_ar_input_norm - and self.layer_idx > 0 - and not is_mtp_block, + fused_allreduce=fused_allreduce, + fused_quant=fused_allreduce and input_norm_fused_quant, + quant_config=quant_config, + prefix=( + f"{prefix}.self_attn.fused_qkv_a_proj" + if qkv_proj is not None + else f"{prefix}.self_attn.q_a_proj" + ), ) self.post_attention_layernorm = RMSNorm( config.hidden_size, From f12f32810b976a35faf094b3514bffc62c4adc0a Mon Sep 17 00:00:00 2001 From: ganyi Date: Tue, 30 Jun 2026 01:09:45 -0500 Subject: [PATCH 2/2] fix GLM-5.1 --- atom/models/deepseek_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/atom/models/deepseek_v2.py b/atom/models/deepseek_v2.py index 943dbe7c7d..e535145aa0 100644 --- a/atom/models/deepseek_v2.py +++ b/atom/models/deepseek_v2.py @@ -2120,7 +2120,8 @@ def __init__( # consumes directly. Independent of the non-AR fuse_input_norm_quant path. qkv_proj = getattr(self.self_attn, "fused_qkv_a_proj", None) input_norm_fused_quant = ( - qkv_proj is not None + not self.self_attn.is_v32 + and qkv_proj is not None and qkv_proj.params_dtype == dtypes.fp8 and qkv_proj.quant_type.value in (QuantType.per_1x128.value, QuantType.per_Token.value)