Skip to content

Commit d256ce2

Browse files
committed
Improve QAT nvfp4 numerics
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without. This is achieved by mimicking the PTQ flow more closely, in particular, in descending order of significance: 1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`, but in `torch.int32` instead of `torch.uint8` 2. Do not cast intermediate fake quantized values to original dtype, e.g. bf16 which loses some fidelity from fp32 3. Fake round blockwise scales to float8 **Test Plan:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` End-to-end tests TBD. ghstack-source-id: d8f7eff Pull Request resolved: #3050
1 parent e1d89e7 commit d256ce2

File tree

4 files changed

+57
-13
lines changed

4 files changed

+57
-13
lines changed

test/quantization/test_qat.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,7 +1910,6 @@ def _test_quantize_api_against_ptq(
19101910
quantize_(m, QATConfig(base_config, step="prepare"), filter_fn)
19111911
out_prepared = m(*example_inputs)
19121912
prepare_sqnr = compute_error(out_prepared, out_baseline)
1913-
19141913
self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)
19151914

19161915
# compare convert
@@ -2086,9 +2085,14 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
20862085
"""
20872086
from torchao.prototype.mx_formats import NVFP4InferenceConfig
20882087

2088+
if use_per_tensor_scale:
2089+
target_prepare_sqnr = 36
2090+
else:
2091+
target_prepare_sqnr = float("inf")
2092+
20892093
self._test_quantize_api_against_ptq(
20902094
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2091-
target_prepare_sqnr=12,
2095+
target_prepare_sqnr=target_prepare_sqnr,
20922096
target_convert_sqnr=float("inf"),
20932097
)
20942098

@@ -2098,11 +2102,16 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
20982102
"""
20992103
Test QAT with `NVFP4FakeQuantizeConfig`.
21002104
"""
2105+
from torchao.prototype.mx_formats import NVFP4InferenceConfig
21012106
from torchao.prototype.qat import NVFP4FakeQuantizeConfig
21022107

21032108
torch.manual_seed(self.SEED)
21042109
m = M().cuda()
21052110
baseline_model = copy.deepcopy(m)
2111+
quantize_(
2112+
baseline_model,
2113+
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
2114+
)
21062115
qat_config = QATConfig(
21072116
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
21082117
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
@@ -2116,7 +2125,11 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
21162125
out = m(*x)
21172126
baseline_out = baseline_model(*x)
21182127
sqnr = compute_error(out, baseline_out).item()
2119-
self.assertGreater(sqnr, 24)
2128+
if use_per_tensor_scale:
2129+
target_sqnr = 130
2130+
else:
2131+
target_sqnr = float("inf")
2132+
self.assertGreaterEqual(sqnr, target_sqnr)
21202133

21212134
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
21222135
@unittest.skipIf(

torchao/prototype/custom_fp_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ def _n_ones(n: int) -> int:
2424
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
2525

2626

27-
def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
27+
def _f32_to_floatx_unpacked(
28+
x: Tensor,
29+
ebits: int,
30+
mbits: int,
31+
compute_dtype: torch.dtype = torch.uint8,
32+
) -> Tensor:
2833
"""Convert FP32 numbers to sub-byte floating point numbers with the given
2934
number of exponent and mantissa bits.
3035
@@ -44,6 +49,7 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
4449
Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
4550
"""
4651
assert x.dtype == torch.float
52+
assert compute_dtype in [torch.uint8, torch.int32]
4753
assert 1 + ebits + mbits <= 8
4854

4955
# calculate constants
@@ -105,7 +111,7 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
105111
denormal_x = x + denorm_mask_float
106112
denormal_x = denormal_x.view(torch.int32)
107113
denormal_x -= denorm_mask_int
108-
denormal_x = denormal_x.to(torch.uint8)
114+
denormal_x = denormal_x.to(compute_dtype)
109115

110116
#
111117
# branch 3: stay in normal range, adjust the exponent and round
@@ -120,26 +126,26 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
120126
normal_x += mant_odd
121127
# take the bits!
122128
normal_x = normal_x >> (MBITS_F32 - mbits)
123-
normal_x = normal_x.to(torch.uint8)
129+
normal_x = normal_x.to(compute_dtype)
124130

125131
#
126132
# combine the branches
127133
#
128-
x = torch.full_like(x, max_int, dtype=torch.uint8)
134+
x = torch.full_like(x, max_int, dtype=compute_dtype)
129135
x = torch.where(denormal_mask, denormal_x, x)
130136
x = torch.where(normal_mask, normal_x, x)
131137

132138
# add sign back
133139
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
134-
sign_lp = sign_lp.to(torch.uint8)
140+
sign_lp = sign_lp.to(compute_dtype)
135141
# Right shift of a negative signed integer can fill the least significant
136142
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
137143
# doesn't have an uint32 dtype, we mask out these bits to get just the
138144
# f4 sign bit
139145
sign_lp = sign_lp & sign_mask
140146
x = x | sign_lp
141147

142-
return x.to(torch.uint8)
148+
return x.to(compute_dtype)
143149

144150

145151
# TODO(future): check if LUT for everything is faster than bit shifting,
@@ -154,7 +160,7 @@ def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
154160
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
155161
Output: torch.Tensor of dtype fp32 with the dequantized value
156162
"""
157-
assert x.dtype == torch.uint8
163+
assert x.dtype in [torch.uint8, torch.int32]
158164
assert 1 + ebits + mbits <= 8
159165

