Skip to content
Draft

[WIP] #1411

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
)
from atom.utils import envs
from atom.utils.decorators import mark_trace
from atom.quantization.quark.utils import weight_dequant_fp8, weight_dequant_mxfp8
from atom.quantization.quark.utils import (
weight_dequant_fp8,
weight_dequant_mxfp8,
quantize_weight_to_fp8_128x128_blockscale,
)
from torch import nn

logger = logging.getLogger("atom")
Expand Down Expand Up @@ -469,6 +473,7 @@ def online_quantize_weight(self):
online_quant_func = get_hip_quant(online_quant_type)
assert online_quant_dtype in [
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float4_e2m1fn_x2,
], (
f"Unsupported online quant: "
Expand Down Expand Up @@ -519,22 +524,35 @@ def online_quantize_weight(self):
elif self.quant_type == QuantType.per_1x32:
# dequant MXFP8 (FP8 elements + 1x32 E8M0 shared scale)
weight = weight_dequant_mxfp8(weight, weight_scale)
q_weight, weight_scale = online_quant_func(
weight, quant_dtype=online_quant_dtype
)

if online_quant_type == QuantType.per_1x128:
# Linear per_1x128 path uses blockscale GEMM, which consumes
# 128x128 weight scales shaped as (N//128, K//128).
q_weight, weight_scale = quantize_weight_to_fp8_128x128_blockscale(
weight, online_quant_dtype
)
else:
q_weight, weight_scale = online_quant_func(
weight, quant_dtype=online_quant_dtype
)
if need_gather:
q_weight, weight_scale = self._shard_quantized_weight(
q_weight, weight_scale
)
self.weight = nn.Parameter(q_weight, requires_grad=False)
self.weight_scale = nn.Parameter(weight_scale, requires_grad=False)
self.weight.weight_loader_process = self.weight_loader_process
self.weight_scale.weight_loader_process = self.weight_loader_process

# Update quant state
self.quant_type = online_quant_type
self.params_dtype = online_quant_dtype
self.quant_func = online_quant_func
# online_quant_func already returns fnuz when quant_dtype=fnuz on gfx942;
# only normalize when output is still non-fnuz.
self.need_normalize_e4m3fn_to_e4m3fnuz = (
online_quant_dtype == torch.float8_e4m3fnuz
and q_weight.dtype != torch.float8_e4m3fnuz
)
self._online_quant_info = {
"layer": self.prefix,
Expand Down
15 changes: 10 additions & 5 deletions atom/quant_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,9 @@ def parse(self, online_quant_config: dict) -> ParsedQuantConfig:
if not isinstance(online_quant_config, dict):
raise TypeError("online_quant_config must be a dict parsed from JSON.")

SCHEME_MAP = {
scheme_map = {
"ptpc": QuantType.per_Token,
"per_block": QuantType.per_1x128,
}

def _parse_online_quant_format(quant_format_str: str) -> LayerQuantConfig:
Expand All @@ -231,10 +232,14 @@ def _parse_online_quant_format(quant_format_str: str) -> LayerQuantConfig:
quant_type = QuantType.per_1x32
dtype_str = quant_format_str[2:]
else:
parts = quant_format_str.split("_", 1)
if len(parts) == 2 and parts[0] in SCHEME_MAP:
quant_type = SCHEME_MAP[parts[0]]
dtype_str = parts[1]
matched_scheme = None
for scheme in sorted(scheme_map, key=len, reverse=True):
if quant_format_str.startswith(scheme + "_"):
matched_scheme = scheme
break
if matched_scheme is not None:
quant_type = scheme_map[matched_scheme]
dtype_str = quant_format_str[len(matched_scheme) + 1 :]
else:
raise ValueError(
f"Unsupported online quant format: '{quant_format_str}'. "
Expand Down
38 changes: 38 additions & 0 deletions atom/quantization/quark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,41 @@ def weight_dequant_mxfp8(
y = x.to(torch.float32).reshape(M, n_blocks, block_size)
y = y * scale.unsqueeze(-1)
return y.reshape(M, K).to(out_dtype)


def quantize_weight_to_fp8_128x128_blockscale(weight, quant_dtype):
"""Quantize a 2D weight to FP8 with 128x128 block scales.

Returns:
q_weight: quantized weight with the same shape as input ``weight``.
scale: per-block scale with shape ``(ceil(N/128), ceil(K/128))``.
"""
assert weight.dim() == 2, f"expected 2D weight, got shape={tuple(weight.shape)}"

w = weight.to(torch.float32).contiguous()
n, k = w.shape
n_blocks = (n + 127) // 128
k_blocks = (k + 127) // 128
n_padded = n_blocks * 128
k_padded = k_blocks * 128

if n_padded != n or k_padded != k:
w = torch.nn.functional.pad(w, (0, k_padded - k, 0, n_padded - n))

w_blocks = w.view(n_blocks, 128, k_blocks, 128).permute(0, 2, 1, 3).contiguous()

finfo = torch.finfo(quant_dtype)
block_amax = w_blocks.abs().amax(dim=(2, 3))
scale = (block_amax / finfo.max).clamp_min(torch.finfo(torch.float32).tiny)

q_blocks = torch.clamp(
w_blocks / scale.unsqueeze(-1).unsqueeze(-1), min=finfo.min, max=finfo.max
).to(quant_dtype)

q_weight = (
q_blocks.permute(0, 2, 1, 3)
.contiguous()
.view(n_padded, k_padded)[:n, :k]
.contiguous()
)
return q_weight, scale.contiguous()
32 changes: 13 additions & 19 deletions atom/rollout/weight_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,27 +231,21 @@ def _requantize_fp8_weight(
quant_type = getattr(module, "quant_type", None)

if quant_type is not None and quant_type.value == _QT.per_1x128.value:
N, K = tensor_gpu.shape
block_k = 128
K_blocks = (K + block_k - 1) // block_k
if K % block_k != 0:
padded = torch.zeros(
N, K_blocks * block_k, dtype=torch.float32, device=self.device
)
padded[:, :K] = tensor_gpu
else:
padded = tensor_gpu
blocks = padded.reshape(N, K_blocks, block_k)
block_amax = blocks.abs().amax(dim=-1)
scale = (block_amax / fp8_max).clamp(min=1e-12)
quantized = (blocks / scale.unsqueeze(-1)).to(fp8_dtype)
quantized = quantized.reshape(N, K_blocks * block_k)[:, :K].contiguous()
param.data.copy_(quantized)
ws = weight_scale.data
weight_scale.data.copy_(
scale[: ws.shape[0], : ws.shape[1]].contiguous().to(ws.dtype)
# Must match the load-time online_quantize_weight layout: a true
# 128x128 block scale of shape (N//128, K//128). The previous code
# produced a 1x128-along-K scale (N, K//128) and sliced it into the
# (N//128, K//128) buffer, which is inconsistent with the blockscale
# GEMM and collapses generation after the first weight update.
from atom.quantization.quark.utils import (
quantize_weight_to_fp8_128x128_blockscale,
)

q_weight, scale = quantize_weight_to_fp8_128x128_blockscale(
tensor_gpu, fp8_dtype
)
param.data.copy_(q_weight)
weight_scale.data.copy_(scale.to(weight_scale.dtype))

elif quant_type is not None and quant_type.value == _QT.per_Tensor.value:
amax = tensor_gpu.abs().max()
scale = (amax / fp8_max).clamp(min=1e-12)
Expand Down