Skip to content

Commit 7f06046

Browse files
committed
Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic. **Unit tests:** ``` python test/quantization/test_qat.py -k test_qat_nvfp4 python test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` **End-to-end tests:** Fine-tuning Llama3.2-3B with and without this PR in axolotl: - fine-tune for 1 epoch on yahma/alpaca-cleaned - batch size 512, learning rate 2e-5, no gradient accumulation Wikitext: - With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline - Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline ``` ==> Llama3.2-3B_baseline_bs512/eval_float.log <== | | |none | 0|word_perplexity|↓ |9.418|± | N/A| ==> Llama3.2-3B_baseline_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.3681|± | N/A| # QAT with this PR (quantized) ==> Llama3.2-3B_qat_bs512/eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.2281|± | N/A| ``` [ghstack-poisoned]
2 parents 90d6af0 + ef3682b commit 7f06046

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

torchao/prototype/qat/nvfp4.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
3131
use_triton_kernel: bool = False
3232

3333

34-
class _NVFP4FakeQuantizedLinearForward(torch.autograd.Function):
34+
class _NVFP4QuantizedForwardFakeQuantizedBackward(torch.autograd.Function):
3535
"""
36-
Autograd function for NVFP4 fake quantization + addmm.
36+
Autograd function for NVFP4 quantization + addmm in low precision during forward,
37+
and fake quantization in high precision during backward.
3738
"""
3839

3940
@staticmethod
@@ -100,7 +101,9 @@ class NVFP4FakeQuantizedLinear(torch.nn.Linear):
100101
"""
101102
Linear module for fake quantized NVFP4 weights and/or activations.
102103
103-
The forward pass follows quantization and addmm numerics in `NVFP4Tensor` exactly.
104+
The forward pass follows quantization and addmm numerics in `NVFP4Tensor`
105+
in lower precision exactly, while the backward pass uses dequantize
106+
(fake quantized) values in high precision.
104107
105108
Example usage::
106109
@@ -146,7 +149,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
146149
x = x.view(-1, x.shape[-1])
147150
else:
148151
batch_size = None
149-
fq = _NVFP4FakeQuantizedLinearForward.apply(
152+
fq = _NVFP4QuantizedForwardFakeQuantizedBackward.apply(
150153
x, self.weight, self.bias, self.activation_config, self.weight_config
151154
)
152155
assert fq.dtype == x.dtype

0 commit comments

Comments
 (0)