diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 8c7180d74b..1df9a532d4 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -32,7 +32,11 @@ ) from atom.utils import envs from atom.utils.decorators import mark_trace -from atom.quantization.quark.utils import weight_dequant_fp8, weight_dequant_mxfp8 +from atom.quantization.quark.utils import ( + weight_dequant_fp8, + weight_dequant_mxfp8, + quantize_weight_to_fp8_128x128_blockscale, +) from torch import nn logger = logging.getLogger("atom") @@ -469,6 +473,7 @@ def online_quantize_weight(self): online_quant_func = get_hip_quant(online_quant_type) assert online_quant_dtype in [ torch.float8_e4m3fn, + torch.float8_e4m3fnuz, torch.float4_e2m1fn_x2, ], ( f"Unsupported online quant: " @@ -519,22 +524,35 @@ def online_quantize_weight(self): elif self.quant_type == QuantType.per_1x32: # dequant MXFP8 (FP8 elements + 1x32 E8M0 shared scale) weight = weight_dequant_mxfp8(weight, weight_scale) - q_weight, weight_scale = online_quant_func( - weight, quant_dtype=online_quant_dtype - ) + + if online_quant_type == QuantType.per_1x128: + # Linear per_1x128 path uses blockscale GEMM, which consumes + # 128x128 weight scales shaped as (N//128, K//128). + q_weight, weight_scale = quantize_weight_to_fp8_128x128_blockscale( + weight, online_quant_dtype + ) + else: + q_weight, weight_scale = online_quant_func( + weight, quant_dtype=online_quant_dtype + ) if need_gather: q_weight, weight_scale = self._shard_quantized_weight( q_weight, weight_scale ) self.weight = nn.Parameter(q_weight, requires_grad=False) self.weight_scale = nn.Parameter(weight_scale, requires_grad=False) + self.weight.weight_loader_process = self.weight_loader_process + self.weight_scale.weight_loader_process = self.weight_loader_process # Update quant state self.quant_type = online_quant_type self.params_dtype = online_quant_dtype self.quant_func = online_quant_func + # online_quant_func already returns fnuz when quant_dtype=fnuz on gfx942; + # only normalize when output is still non-fnuz. self.need_normalize_e4m3fn_to_e4m3fnuz = ( online_quant_dtype == torch.float8_e4m3fnuz + and q_weight.dtype != torch.float8_e4m3fnuz ) self._online_quant_info = { "layer": self.prefix, diff --git a/atom/quant_spec.py b/atom/quant_spec.py index e1d1389e25..f9a7fadb89 100644 --- a/atom/quant_spec.py +++ b/atom/quant_spec.py @@ -218,8 +218,9 @@ def parse(self, online_quant_config: dict) -> ParsedQuantConfig: if not isinstance(online_quant_config, dict): raise TypeError("online_quant_config must be a dict parsed from JSON.") - SCHEME_MAP = { + scheme_map = { "ptpc": QuantType.per_Token, + "per_block": QuantType.per_1x128, } def _parse_online_quant_format(quant_format_str: str) -> LayerQuantConfig: @@ -231,10 +232,14 @@ def _parse_online_quant_format(quant_format_str: str) -> LayerQuantConfig: quant_type = QuantType.per_1x32 dtype_str = quant_format_str[2:] else: - parts = quant_format_str.split("_", 1) - if len(parts) == 2 and parts[0] in SCHEME_MAP: - quant_type = SCHEME_MAP[parts[0]] - dtype_str = parts[1] + matched_scheme = None + for scheme in sorted(scheme_map, key=len, reverse=True): + if quant_format_str.startswith(scheme + "_"): + matched_scheme = scheme + break + if matched_scheme is not None: + quant_type = scheme_map[matched_scheme] + dtype_str = quant_format_str[len(matched_scheme) + 1 :] else: raise ValueError( f"Unsupported online quant format: '{quant_format_str}'. " diff --git a/atom/quantization/quark/utils.py b/atom/quantization/quark/utils.py index 3312fce74c..a0e58bb119 100644 --- a/atom/quantization/quark/utils.py +++ b/atom/quantization/quark/utils.py @@ -124,3 +124,41 @@ def weight_dequant_mxfp8( y = x.to(torch.float32).reshape(M, n_blocks, block_size) y = y * scale.unsqueeze(-1) return y.reshape(M, K).to(out_dtype) + + +def quantize_weight_to_fp8_128x128_blockscale(weight, quant_dtype): + """Quantize a 2D weight to FP8 with 128x128 block scales. + + Returns: + q_weight: quantized weight with the same shape as input ``weight``. + scale: per-block scale with shape ``(ceil(N/128), ceil(K/128))``. + """ + assert weight.dim() == 2, f"expected 2D weight, got shape={tuple(weight.shape)}" + + w = weight.to(torch.float32).contiguous() + n, k = w.shape + n_blocks = (n + 127) // 128 + k_blocks = (k + 127) // 128 + n_padded = n_blocks * 128 + k_padded = k_blocks * 128 + + if n_padded != n or k_padded != k: + w = torch.nn.functional.pad(w, (0, k_padded - k, 0, n_padded - n)) + + w_blocks = w.view(n_blocks, 128, k_blocks, 128).permute(0, 2, 1, 3).contiguous() + + finfo = torch.finfo(quant_dtype) + block_amax = w_blocks.abs().amax(dim=(2, 3)) + scale = (block_amax / finfo.max).clamp_min(torch.finfo(torch.float32).tiny) + + q_blocks = torch.clamp( + w_blocks / scale.unsqueeze(-1).unsqueeze(-1), min=finfo.min, max=finfo.max + ).to(quant_dtype) + + q_weight = ( + q_blocks.permute(0, 2, 1, 3) + .contiguous() + .view(n_padded, k_padded)[:n, :k] + .contiguous() + ) + return q_weight, scale.contiguous() diff --git a/atom/rollout/weight_updater.py b/atom/rollout/weight_updater.py index 8fec5e10cf..754c60dcfd 100644 --- a/atom/rollout/weight_updater.py +++ b/atom/rollout/weight_updater.py @@ -231,27 +231,21 @@ def _requantize_fp8_weight( quant_type = getattr(module, "quant_type", None) if quant_type is not None and quant_type.value == _QT.per_1x128.value: - N, K = tensor_gpu.shape - block_k = 128 - K_blocks = (K + block_k - 1) // block_k - if K % block_k != 0: - padded = torch.zeros( - N, K_blocks * block_k, dtype=torch.float32, device=self.device - ) - padded[:, :K] = tensor_gpu - else: - padded = tensor_gpu - blocks = padded.reshape(N, K_blocks, block_k) - block_amax = blocks.abs().amax(dim=-1) - scale = (block_amax / fp8_max).clamp(min=1e-12) - quantized = (blocks / scale.unsqueeze(-1)).to(fp8_dtype) - quantized = quantized.reshape(N, K_blocks * block_k)[:, :K].contiguous() - param.data.copy_(quantized) - ws = weight_scale.data - weight_scale.data.copy_( - scale[: ws.shape[0], : ws.shape[1]].contiguous().to(ws.dtype) + # Must match the load-time online_quantize_weight layout: a true + # 128x128 block scale of shape (N//128, K//128). The previous code + # produced a 1x128-along-K scale (N, K//128) and sliced it into the + # (N//128, K//128) buffer, which is inconsistent with the blockscale + # GEMM and collapses generation after the first weight update. + from atom.quantization.quark.utils import ( + quantize_weight_to_fp8_128x128_blockscale, ) + q_weight, scale = quantize_weight_to_fp8_128x128_blockscale( + tensor_gpu, fp8_dtype + ) + param.data.copy_(q_weight) + weight_scale.data.copy_(scale.to(weight_scale.dtype)) + elif quant_type is not None and quant_type.value == _QT.per_Tensor.value: amax = tensor_gpu.abs().max() scale = (amax / fp8_max).clamp(min=1e-12)