Commit e16506d
committed
Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
[ghstack-poisoned]1 file changed
+1
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2096 | 2096 | | |
2097 | 2097 | | |
2098 | 2098 | | |
| 2099 | + | |
2099 | 2100 | | |
2100 | 2101 | | |
2101 | 2102 | | |
| |||
0 commit comments