160166
sign_mask = 1 << (ebits + mbits)

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,6 @@ def _nvfp4_quantize(
798798
assert data_hp.is_contiguous(), "Only support contiguous data for now"
799799
assert block_size == 16, "NVFP4 requires block_size=16"
800800

801-
orig_dtype = data_hp.dtype
802801
orig_shape = data_hp.shape
803802
# Convert to float32 early for consistent precision with Triton implementation
804803
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
@@ -834,7 +833,7 @@ def _nvfp4_quantize(
834833
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
835834
data_scaled = data_scaled.view(orig_shape)
836835
if skip_dtype_cast_and_packing:
837-
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
836+
return _Float8Round.apply(out_scales), data_scaled
838837
else:
839838
data_lp = f32_to_f4_unpacked(data_scaled)
840839
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'

torchao/prototype/qat/nvfp4.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
import torch
44

5+
from torchao.prototype.custom_fp_utils import (
6+
_f32_to_floatx_unpacked,
7+
_floatx_unpacked_to_f32,
8+
)
9+
from torchao.prototype.mx_formats.kernels import (
10+
EBITS_F4_E2M1,
11+
MBITS_F4_E2M1,
12+
)
513
from torchao.prototype.mx_formats.nvfp4_tensor import (
614
_nvfp4_quantize,
715
per_tensor_amax_to_scale,
@@ -12,6 +20,24 @@
1220
)
1321

1422

23+
class _FP4Round(torch.autograd.Function):
24+
"""
25+
Cast an fp32 tensor to fp4 and back with backward STE.
26+
"""
27+
28+
@staticmethod
29+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
30+
q = _f32_to_floatx_unpacked(
31+
x, EBITS_F4_E2M1, MBITS_F4_E2M1, compute_dtype=torch.int32
32+
)
33+
dq = _floatx_unpacked_to_f32(q, EBITS_F4_E2M1, MBITS_F4_E2M1)
34+
return dq
35+
36+
@staticmethod
37+
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
38+
return gy
39+
40+
1541
@dataclass
1642
class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
1743
"""
@@ -56,9 +82,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5682
per_tensor_scale=per_tensor_scale,
5783
skip_dtype_cast_and_packing=True,
5884
)
85+
q = _FP4Round.apply(q)
5986
if self.config.use_per_tensor_scale:
6087
scale = scale * per_tensor_scale
61-
assert q.dtype == x.dtype
6288
assert scale.dtype == torch.float32
6389

6490
# dequantize

0 commit comments

Comments
 (0)