Skip to content

Commit 155c576

Browse files
committed
Update on "Deprecate config functions like int4_weight_only"
**Summary:** These have been superseded by `AOBaseConfig` objects for several releases already, but we never deprecated them. We will keep them around for another release before breaking BC and removing them. **Test Plan:** ``` python test/quantization/test_quant_api.py -k test_config_deprecation ``` ``` >>> int4_weight_only() /home/andrewor/local/ao/torchao/utils.py:446: UserWarning: `int4_weight_only` is deprecated and will be removed in a future release. Please use `Int4WeightOnlyConfig` instead. Example usage: quantize_(model, Int4WeightOnlyConfig(...)) warnings.warn( Int4WeightOnlyConfig(group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False, zero_point_domain=<ZeroPointDomain.NONE: 3>, set_inductor_config=True, preserve_zero=None, int4_packing_format=<Int4PackingFormat.PLAIN: 'plain'>, int4_choose_qparams_algorithm=<Int4ChooseQParamsAlgorithm.TINYGEMM: 'tinygemm'>, version=2) ``` [ghstack-poisoned]
2 parents 5e93d50 + 9e30e7f commit 155c576

File tree

9 files changed

+517
-97
lines changed

9 files changed

+517
-97
lines changed

.github/workflows/regression_test_aarch64.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ jobs:
5555
pytest -s test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py
5656
pytest -s test/prototype/test_embedding.py
5757
pytest -s test/prototype/test_int8_lut_tensor.py
58+
pytest -s test/prototype/test_tensor_conversion.py
5859
pytest -s test/prototype/test_groupwise_lowbit_weight_lut_quantizer.py
5960
pytest -s test/prototype/test_parq.py
6061
- name: torchao/csrc/cpu - build and run C++ tests
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import pytest
9+
import torch
10+
11+
from torchao.prototype.parq.quant import (
12+
StretchedIntxWeightConfig,
13+
StretchedUnifTorchaoQuantizer,
14+
)
15+
from torchao.prototype.quantization.int8_lut_tensor.int8_lut_tensor import Int8LutTensor
16+
from torchao.prototype.tensor_conversion.api import _convert_model_for_aarch64
17+
from torchao.quantization import MappingType
18+
from torchao.quantization.granularity import PerAxis, PerGroup
19+
from torchao.quantization.quant_api import (
20+
Int8DynamicActivationIntxWeightConfig,
21+
IntxWeightOnlyConfig,
22+
quantize_,
23+
)
24+
from torchao.quantization.quantize_.workflows.intx.intx_opaque_tensor import (
25+
IntxOpaqueTensor,
26+
_is_kernel_library_loaded,
27+
)
28+
from torchao.quantization.utils import compute_error
29+
30+
31+
class ToyLinearModelWithTiedEmbedding(torch.nn.Module):
32+
def __init__(self, d0=512, d1=512, d2=256, d3=128, d4=32):
33+
super().__init__()
34+
self.embedding1 = torch.nn.Embedding(d0, d1)
35+
self.embedding2 = torch.nn.Embedding(d0, d1)
36+
self.embedding3 = torch.nn.Embedding(d0, d1)
37+
38+
self.linear1 = torch.nn.Linear(d1, d2, bias=False)
39+
self.linear2 = torch.nn.Linear(d2, d3, bias=True)
40+
self.linear3 = torch.nn.Linear(d3, d4, bias=False)
41+
self.linear4 = torch.nn.Linear(d4, d1, bias=False)
42+
43+
self.lm_head1 = torch.nn.Linear(d1, d0, bias=False)
44+
self.lm_head2 = torch.nn.Linear(d1, d0, bias=False)
45+
self.lm_head3 = torch.nn.Linear(d1, d0, bias=False)
46+
47+
# Tie weights
48+
# lm_head1 / lm_head2 form one tied weight group
49+
self.embedding2.weight = self.embedding1.weight
50+
self.lm_head1.weight = self.embedding1.weight
51+
self.lm_head2.weight = self.embedding1.weight
52+
53+
# lm_head3 forms a separate tied weight group
54+
self.lm_head3.weight = self.embedding3.weight
55+
56+
def example_inputs(
57+
self,
58+
lead_dim=(1,),
59+
dtype=torch.bfloat16,
60+
):
61+
return (
62+
torch.randint(
63+
0,
64+
self.embedding1.num_embeddings,
65+
size=lead_dim,
66+
dtype=torch.int64,
67+
device="cpu",
68+
),
69+
)
70+
71+
def forward(self, x):
72+
x = self.embedding1(x) + self.embedding2(x) + self.embedding3(x)
73+
x = self.linear1(x)
74+
x = self.linear2(x)
75+
x = self.linear3(x)
76+
x = self.linear4(x)
77+
x = self.lm_head1(x) + self.lm_head2(x) + self.lm_head3(x)
78+
return x
79+
80+
81+
@pytest.fixture(autouse=True)
82+
def run_before_and_after_tests():
83+
yield
84+
torch._dynamo.reset() # reset cache between tests
85+
86+
87+
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
88+
@pytest.mark.parametrize("granularity", [PerGroup(32), PerAxis(0)])
89+
@pytest.mark.parametrize("bit_width", [1, 2, 3, 4])
90+
@pytest.mark.parametrize(
91+
"lead_dim",
92+
[
93+
(1,),
94+
(5,),
95+
(7, 2),
96+
],
97+
)
98+
@pytest.mark.skipif(
99+
not _is_kernel_library_loaded(), reason="Kernel library is not loaded"
100+
)
101+
def test_aarch64_conversion(dtype, granularity, bit_width, lead_dim):
102+
torch.manual_seed(0)
103+
104+
model = ToyLinearModelWithTiedEmbedding()
105+
model = model.to(dtype)
106+
example_inputs = model.example_inputs(lead_dim, dtype)
107+
108+
# Quantize linear 2 and 3 with PARQ
109+
quantizer = StretchedUnifTorchaoQuantizer(bit_width)
110+
config = StretchedIntxWeightConfig(
111+
b=bit_width,
112+
quant_min=quantizer.quant_min,
113+
quant_max=quantizer.quant_max,
114+
granularity=granularity,
115+
activation_quantization="int8_asym_per_token",
116+
)
117+
quantize_(model, config, filter_fn=lambda m, fqn: fqn in ["linear2", "linear3"])
118+
119+
# Quantize linear 1 and 4 with int8 dynamic activation
120+
config = Int8DynamicActivationIntxWeightConfig(
121+
weight_dtype=torch.int4,
122+
weight_granularity=granularity,
123+
weight_mapping_type=MappingType.SYMMETRIC,
124+
)
125+
quantize_(
126+
model,
127+
config,
128+
filter_fn=lambda m, fqn: fqn
129+
in ["linear1", "linear4", "lm_head1", "lm_head2", "lm_head3"],
130+
)
131+
132+
# Quantize embedding 1, 2, and 3 with weight only
133+
config = IntxWeightOnlyConfig(
134+
weight_dtype=torch.int4,
135+
granularity=granularity,
136+
mapping_type=MappingType.SYMMETRIC,
137+
)
138+
quantize_(
139+
model,
140+
config,
141+
filter_fn=lambda m, fqn: fqn in ["embedding1", "embedding2", "embedding3"],
142+
)
143+
model_out = model(*example_inputs)
144+
145+
# Convert to optimized model
146+
_convert_model_for_aarch64(model)
147+
148+
# Check expected tensor subclass
149+
assert isinstance(model.linear2.weight, Int8LutTensor)
150+
assert isinstance(model.linear3.weight, Int8LutTensor)
151+
assert isinstance(model.linear1.weight, IntxOpaqueTensor)
152+
assert isinstance(model.linear4.weight, IntxOpaqueTensor)
153+
154+
# Assert tied params
155+
tied_group1_id = id(model.embedding1.weight)
156+
assert id(model.embedding2.weight) == tied_group1_id
157+
assert id(model.lm_head1.weight) == tied_group1_id
158+
assert id(model.lm_head2.weight) == tied_group1_id
159+
160+
assert id(model.lm_head3.weight) == id(model.embedding3.weight)
161+
assert id(model.lm_head3.weight) != tied_group1_id
162+
163+
# Compare converted out with original out
164+
converted_out = model(*example_inputs)
165+
sqnr = compute_error(model_out, converted_out)
166+
sqnr_threshold = 30
167+
assert sqnr > sqnr_threshold, f"sqnr: {sqnr}"
168+
169+
# Check exported graph for correct ops
170+
ep = torch.export.export(model, example_inputs)
171+
expected_counts = {
172+
"torch.ops.torchao._shared_embedding_": 3,
173+
"torch.ops.torchao._linear_8bit_act_": 7,
174+
"torch.ops.aten.linear.default": 0,
175+
"torch.ops.aten.embedding.default": 0,
176+
}
177+
for line, cnt in expected_counts.items():
178+
assert ep.graph_module.code.count(line) == cnt, (
179+
f"expected {cnt} {line} in {ep.graph_module.code}"
180+
)

