Commit 7f06046
committed
Update on "Improve QAT nvfp4 numerics"
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.
**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation
Wikitext:
- With this PR, QAT nvfp4 quantized model achieved 15% lower
perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
same as the quantized baseline
```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
| | |none | 0|word_perplexity|↓ |9.418|± | N/A|
==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.3681|± | N/A|
# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.2281|± | N/A|
```
[ghstack-poisoned]1 file changed
+7
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
31 | 31 | | |
32 | 32 | | |
33 | 33 | | |
34 | | - | |
| 34 | + | |
35 | 35 | | |
36 | | - | |
| 36 | + | |
| 37 | + | |
37 | 38 | | |
38 | 39 | | |
39 | 40 | | |
| |||
100 | 101 | | |
101 | 102 | | |
102 | 103 | | |
103 | | - | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
104 | 107 | | |
105 | 108 | | |
106 | 109 | | |
| |||
146 | 149 | | |
147 | 150 | | |
148 | 151 | | |
149 | | - | |
| 152 | + | |
150 | 153 | | |
151 | 154 | | |
152 | 155 | | |
| |||
0 commit comments