Skip to content

Commit c0907aa

Browse files
committed
Add AffineQuantizedTensor based workflow doc and examples
Summary: att Test Plan: . Reviewers: Subscribers: Tasks: Tags:
1 parent f8f74c7 commit c0907aa

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

torchao/quantization/README.md

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,125 @@ model = torch.compile(model, mode='max-autotune')
164164
model(input)
165165
```
166166

167+
## Affine Quantization
168+
Affine quantization refers to the type of quantization that maps from floating point numbers to quantized numbers (typically integer) with an affine transformation, i.e.: `quantized_val = float_val / scale + zero_point` where `scale` and `zero_point` are quantization parameters for some granularity and based on some data.
169+
170+
### Quantization Primitives
171+
We used to have different quantize and dequantize operators for quantization with different granularities. But in the end these can all be expressed with a `block_size` argument with different settings, so we unified existing quant primitives to `choose_qparams_affine`, `quantize_affine` and `dequantize_affine` that can represent symmetric/asymmetric per tensor/channel/token/channel_group quantization, this can be used to implement the unified quantized tensor subclass.
172+
173+
### Quantized Tensor Subclass
174+
We also have a unified quantized tensor subclass that implements how to get a quantized tensor from floating point tensor and what does it mean to call linear ops on an instance of the tensor, e.g. `F.linear` and `aten.addmm`, with this we could dispatch to different operators (e.g. `int4mm` op) based on device (cpu, cuda) and quantization settings (`int4`, `int8`) and also packing formats (e.g. format optimized for cpu int4 mm kernel)
175+
176+
### Quantization Flow
177+
What we need to do afterwards is roughly the following
178+
179+
```
180+
for n, m in model.named_modules():
181+
# or use some filter_fn
182+
if isinstance(m, torch.nn.Linear):
183+
# optional filtering for module name, shape etc.
184+
# quantization activation (needed by dynamic quantization)
185+
# m.weight = nn.Parameter(to_laq(m.weight, device=..., layout=..., ...))
186+
m.weight = nn.Parameter(to_aq(m.weight, device=..., layout=..., ...))
187+
```
188+
The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support
189+
`torch.export.export` and `torch.aot_compile` with the following workaround:
190+
```
191+
from torchao.quantization.utils import unwrap_tensor_subclass
192+
m_unwrapped = unwrap_tensor_subclass(m)
193+
194+
195+
# export
196+
m = torch.export.export(m_unwrapped, example_inputs).module()
197+
198+
# aot_compile
199+
torch._export.aot_compile(m_unwrapped, example_inputs)
200+
```
201+
202+
But we expect this will be integrated into the export path by default in the future.
203+
204+
205+
### Example
206+
Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul
207+
as an example:
208+
```python
209+
import torch
210+
from torchao.quantization.quant_api import quantize
211+
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
212+
from torchao.dtypes import to_aq
213+
from torch._inductor.runtime.runtime_utils import do_bench_gpu
214+
import copy
215+
216+
class ToyLinearModel(torch.nn.Module):
217+
def __init__(self, m=64, n=32, k=64):
218+
super().__init__()
219+
self.linear1 = torch.nn.Linear(m, n, bias=False)
220+
self.linear2 = torch.nn.Linear(n, k, bias=False)
221+
222+
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
223+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
224+
225+
def forward(self, x):
226+
x = self.linear1(x)
227+
x = self.linear2(x)
228+
return x
229+
230+
# weight settings
231+
groupsize = 32
232+
mapping_type = MappingType.ASYMMETRIC
233+
block_size = (1, groupsize)
234+
target_dtype = torch.int32
235+
quant_min = 0
236+
quant_max = 15
237+
eps = 1e-6
238+
preserve_zero = False
239+
zero_point_dtype = torch.bfloat16
240+
zero_point_domain = ZeroPointDomain.FLOAT
241+
242+
dtype = torch.bfloat16
243+
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")
244+
m_bf16 = copy.deepcopy(m)
245+
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
246+
247+
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
248+
249+
def apply_weight_quant(weight):
250+
return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
251+
252+
m = quantize(m, apply_weight_quant)
253+
254+
torch._inductor.config.force_fuse_int_mm_with_mul = True
255+
torch._inductor.config.use_mixed_mm = True
256+
257+
# compile the model to improve performance
258+
m = torch.compile(m, mode='max-autotune')
259+
260+
261+
# benchmark to see the speedup
262+
from torch.utils.benchmark import Timer
263+
def benchmark(f, *args, **kwargs):
264+
t0 = Timer(
265+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
266+
)
267+
# blocked_autorange doesn't check for variance in times and would often only run the model a single
268+
# time, as a result many unstable times were showing up. adaptive_autorange solves the issue by checking
269+
# whether the IQR/median < .03 and repeating if not.
270+
res = t0.adaptive_autorange(.03, max_run_time=20)
271+
return res.median * 1e3
272+
273+
bf16_time = benchmark(m_bf16, *example_inputs)
274+
print(f"bf16 median time: {bf16_time}")
275+
int4_time = benchmark(m, *example_inputs)
276+
print(f"int4 weight only quantized median time: {int4_time}")
277+
print(f"speedup: {bf16_time / int4_time}")
278+
279+
280+
# output
281+
# bf16 median time: 0.5524866282939911
282+
# int4 weight only quantized median time: 0.47659454867243767
283+
# speedup: 1.1592382452400098
284+
```
285+
167286

168287
## Notes
169288

0 commit comments

Comments
 (0)