Skip to content

Commit 8669213

Browse files
authored
Introduce IntxOpaqueTensor to replace PackedInt8DynamicActivationIntxWeightLayout in AQT (#2742)
* up * Refactor packed format to remove AQT * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up * up
1 parent e2514dd commit 8669213

File tree

6 files changed

+706
-0
lines changed

6 files changed

+706
-0
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ jobs:
5454
python torchao/experimental/tests/test_embedding_xbit_quantizer.py
5555
python torchao/experimental/tests/test_quant_passes.py
5656
pytest -s test/prototype/test_dynamic_activation_lut.py
57+
pytest -s test/quantization/quantize_/workflows/intx/test_intx_opaque_tensor.py
5758
- name: Run kernels/cpu/aarch64/tests
5859
if: runner.os == 'macOS'
5960
run: |
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import tempfile
9+
import unittest
10+
11+
import torch
12+
from parameterized import param, parameterized
13+
from torch.testing._internal.common_utils import (
14+
TestCase,
15+
run_tests,
16+
)
17+
18+
from torchao.experimental.op_lib_utils import _check_torchao_ops_loaded
19+
from torchao.quantization.granularity import PerAxis, PerGroup
20+
from torchao.quantization.quant_api import (
21+
Int8DynamicActivationIntxWeightConfig,
22+
MappingType,
23+
quantize_,
24+
)
25+
from torchao.quantization.quantize_.common import PackingFormat
26+
from torchao.quantization.utils import compute_error
27+
28+
29+
def _get_accuracy_test_cases():
30+
MODEL_DTYPES = [
31+
torch.float32,
32+
torch.bfloat16,
33+
]
34+
35+
PACKING_FORMATS = [
36+
(PackingFormat.UNPACKED_TO_INT8, None),
37+
(PackingFormat.OPAQUE, "aten"),
38+
(PackingFormat.OPAQUE, "torchao_auto"),
39+
(PackingFormat.OPAQUE, "torchao_lowbit"),
40+
(PackingFormat.OPAQUE, "torchao_kleidiai"),
41+
]
42+
43+
WEIGHT_DTYPES = [
44+
torch.int1,
45+
torch.int2,
46+
torch.int3,
47+
torch.int4,
48+
torch.int5,
49+
torch.int6,
50+
torch.int7,
51+
torch.int8,
52+
]
53+
54+
MAPPING_TYPES = [
55+
MappingType.SYMMETRIC,
56+
MappingType.ASYMMETRIC,
57+
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
58+
]
59+
60+
GRANULARITIES = [PerGroup(128), PerAxis(0)]
61+
62+
def _is_valid_test_combination(
63+
model_dtype,
64+
packing_format,
65+
compute_target,
66+
weight_dtype,
67+
weight_mapping_type,
68+
weight_granularity,
69+
):
70+
# ATEN restrictions
71+
if (packing_format == PackingFormat.OPAQUE) and (compute_target == "aten"):
72+
if weight_dtype != torch.int4:
73+
return False
74+
if weight_mapping_type == MappingType.ASYMMETRIC:
75+
return False
76+
if model_dtype != torch.float32:
77+
return False
78+
79+
# TORCHAO_KLEIDIAI restrictions
80+
if (packing_format == PackingFormat.OPAQUE) and (
81+
compute_target == "torchao_kleidiai"
82+
):
83+
if weight_dtype != torch.int4:
84+
return False
85+
if weight_mapping_type == MappingType.ASYMMETRIC:
86+
return False
87+
88+
# SYMMETRIC_NO_CLIPPING_ERR does not work well with int1
89+
if (
90+
weight_dtype == torch.int1
91+
and weight_mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR
92+
):
93+
return False
94+
95+
return True
96+
97+
test_cases = [
98+
param(
99+
model_dtype=mdt,
100+
packing_format=pf,
101+
compute_target=ct,
102+
weight_dtype=dt,
103+
weight_mapping_type=mt,
104+
weight_granularity=gr,
105+
)
106+
for mdt in MODEL_DTYPES
107+
for pf, ct in PACKING_FORMATS
108+
for dt in WEIGHT_DTYPES
109+
for mt in MAPPING_TYPES
110+
for gr in GRANULARITIES
111+
if _is_valid_test_combination(dt, pf, ct, dt, mt, gr)
112+
]
113+
114+
return test_cases
115+
116+
117+
_TORCHAO_OPS_LOADED = False
118+
try:
119+
_check_torchao_ops_loaded()
120+
_TORCHAO_OPS_LOADED = True
121+
except Exception:
122+
pass
123+
124+
125+
@unittest.skipIf(not _TORCHAO_OPS_LOADED, "Need torchao ops")
126+
class TestIntxOpaqueTensor(TestCase):
127+
@parameterized.expand(
128+
_get_accuracy_test_cases(),
129+
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
130+
)
131+
def test_accuracy(
132+
self,
133+
model_dtype,
134+
packing_format,
135+
compute_target,
136+
weight_dtype,
137+
weight_mapping_type,
138+
weight_granularity,
139+
):
140+
"""
141+
Checks the accuracy of packed layouts
142+
"""
143+
m = 3
144+
n = 1071
145+
k = 2048
146+
activations = torch.randn(m, k).to(model_dtype)
147+
model = torch.nn.Sequential(
148+
*[torch.nn.Linear(k, k, bias=False), torch.nn.Linear(k, n, bias=True)]
149+
).to(model_dtype)
150+
151+
quantized_model = copy.deepcopy(model)
152+
quantize_(
153+
quantized_model,
154+
Int8DynamicActivationIntxWeightConfig(
155+
weight_dtype=weight_dtype,
156+
weight_granularity=weight_granularity,
157+
weight_mapping_type=weight_mapping_type,
158+
packing_format=packing_format,
159+
compute_target=compute_target,
160+
version=2,
161+
),
162+
)
163+
164+
quantized_model_reference = copy.deepcopy(model)
165+
quantize_(
166+
quantized_model_reference,
167+
Int8DynamicActivationIntxWeightConfig(
168+
weight_dtype=weight_dtype,
169+
weight_granularity=weight_granularity,
170+
weight_mapping_type=weight_mapping_type,
171+
packing_format=PackingFormat.UNPACKED_TO_INT8,
172+
compute_target=None,
173+
version=2,
174+
),
175+
)
176+
177+
with torch.no_grad():
178+
result = quantized_model(activations)
179+
expected_result = quantized_model_reference(activations)
180+
181+
sqnr = compute_error(result, expected_result)
182+
self.assertTrue(sqnr > 30, f"Got SQNR of {sqnr}")
183+
184+
def test_export_compile_aoti(
185+
self,
186+
):
187+
m = 3
188+
k0 = 512
189+
k1 = 256
190+
k2 = 128
191+
k3 = 1024
192+
weight_dtype = torch.int4
193+
weight_granularity = PerAxis(0)
194+
weight_mapping_type = MappingType.ASYMMETRIC
195+
196+
layers = [
197+
torch.nn.Linear(k0, k1, bias=False),
198+
torch.nn.Linear(k1, k2, bias=True),
199+
torch.nn.Linear(k2, k3, bias=False),
200+
]
201+
model = torch.nn.Sequential(*layers)
202+
activations = torch.randn(2, 1, m, k0, dtype=torch.float32)
203+
dynamic_shapes = {
204+
"input": {
205+
0: torch.export.Dim.AUTO,
206+
1: torch.export.Dim.STATIC,
207+
2: torch.export.Dim.AUTO,
208+
3: torch.export.Dim.STATIC,
209+
}
210+
}
211+
212+
quantize_(
213+
model,
214+
Int8DynamicActivationIntxWeightConfig(
215+
weight_dtype=weight_dtype,
216+
weight_granularity=weight_granularity,
217+
weight_mapping_type=weight_mapping_type,
218+
packing_format=PackingFormat.OPAQUE,
219+
compute_target="torchao_auto",
220+
version=2,
221+
),
222+
)
223+
eager_results = model(activations)
224+
225+
# Export
226+
exported = torch.export.export(
227+
model, (activations,), strict=True, dynamic_shapes=dynamic_shapes
228+
)
229+
exported_results = exported.module()(activations)
230+
self.assertTrue(torch.allclose(eager_results, exported_results))
231+
232+
# Compile
233+
compiled = torch.compile(model)
234+
with torch.no_grad():
235+
compiled_results = compiled(activations)
236+
self.assertTrue(torch.allclose(eager_results, compiled_results))
237+
238+
# AOTI
239+
with tempfile.TemporaryDirectory() as tmpdirname:
240+
package_path = f"{tmpdirname}/model.pt2"
241+
torch._inductor.aoti_compile_and_package(
242+
exported, package_path=package_path
243+
)
244+
fn = torch._inductor.aoti_load_package(package_path)
245+
aoti_results = fn(activations)
246+
self.assertTrue(torch.allclose(eager_results, aoti_results))
247+
248+
@parameterized.expand(
249+
[
250+
param(packing_format=pf, compute_target=ct)
251+
for (pf, ct) in [
252+
(PackingFormat.OPAQUE, "torchao_auto"),
253+
(PackingFormat.OPAQUE, "aten"),
254+
]
255+
],
256+
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
257+
)
258+
def test_serialization(self, packing_format, compute_target):
259+
layers = [
260+
torch.nn.Linear(512, 256),
261+
]
262+
model = torch.nn.Sequential(*layers)
263+
model2 = torch.nn.Sequential(*layers)
264+
activations = torch.randn(1, 512, dtype=torch.float32)
265+
266+
quantize_(
267+
model,
268+
Int8DynamicActivationIntxWeightConfig(
269+
weight_dtype=torch.int4,
270+
weight_granularity=PerGroup(64),
271+
packing_format=packing_format,
272+
compute_target=compute_target,
273+
version=2,
274+
),
275+
)
276+
expected = model(activations)
277+
278+
with tempfile.TemporaryDirectory() as tmpdirname:
279+
torch.save(model.state_dict(), f"{tmpdirname}/model.pt")
280+
state_dict = torch.load(
281+
f"{tmpdirname}/model.pt", map_location="cpu", weights_only=True
282+
)
283+
284+
# Load deserialized weights into model2 and check result
285+
model2.load_state_dict(state_dict, assign=True)
286+
actual = model2(activations)
287+
self.assertTrue(torch.allclose(expected, actual))
288+
289+
def test_moe_quant_intx(self):
290+
from torchao.prototype.moe_quant.quantizable_moe_modules import (
291+
MOEFeedForwardAOQuantizable,
292+
)
293+
from torchao.prototype.moe_quant.utils import (
294+
FakeExtraDimTensor,
295+
MoEQuantConfig,
296+
UseFakeExtraDimTensor,
297+
cond_ffn_filter,
298+
)
299+
from torchao.quantization.quant_api import (
300+
Int8DynamicActivationIntxWeightConfig,
301+
quantize_,
302+
)
303+
from torchao.quantization.utils import compute_error
304+
305+
with torch.device("cpu"):
306+
model = MOEFeedForwardAOQuantizable(512, 256, 8, 2, empty_init=False).to(
307+
torch.float32
308+
)
309+
x = torch.randn(8, 512, dtype=torch.float32)
310+
311+
out = model(x).clone()
312+
313+
base_config = Int8DynamicActivationIntxWeightConfig(
314+
packing_format=PackingFormat.OPAQUE,
315+
compute_target="torchao_auto",
316+
version=2,
317+
)
318+
moe_config = MoEQuantConfig(
319+
base_config, use_fake_extra_dim_tensor=UseFakeExtraDimTensor.TRUE
320+
)
321+
322+
quantize_(model, moe_config, cond_ffn_filter)
323+
324+
out_q = model(x).clone()
325+
assert isinstance(model.experts.w1, FakeExtraDimTensor)
326+
327+
mod_c = torch.compile(model, mode="reduce-overhead")
328+
329+
mod_c(x)
330+
mod_c(x)
331+
332+
out_qc = mod_c(x).clone()
333+
334+
self.assertTrue(compute_error(out_q, out) > 30)
335+
self.assertTrue(compute_error(out_qc, out) > 30)
336+
337+
338+
if __name__ == "__main__":
339+
run_tests()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
Int4OpaqueTensor,
9595
Int4PreshuffledTensor,
9696
Int4Tensor,
97+
IntxOpaqueTensor,
9798
IntxUnpackedToInt8Tensor,
9899
)
99100
from .smoothquant import (
@@ -163,6 +164,7 @@
163164
"Int4Tensor",
164165
"Int4PreshuffledTensor",
165166
"Int4MarlinSparseTensor",
167+
"IntxOpaqueTensor",
166168
"IntxUnpackedToInt8Tensor",
167169
"Float8Tensor",
168170
"Int4OpaqueTensor",

0 commit comments

Comments
 (0)