Skip to content

Commit f0dd811

Browse files
committed
Refactor rest of tinygemm quant primitive ops
Summary: This PR replaces the remaining tinygemm specific quant primitive ops with the general quant primitive ops that we want to use for everything, we could delete these ops in a separate PR if needed Test Plan: python test/quantization/test_quant_primitives.py -k test_get_groupwise_affine_qparams python test/quantization/test_quant_primitives.py -k test_groupwise_affine_quantize_tensor_from_qparams python test/quantization/test_quant_primitives.py -k test_groupwise_affine_dequantize_tensor_from_qparams accuracy: perf: no diff for generated code with `TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py`
1 parent 729fa4d commit f0dd811

File tree

10 files changed

+538
-212
lines changed

10 files changed

+538
-212
lines changed

benchmarks/benchmark_aq.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs
2+
"""
3+
import torch
4+
from torchao.quantization.subclass import (
5+
Int8WeightOnlyQuantizedLinearWeight,
6+
Int4WeightOnlyQuantizedLinearWeight,
7+
)
8+
from torchao.quantization.utils import (
9+
TORCH_VERSION_AFTER_2_4,
10+
)
11+
from torchao.quantization.quant_api import (
12+
_replace_with_custom_fn_if_matches_filter,
13+
)
14+
import copy
15+
16+
class ToyLinearModel(torch.nn.Module):
17+
def __init__(self, m=64, n=32, k=64):
18+
super().__init__()
19+
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
20+
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
21+
22+
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
23+
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
24+
25+
def forward(self, x):
26+
x = self.linear1(x)
27+
x = self.linear2(x)
28+
return x
29+
30+
def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
31+
"""
32+
The deprecated implementation for int8 dynamic quant API, used as a reference for
33+
numerics and performance
34+
"""
35+
from torchao.quantization.quant_api import _in_features_greater_than_16
36+
from torchao.quantization.quant_api import _is_linear
37+
from torchao.quantization.quant_api import _get_subclass_inserter
38+
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight
39+
40+
if filter_fn is None:
41+
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
42+
*args
43+
)
44+
45+
_replace_with_custom_fn_if_matches_filter(
46+
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
47+
)
48+
49+
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
50+
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
51+
"""
52+
The deprecated implementation for weight only quant API, used as a reference for
53+
numerics and performance
54+
"""
55+
from torchao.quantization.quant_api import _is_linear
56+
from torchao.quantization.quant_api import _get_subclass_inserter
57+
58+
filter_fn = kwargs.pop("filter_fn", _is_linear)
59+
60+
_replace_with_custom_fn_if_matches_filter(
61+
model,
62+
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
63+
filter_fn,
64+
)
65+
66+
return _ref_change_linear_weights_to_woqtensors
67+
68+
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
69+
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
70+
71+
72+
def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
73+
if kwargs is None:
74+
kwargs = {}
75+
76+
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
77+
m_ref = copy.deepcopy(m)
78+
# setting batch_size to 20 to be compatible with the kernel
79+
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
80+
81+
api(m, **kwargs)
82+
83+
# reference
84+
ref_api(m_ref, **kwargs)
85+
86+
res = m(*example_inputs)
87+
ref = m_ref(*example_inputs)
88+
89+
assert torch.equal(res, ref)
90+
91+
# perf comparison
92+
from torchao.utils import benchmark_model
93+
# warmup
94+
WARMUP = 5
95+
RUNS = 100
96+
input_tensor = example_inputs[0]
97+
m = torch.compile(m, mode='max-autotune', fullgraph=True)
98+
99+
benchmark_model(m, WARMUP, input_tensor)
100+
elapsed_time = benchmark_model(m, RUNS, input_tensor)
101+
102+
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
103+
benchmark_model(m_ref, WARMUP, input_tensor)
104+
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)
105+
106+
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
107+
assert elapsed_time < 1.05 * ref_elapsed_time
108+
109+
if __name__ == "__main__" and TORCH_VERSION_AFTER_2_4 and torch.cuda.is_available():
110+
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
111+
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors)
112+
113+
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
114+
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors)
115+
116+
kwargs = {"groupsize": 32}
117+
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
118+
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs)

test/integration/test_integration.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,7 @@ def _test_lin_weight_subclass_impl(
930930
)
931931

932932
@parameterized.expand(COMMON_DEVICE_DTYPE)
933+
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
933934
def test_int8_dynamic_quant_subclass(self, device, dtype):
934935
self._test_lin_weight_subclass_impl(
935936
Int8DynamicallyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
@@ -1217,6 +1218,8 @@ def forward(self, x):
12171218
@parameterized.expand(COMMON_DEVICE_DTYPE)
12181219
@torch.no_grad()
12191220
def test_save_load_dqtensors(self, device, dtype):
1221+
if device == "cpu":
1222+
self.skipTest(f"indcutor failed for cpu right now")
12201223
self._test_handle_save_load_meta_impl(change_linear_weights_to_int8_dqtensors, device, test_dtype=dtype)
12211224

12221225
@parameterized.expand(COMMON_DEVICE_DTYPE)

test/quantization/test_quant_api.py

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from torchao.quantization.subclass import (
3030
to_laq,
3131
LinearActQuantizedTensor,
32+
Int8WeightOnlyQuantizedLinearWeight,
33+
Int4WeightOnlyQuantizedLinearWeight,
3234
)
3335
from torchao.quantization.quant_api import (
3436
_replace_with_custom_fn_if_matches_filter,
@@ -138,6 +140,28 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
138140
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
139141
)
140142

143+
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
144+
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
145+
"""
146+
The deprecated implementation for weight only quant API, used as a reference for
147+
numerics and performance
148+
"""
149+
from torchao.quantization.quant_api import _is_linear
150+
from torchao.quantization.quant_api import _get_subclass_inserter
151+
152+
filter_fn = kwargs.pop("filter_fn", _is_linear)
153+
154+
_replace_with_custom_fn_if_matches_filter(
155+
model,
156+
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
157+
filter_fn,
158+
)
159+
160+
return _ref_change_linear_weights_to_woqtensors
161+
162+
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
163+
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
164+
141165
class TestQuantFlow(unittest.TestCase):
142166
def test_dynamic_quant_gpu_singleline(self):
143167
m = ToyLinearModel().eval()
@@ -478,8 +502,7 @@ def test_quantized_tensor_subclass_int4(self):
478502
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
479503

480504
# reference
481-
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
482-
change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
505+
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
483506

484507
res = m(*example_inputs)
485508
ref = m_copy(*example_inputs)
@@ -489,7 +512,7 @@ def test_quantized_tensor_subclass_int4(self):
489512

490513
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
491514
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
492-
def test_quantized_tensor_subclass_int8(self):
515+
def test_quantized_tensor_subclass_int8_wo(self):
493516
m = ToyLinearModel().eval().to(torch.bfloat16)
494517
m_copy = copy.deepcopy(m)
495518
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
@@ -500,13 +523,13 @@ def test_quantized_tensor_subclass_int8(self):
500523
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
501524

502525
# reference
503-
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
504-
change_linear_weights_to_int8_woqtensors(m_copy)
526+
_ref_change_linear_weights_to_int8_woqtensors(m_copy)
527+
505528

506529
res = m(*example_inputs)
507530
ref = m_copy(*example_inputs)
508531

509-
torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2)
532+
self.assertTrue(torch.equal(res, ref))
510533

511534

512535
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@@ -525,8 +548,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
525548
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)
526549

527550
# reference
528-
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
529-
change_linear_weights_to_int8_dqtensors(m_copy)
551+
_ref_change_linear_weights_to_int8_dqtensors(m_copy)
530552

531553
res = m(*example_inputs)
532554
ref = m_copy(*example_inputs)
@@ -545,45 +567,5 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
545567
# make sure it compiles
546568
torch._export.aot_compile(m_unwrapped, example_inputs)
547569

548-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
549-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
550-
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation")
551-
def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):
552-
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
553-
m_ref = copy.deepcopy(m)
554-
# setting batch_size to 20 to be compatible with the kernel
555-
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
556-
557-
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
558-
change_linear_weights_to_int8_dqtensors(m)
559-
560-
# reference
561-
_ref_change_linear_weights_to_int8_dqtensors(m_ref)
562-
563-
res = m(*example_inputs)
564-
ref = m_ref(*example_inputs)
565-
566-
self.assertTrue(torch.equal(res, ref))
567-
568-
# perf comparison
569-
from torchao.utils import benchmark_model
570-
# warmup
571-
WARMUP = 5
572-
RUNS = 100
573-
input_tensor = example_inputs[0]
574-
m = torch.compile(m, mode='max-autotune', fullgraph=True)
575-
576-
benchmark_model(m, WARMUP, input_tensor)
577-
elapsed_time = benchmark_model(m, RUNS, input_tensor)
578-
579-
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
580-
benchmark_model(m_ref, WARMUP, input_tensor)
581-
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)
582-
583-
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
584-
self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time)
585-
586-
587-
588570
if __name__ == "__main__":
589571
unittest.main()

0 commit comments

Comments
 (0)