test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py

Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
import unittest
1111

1212
import torch
13-
from parameterized import param, parameterized
1413
from torch.testing import FileCheck
14+
from torch.testing._internal.common_utils import (
15+
TestCase,
16+
instantiate_parametrized_tests,
17+
parametrize,
18+
)
1519

1620
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
1721
from torchao.quantization.granularity import PerAxis, PerGroup
@@ -34,42 +38,35 @@
3438

3539

3640
@unittest.skipIf(not _is_kernel_library_loaded(), "Kernel library not loaded")
37-
class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
38-
TEST_ACCURACY_CASES = [
39-
param(
40-
layout=layout,
41-
weight_dtype=weight_dtype,
42-
weight_mapping_type=weight_mapping_type,
43-
weight_granularity=weight_granularity,
44-
)
45-
for layout in [
46-
PackedLinearInt8DynamicActivationIntxWeightLayout(),
47-
PackedLinearInt8DynamicActivationIntxWeightLayout(target="universal"),
48-
]
49-
for weight_dtype in [
50-
torch.int1,
51-
torch.int2,
52-
torch.int3,
53-
torch.int4,
54-
torch.int5,
55-
torch.int6,
56-
torch.int7,
57-
torch.int8,
58-
]
59-
for weight_mapping_type in [
60-
MappingType.SYMMETRIC,
61-
MappingType.ASYMMETRIC,
62-
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
63-
]
64-
for weight_granularity in [
65-
PerGroup(128),
66-
PerAxis(0),
67-
]
68-
]
69-
70-
@parameterized.expand(
71-
TEST_ACCURACY_CASES,
72-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
41+
class TestInt8DynamicActivationIntxWeight(TestCase):
42+
@parametrize(
43+
"layout, weight_dtype, weight_mapping_type, weight_granularity",
44+
[
45+
(layout, weight_dtype, weight_mapping_type, weight_granularity)
46+
for layout in [
47+
PackedLinearInt8DynamicActivationIntxWeightLayout(),
48+
PackedLinearInt8DynamicActivationIntxWeightLayout(target="universal"),
49+
]
50+
for weight_dtype in [
51+
torch.int1,
52+
torch.int2,
53+
torch.int3,
54+
torch.int4,
55+
torch.int5,
56+
torch.int6,
57+
torch.int7,
58+
torch.int8,
59+
]
60+
for weight_mapping_type in [
61+
MappingType.SYMMETRIC,
62+
MappingType.ASYMMETRIC,
63+
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
64+
]
65+
for weight_granularity in [
66+
PerGroup(128),
67+
PerAxis(0),
68+
]
69+
],
7370
)
7471
def test_accuracy(
7572
self, layout, weight_dtype, weight_mapping_type, weight_granularity
@@ -396,15 +393,12 @@ def test_export_QDQLayout(self):
396393
exported.graph_module.code
397394
)
398395

399-
@parameterized.expand(
396+
@parametrize(
397+
"layout",
400398
[
401-
param(layout=layout)
402-
for layout in [
403-
PackedLinearInt8DynamicActivationIntxWeightLayout(),
404-
QDQLayout(),
405-
]
399+
PackedLinearInt8DynamicActivationIntxWeightLayout(),
400+
QDQLayout(),
406401
],
407-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
408402
)
409403
def test_serialization(self, layout):
410404
layers = [
@@ -436,20 +430,16 @@ def test_serialization(self, layout):
436430
actual = model2(activations)
437431
self.assertTrue(torch.allclose(expected, actual))
438432

439-
@parameterized.expand(
433+
@parametrize(
434+
"group_size, mapping_type, act_mapping_type",
440435
[
441-
param(
442-
group_size=group_size,
443-
mapping_type=mapping_type,
444-
act_mapping_type=act_mapping_type,
445-
)
436+
(group_size, mapping_type, act_mapping_type)
446437
for group_size, mapping_type, act_mapping_type in zip(
447438
[32, 64],
448439
[MappingType.ASYMMETRIC, MappingType.SYMMETRIC],
449440
[MappingType.ASYMMETRIC, MappingType.SYMMETRIC],
450441
)
451442
],
452-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
453443
)
454444
def test_identical_to_Int8DynamicActivationInt4WeightConfig(
455445
self, group_size, mapping_type, act_mapping_type
@@ -490,15 +480,16 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
490480
sqnr = compute_error(model(activations), model_copy(activations)).item()
491481
self.assertTrue(sqnr == float("inf"))
492482

493-
@parameterized.expand(
483+
@parametrize(
484+
"weight_dtype, group_size, mapping_type, act_mapping_type, scale_dtype, model_dtype",
494485
[
495-
param(
496-
weight_dtype=weight_dtype,
497-
group_size=group_size,
498-
mapping_type=mapping_type,
499-
act_mapping_type=act_mapping_type,
500-
scale_dtype=scale_dtype,
501-
model_dtype=model_dtype,
486+
(
487+
weight_dtype,
488+
group_size,
489+
mapping_type,
490+
act_mapping_type,
491+
scale_dtype,
492+
model_dtype,
502493
)
503494
for weight_dtype in list(getattr(torch, f"int{x}") for x in range(1, 9))
504495
for group_size in [32, 64, 128]
@@ -507,7 +498,6 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
507498
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
508499
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
509500
],
510-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
511501
)
512502
def test_identical_to_IntXQuantizationAwareTrainingConfig(
513503
self,
@@ -582,18 +572,14 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
582572
sqnr = compute_error(prepared_out, converted_out).item()
583573
self.assertTrue(sqnr == float("inf"))
584574

585-
@parameterized.expand(
575+
@parametrize(
576+
"group_size, scale_dtype, model_dtype",
586577
[
587-
param(
588-
group_size=group_size,
589-
scale_dtype=scale_dtype,
590-
model_dtype=model_dtype,
591-
)
578+
(group_size, scale_dtype, model_dtype)
592579
for group_size in [32, 64, 128]
593580
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
594581
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
595582
],
596-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
597583
)
598584
def test_identical_to_Int8DynActInt4WeightQATQuantizer(
599585
self, group_size, scale_dtype, model_dtype
@@ -690,5 +676,7 @@ def test_moe_quant_intx(self):
690676
self.assertGreater(compute_error(out_qc, out), 30)
691677

692678

679+
instantiate_parametrized_tests(TestInt8DynamicActivationIntxWeight)
680+
693681
if __name__ == "__main__":
694682
unittest.main()

0 commit comments

Comments
 (0)