Commit a369ffa
committed
Support NVFP4 dynamic per tensor scale
**Summary:** This commit adds an option for the existing
`NVFP4InferenceConfig` to dynamically compute an appropriate
fp32 per tensor scale to support the two level scaling
according to the NVFP4 specification:
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.
While two level scaling is supported in `NVFP4Tensor`, today
there is no config API for users to call this. The existing
`NVFP4InferenceConfig` only supports single level scaling
because including an explicit `per_tensor_scale` field would
make serialization tricky.
In the future, we should add an end-to-end calibration flow
so users can compute an appropriate per tensor scale for the
activations first, and then pass this to `NVFP4Tensor` as a
static scale, similar to the proposal in #2572.
**Test Plan:**
```
pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4
pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4
```
Also did a quick benchmark before and after:
```
import copy
import time
import torch
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig
m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda")
m_mx2 = copy.deepcopy(m_mx1)
config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False)
config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True)
quantize_(m_mx1, config=config1)
quantize_(m_mx2, config=config2)
m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager")
m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager")
start = time.time()
for _ in range(1000):
m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("No per_tensor_scale = ", time.time() - start, "seconds")
start = time.time()
for _ in range(1000):
m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("With per_tensor_scale = ", time.time() - start, "seconds")
```
On a single B200:
```
No per_tensor_scale = 1.2855589389801025 seconds
With per_tensor_scale = 1.3009123802185059 seconds
```
ghstack-source-id: e6c06b6
Pull Request resolved: #30491 parent 8e2ca35 commit a369ffa
File tree
5 files changed
+49
-17
lines changed- test
- prototype/mx_formats
- quantization
- torchao
- prototype/mx_formats
- quantization/qat
5 files changed
+49
-17
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
105 | 105 | | |
106 | 106 | | |
107 | 107 | | |
| 108 | + | |
108 | 109 | | |
109 | 110 | | |
110 | 111 | | |
| |||
126 | 127 | | |
127 | 128 | | |
128 | 129 | | |
| 130 | + | |
129 | 131 | | |
130 | 132 | | |
131 | 133 | | |
| |||
147 | 149 | | |
148 | 150 | | |
149 | 151 | | |
150 | | - | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
151 | 155 | | |
152 | 156 | | |
153 | 157 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
2077 | 2077 | | |
2078 | 2078 | | |
2079 | 2079 | | |
2080 | | - | |
| 2080 | + | |
| 2081 | + | |
2081 | 2082 | | |
2082 | 2083 | | |
2083 | 2084 | | |
| |||
2086 | 2087 | | |
2087 | 2088 | | |
2088 | 2089 | | |
2089 | | - | |
2090 | | - | |
| 2090 | + | |
| 2091 | + | |
2091 | 2092 | | |
2092 | 2093 | | |
2093 | 2094 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
134 | 135 | | |
135 | 136 | | |
136 | 137 | | |
137 | | - | |
| 138 | + | |
| 139 | + | |
138 | 140 | | |
139 | 141 | | |
140 | 142 | | |
| |||
145 | 147 | | |
146 | 148 | | |
147 | 149 | | |
| 150 | + | |
148 | 151 | | |
149 | 152 | | |
150 | 153 | | |
| |||
175 | 178 | | |
176 | 179 | | |
177 | 180 | | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
178 | 186 | | |
179 | 187 | | |
180 | | - | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
181 | 191 | | |
182 | 192 | | |
183 | 193 | | |
| 194 | + | |
184 | 195 | | |
185 | 196 | | |
186 | 197 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
50 | 51 | | |
51 | 52 | | |
52 | 53 | | |
| |||
245 | 246 | | |
246 | 247 | | |
247 | 248 | | |
248 | | - | |
| 249 | + | |
249 | 250 | | |
250 | 251 | | |
251 | 252 | | |
| |||
645 | 646 | | |
646 | 647 | | |
647 | 648 | | |
| 649 | + | |
| 650 | + | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
648 | 654 | | |
649 | 655 | | |
650 | 656 | | |
651 | | - | |
| 657 | + | |
652 | 658 | | |
653 | 659 | | |
654 | 660 | | |
| |||
672 | 678 | | |
673 | 679 | | |
674 | 680 | | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
675 | 686 | | |
676 | 687 | | |
677 | 688 | | |
678 | | - | |
| 689 | + | |
679 | 690 | | |
680 | 691 | | |
681 | 692 | | |
| |||
697 | 708 | | |
698 | 709 | | |
699 | 710 | | |
| 711 | + | |
700 | 712 | | |
701 | 713 | | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
702 | 719 | | |
703 | 720 | | |
704 | 721 | | |
705 | | - | |
| 722 | + | |
706 | 723 | | |
707 | 724 | | |
708 | 725 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
442 | 442 | | |
443 | 443 | | |
444 | 444 | | |
445 | | - | |
446 | | - | |
447 | | - | |
448 | | - | |
449 | | - | |
450 | 445 | | |
451 | | - | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
452 | 449 | | |
453 | 450 | | |
454 | | - | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
455 | 454 | | |
456 | 455 | | |
457 | 456 | | |
| |||
0 commit comments