Commit 898ca19
committed
Improve QAT nvfp4 numerics
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.
**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation
Wikitext:
- With this PR, QAT nvfp4 quantized model achieved 15% lower
perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
same as the quantized baseline
```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
| | |none | 0|word_perplexity|↓ |9.418|± | N/A|
==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.3681|± | N/A|
# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
| | |none | 0|word_perplexity|↓ |10.2281|± | N/A|
```
ghstack-source-id: bb1356c
Pull Request resolved: #30501 parent 5cbbd73 commit 898ca19
File tree
7 files changed
+197
-86
lines changed- test/quantization
- torchao
- prototype
- mx_formats
- qat
- quantization/qat
7 files changed
+197
-86
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1910 | 1910 | | |
1911 | 1911 | | |
1912 | 1912 | | |
1913 | | - | |
1914 | 1913 | | |
1915 | 1914 | | |
1916 | 1915 | | |
| |||
2088 | 2087 | | |
2089 | 2088 | | |
2090 | 2089 | | |
2091 | | - | |
| 2090 | + | |
2092 | 2091 | | |
2093 | 2092 | | |
2094 | 2093 | | |
| 2094 | + | |
2095 | 2095 | | |
2096 | 2096 | | |
2097 | 2097 | | |
2098 | 2098 | | |
2099 | 2099 | | |
2100 | 2100 | | |
| 2101 | + | |
2101 | 2102 | | |
2102 | 2103 | | |
2103 | 2104 | | |
2104 | 2105 | | |
2105 | 2106 | | |
| 2107 | + | |
| 2108 | + | |
| 2109 | + | |
| 2110 | + | |
2106 | 2111 | | |
2107 | 2112 | | |
2108 | 2113 | | |
| |||
2116 | 2121 | | |
2117 | 2122 | | |
2118 | 2123 | | |
2119 | | - | |
| 2124 | + | |
2120 | 2125 | | |
2121 | 2126 | | |
2122 | 2127 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
771 | 771 | | |
772 | 772 | | |
773 | 773 | | |
774 | | - | |
775 | | - | |
776 | | - | |
777 | | - | |
778 | | - | |
779 | | - | |
780 | | - | |
781 | | - | |
782 | | - | |
783 | | - | |
784 | | - | |
785 | | - | |
786 | | - | |
787 | | - | |
788 | | - | |
789 | | - | |
790 | | - | |
791 | | - | |
792 | | - | |
793 | | - | |
794 | | - | |
795 | | - | |
796 | | - | |
797 | 774 | | |
798 | 775 | | |
799 | 776 | | |
800 | 777 | | |
801 | 778 | | |
802 | 779 | | |
803 | 780 | | |
804 | | - | |
805 | 781 | | |
806 | 782 | | |
807 | 783 | | |
| |||
813 | 789 | | |
814 | 790 | | |
815 | 791 | | |
816 | | - | |
817 | | - | |
| 792 | + | |
| 793 | + | |
| 794 | + | |
| 795 | + | |
818 | 796 | | |
819 | 797 | | |
820 | 798 | | |
| |||
826 | 804 | | |
827 | 805 | | |
828 | 806 | | |
829 | | - | |
830 | | - | |
| 807 | + | |
| 808 | + | |
831 | 809 | | |
832 | 810 | | |
833 | 811 | | |
| |||
836 | 814 | | |
837 | 815 | | |
838 | 816 | | |
839 | | - | |
840 | | - | |
841 | | - | |
842 | | - | |
843 | | - | |
844 | | - | |
845 | | - | |
846 | | - | |
| 817 | + | |
| 818 | + | |
| 819 | + | |
| 820 | + | |
| 821 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3 | 3 | | |
4 | 4 | | |
5 | 5 | | |
6 | | - | |
| 6 | + | |
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
11 | | - | |
| 11 | + | |
12 | 12 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
| 2 | + | |
2 | 3 | | |
3 | 4 | | |
4 | 5 | | |
5 | 6 | | |
6 | | - | |
| 7 | + | |
| 8 | + | |
7 | 9 | | |
8 | 10 | | |
9 | | - | |
10 | | - | |
11 | | - | |
12 | | - | |
| 11 | + | |
13 | 12 | | |
14 | 13 | | |
15 | 14 | | |
| |||
23 | 22 | | |
24 | 23 | | |
25 | 24 | | |
| 25 | + | |
| 26 | + | |
26 | 27 | | |
27 | 28 | | |
28 | 29 | | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
29 | 75 | | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
30 | 79 | | |
31 | | - | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
32 | 102 | | |
33 | | - | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
34 | 124 | | |
35 | 125 | | |
36 | | - | |
37 | | - | |
38 | | - | |
39 | | - | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
40 | 149 | | |
41 | 150 | | |
42 | | - | |
43 | | - | |
44 | 151 | | |
| 152 | + | |
45 | 153 | | |
46 | | - | |
47 | | - | |
48 | | - | |
49 | 154 | | |
50 | | - | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
51 | 164 | | |
52 | | - | |
53 | | - | |
54 | | - | |
55 | | - | |
56 | | - | |
57 | | - | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
58 | 180 | | |
59 | | - | |
60 | | - | |
61 | | - | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
67 | | - | |
68 | | - | |
69 | | - | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
208 | 208 | | |
209 | 209 | | |
210 | 210 | | |
211 | | - | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
212 | 229 | | |
213 | 230 | | |
214 | 231 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
444 | 444 | | |
445 | 445 | | |
446 | 446 | | |
447 | | - | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
448 | 450 | | |
449 | 451 | | |
450 | 452 | | |
451 | 453 | | |
452 | | - | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
453 | 457 | | |
454 | 458 | | |
455 | 459 | | |
| |||
0 commit comments