Commit d256ce2
committed
Improve QAT nvfp4 numerics
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:
1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8
**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
End-to-end tests TBD.
ghstack-source-id: d8f7eff
Pull Request resolved: #30501 parent e1d89e7 commit d256ce2
File tree
4 files changed
+57
-13
lines changed- test/quantization
- torchao/prototype
- mx_formats
- qat
4 files changed
+57
-13
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1910 | 1910 | | |
1911 | 1911 | | |
1912 | 1912 | | |
1913 | | - | |
1914 | 1913 | | |
1915 | 1914 | | |
1916 | 1915 | | |
| |||
2086 | 2085 | | |
2087 | 2086 | | |
2088 | 2087 | | |
| 2088 | + | |
| 2089 | + | |
| 2090 | + | |
| 2091 | + | |
| 2092 | + | |
2089 | 2093 | | |
2090 | 2094 | | |
2091 | | - | |
| 2095 | + | |
2092 | 2096 | | |
2093 | 2097 | | |
2094 | 2098 | | |
| |||
2098 | 2102 | | |
2099 | 2103 | | |
2100 | 2104 | | |
| 2105 | + | |
2101 | 2106 | | |
2102 | 2107 | | |
2103 | 2108 | | |
2104 | 2109 | | |
2105 | 2110 | | |
| 2111 | + | |
| 2112 | + | |
| 2113 | + | |
| 2114 | + | |
2106 | 2115 | | |
2107 | 2116 | | |
2108 | 2117 | | |
| |||
2116 | 2125 | | |
2117 | 2126 | | |
2118 | 2127 | | |
2119 | | - | |
| 2128 | + | |
| 2129 | + | |
| 2130 | + | |
| 2131 | + | |
| 2132 | + | |
2120 | 2133 | | |
2121 | 2134 | | |
2122 | 2135 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
27 | | - | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
28 | 33 | | |
29 | 34 | | |
30 | 35 | | |
| |||
44 | 49 | | |
45 | 50 | | |
46 | 51 | | |
| 52 | + | |
47 | 53 | | |
48 | 54 | | |
49 | 55 | | |
| |||
105 | 111 | | |
106 | 112 | | |
107 | 113 | | |
108 | | - | |
| 114 | + | |
109 | 115 | | |
110 | 116 | | |
111 | 117 | | |
| |||
120 | 126 | | |
121 | 127 | | |
122 | 128 | | |
123 | | - | |
| 129 | + | |
124 | 130 | | |
125 | 131 | | |
126 | 132 | | |
127 | 133 | | |
128 | | - | |
| 134 | + | |
129 | 135 | | |
130 | 136 | | |
131 | 137 | | |
132 | 138 | | |
133 | 139 | | |
134 | | - | |
| 140 | + | |
135 | 141 | | |
136 | 142 | | |
137 | 143 | | |
138 | 144 | | |
139 | 145 | | |
140 | 146 | | |
141 | 147 | | |
142 | | - | |
| 148 | + | |
143 | 149 | | |
144 | 150 | | |
145 | 151 | | |
| |||
154 | 160 | | |
155 | 161 | | |
156 | 162 | | |
157 | | - | |
| 163 | + | |
158 | 164 | | |
159 | 165 | | |
160 | 166 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
798 | 798 | | |
799 | 799 | | |
800 | 800 | | |
801 | | - | |
802 | 801 | | |
803 | 802 | | |
804 | 803 | | |
| |||
834 | 833 | | |
835 | 834 | | |
836 | 835 | | |
837 | | - | |
| 836 | + | |
838 | 837 | | |
839 | 838 | | |
840 | 839 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2 | 2 | | |
3 | 3 | | |
4 | 4 | | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
5 | 13 | | |
6 | 14 | | |
7 | 15 | | |
| |||
12 | 20 | | |
13 | 21 | | |
14 | 22 | | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
15 | 41 | | |
16 | 42 | | |
17 | 43 | | |
| |||
56 | 82 | | |
57 | 83 | | |
58 | 84 | | |
| 85 | + | |
59 | 86 | | |
60 | 87 | | |
61 | | - | |
62 | 88 | | |
63 | 89 | | |
64 | 90 | | |
| |||
0 commit comments