Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
)
from aiter.dist.communication_op import (
tensor_model_parallel_fused_allreduce_rmsnorm,
tensor_model_parallel_fused_allreduce_rmsnorm_quant,
tensor_model_parallel_fused_allreduce_rmsnorm_quant_per_group,
)
from aiter.dist.parallel_state import get_tensor_model_parallel_world_size
from aiter.jit.utils.torch_guard import torch_compile_guard
Expand Down Expand Up @@ -330,6 +332,43 @@ def forward(
x, self.weight, self.eps, residual, self.x_pad_to_multiple
)
return x, residual
if (
self.fused_allreduce
and self.tp_size > 1
and self.use_fused_quant
and residual is not None
):
# Combined AllReduce + RMSNorm + FP8 quant: the downstream GEMM
# (e.g. qkv_proj) consumes the (fp8, scale) output directly, skipping
# a separate quant kernel. `_aiter_transpose_scale` (resolved at
# init) selects the scale layout matching that GEMM (column-major for
# the preshuffle path, row-major otherwise). The fused kernel does
# not support non-contiguous input.
assert self.quant_type.value in (
_QV_PER_1X128,
_QV_PER_TOKEN,
), "Unsupported quant type for fused allreduce rmsnorm quant"
if self.quant_type.value == _QV_PER_1X128:
x, residual, x_scale = (
tensor_model_parallel_fused_allreduce_rmsnorm_quant_per_group(
x.contiguous(),
residual,
self.weight,
self.eps,
group_size=128,
transpose_scale=self._aiter_transpose_scale,
)
)
else: # _QV_PER_TOKEN
x, residual, x_scale = (
tensor_model_parallel_fused_allreduce_rmsnorm_quant(
x.contiguous(),
residual,
self.weight,
self.eps,
)
)
return (x, x_scale), residual
if self.fused_allreduce and self.tp_size > 1:
assert (
residual is not None
Expand Down
25 changes: 22 additions & 3 deletions atom/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2115,12 +2115,31 @@ def __init__(
reduce_results=not self.fuse_ar_input_norm,
prefix=f"{prefix}.mlp",
)
# Fuse activation quant into AR+RMSNorm when fused_qkv_a_proj is
# per-1x128/per-token FP8, so the kernel emits the (fp8, scale) it
# consumes directly. Independent of the non-AR fuse_input_norm_quant path.
qkv_proj = getattr(self.self_attn, "fused_qkv_a_proj", None)
input_norm_fused_quant = (
not self.self_attn.is_v32
and qkv_proj is not None
and qkv_proj.params_dtype == dtypes.fp8
and qkv_proj.quant_type.value
in (QuantType.per_1x128.value, QuantType.per_Token.value)
)
fused_allreduce = (
self.fuse_ar_input_norm and self.layer_idx > 0 and not is_mtp_block
)
self.input_layernorm = RMSNorm(
config.hidden_size,
eps=config.rms_norm_eps,
fused_allreduce=self.fuse_ar_input_norm
and self.layer_idx > 0
and not is_mtp_block,
fused_allreduce=fused_allreduce,
fused_quant=fused_allreduce and input_norm_fused_quant,
quant_config=quant_config,
prefix=(
f"{prefix}.self_attn.fused_qkv_a_proj"
if qkv_proj is not None
else f"{prefix}.self_attn.q_a_proj"
),
)
self.post_attention_layernorm = RMSNorm(
config.hidden_size,
Expand Down
Loading