Skip to content

Commit ac2e283

Browse files
committed
Add decorator for custom op and inductor decomp registration
Summary: This PR adds a decorator to register custom op and also an inductor dcomposition. The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op, this is because some backends like xnnpack wants to work with these higher level ops. Test Plan: regression tests: `python test/quantization/test_quant_api.py` `python test/integration/test_integration.py` also need to check performance with `python tutorials/quantize_vit/run_vit_b_quant.py` Reviewers: Subscribers: Tasks: Tags:
1 parent bc8599f commit ac2e283

File tree

8 files changed

+106
-111
lines changed

8 files changed

+106
-111
lines changed

test/integration/test_integration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
choose_qparams_affine,
3838
quantize_affine,
3939
dequantize_affine,
40-
MappingType,
4140
)
4241
from torchao.quantization.utils import (
4342
dequantize_per_channel,

test/quantization/test_quant_api.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222
from torchao.dtypes import (
2323
AffineQuantizedTensor,
2424
)
25-
from torchao.quantization.quant_primitives import (
26-
MappingType,
27-
ZeroPointDomain,
28-
)
2925
from torchao.quantization.subclass import (
3026
LinearActQuantizedTensor,
3127
Int8WeightOnlyQuantizedLinearWeight,

test/quantization/test_quant_primitives.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
quantize_affine,
1313
dequantize_affine,
1414
choose_qparams_affine,
15-
MappingType,
16-
ZeroPointDomain,
1715
)
1816
# TODO: remove test for utils?
1917
from torchao.quantization.utils import (
@@ -167,7 +165,7 @@ def test_choose_qparams_group_sym(self):
167165
we don't include it here. We may just replace it with per block quant
168166
"""
169167
input = torch.randn(10, 10)
170-
mapping_type = MappingType.SYMMETRIC
168+
mapping_type = "symmetric"
171169
dtype = torch.int8
172170
block_size = (1, 2)
173171
eps = torch.finfo(torch.float32).eps
@@ -183,7 +181,7 @@ def test_choose_qparams_group_sym(self):
183181
@unittest.skipIf(is_fbcode(), "broken in fbcode")
184182
def test_choose_qparams_token_asym(self):
185183
input = torch.randn(10, 10)
186-
mapping_type = MappingType.ASYMMETRIC
184+
mapping_type = "asymmetric"
187185
dtype = torch.int8
188186
block_size = (1, 10)
189187
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
@@ -198,7 +196,7 @@ def test_choose_qparams_token_asym(self):
198196
@unittest.skipIf(is_fbcode(), "broken in fbcode")
199197
def test_choose_qparams_tensor_asym(self):
200198
input = torch.randn(10, 10)
201-
mapping_type = MappingType.ASYMMETRIC
199+
mapping_type = "asymmetric"
202200
dtype = torch.int8
203201
block_size = (10, 10)
204202
eps = torch.finfo(torch.float32).eps
@@ -217,7 +215,7 @@ def test_choose_qparams_tensor_asym(self):
217215
@unittest.skipIf(is_fbcode(), "broken in fbcode")
218216
def test_choose_qparams_tensor_sym(self):
219217
input = torch.randn(10, 10)
220-
mapping_type = MappingType.SYMMETRIC
218+
mapping_type = "symmetric"
221219
dtype = torch.int8
222220
block_size = (10, 10)
223221
eps = torch.finfo(torch.float32).eps
@@ -237,7 +235,7 @@ def test_quantize_activation_per_token_abs_max(self):
237235
input = torch.randn(10, 10)
238236
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)
239237

240-
mapping_type = MappingType.SYMMETRIC
238+
mapping_type = "symmetric"
241239
block_size = list(input.shape)
242240
for i in range(len(block_size) - 1):
243241
block_size[i] = 1
@@ -278,7 +276,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self):
278276
@unittest.skipIf(is_fbcode(), "broken in fbcode")
279277
def test_quantize_dequantize_group_sym(self):
280278
input = torch.randn(10, 10)
281-
mapping_type = MappingType.SYMMETRIC
279+
mapping_type = "symmetric"
282280
dtype = torch.int8
283281
block_size = (1, 2)
284282
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
@@ -303,7 +301,7 @@ def test_quantize_dequantize_group_sym(self):
303301
@unittest.skipIf(is_fbcode(), "broken in fbcode")
304302
def test_quantize_dequantize_channel_asym(self):
305303
input = torch.randn(10, 10)
306-
mapping_type = MappingType.ASYMMETRIC
304+
mapping_type = "asymmetric"
307305
dtype = torch.int8
308306
block_size = (10, 1)
309307
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
@@ -327,7 +325,7 @@ def test_quantize_dequantize_channel_asym(self):
327325
@unittest.skipIf(is_fbcode(), "broken in fbcode")
328326
def test_quantize_dequantize_tensor_asym(self):
329327
input = torch.randn(10, 10)
330-
mapping_type = MappingType.ASYMMETRIC
328+
mapping_type = "asymmetric"
331329
dtype = torch.int8
332330
block_size = (10, 10)
333331
output_dtype = torch.float32
@@ -351,7 +349,7 @@ def test_quantize_dequantize_tensor_asym(self):
351349
@unittest.skipIf(is_fbcode(), "broken in fbcode")
352350
def test_quantize_dequantize_channel_asym_4d(self):
353351
input = torch.randn(3, 3, 10, 10)
354-
mapping_type = MappingType.ASYMMETRIC
352+
mapping_type = "asymmetric"
355353
dtype = torch.int8
356354
block_size = (3, 3, 1, 10)
357355
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
@@ -373,7 +371,7 @@ def test_quantize_dequantize_channel_asym_4d(self):
373371
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower")
374372
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
375373
input = torch.randn(3, 3, 10, 10)
376-
mapping_type = MappingType.ASYMMETRIC
374+
mapping_type = "asymmetric"
377375
dtype = torch.int8
378376
block_size = (3, 3, 2, 2)
379377
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
@@ -384,7 +382,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
384382

385383
def test_choose_qparams_tensor_asym_eps(self):
386384
input = torch.zeros(10, 10)
387-
mapping_type = MappingType.ASYMMETRIC
385+
mapping_type = "asymmetric"
388386
dtype = torch.int8
389387
block_size = (10, 10)
390388
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype)
@@ -406,7 +404,7 @@ def test_raises(self):
406404
"""Make sure some errors are raised when user requested an unsupported type of quantization
407405
"""
408406
input = torch.randn(10, 10)
409-
mapping_type = MappingType.ASYMMETRIC
407+
mapping_type = "asymmetric"
410408
dtype = torch.int8
411409
block_size = (10, 10)
412410
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype)
@@ -425,7 +423,7 @@ def test_not_preserve_zero_not_supported(self):
425423
"""Making sure preserve_zero == False is not supported for symmetric quant"""
426424
input = torch.randn(10, 256)
427425
n_bit = 4
428-
mapping_type = MappingType.SYMMETRIC
426+
mapping_type = "symmetric"
429427
dtype = torch.int8
430428
block_size = (1, 128)
431429
quant_min = 0
@@ -453,7 +451,7 @@ def test_get_groupwise_affine_qparams(self):
453451
n_bit = 4
454452
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
455453

456-
mapping_type = MappingType.ASYMMETRIC
454+
mapping_type = "asymmetric"
457455
dtype = torch.int8
458456
block_size = (1, 128)
459457
quant_min = 0
@@ -473,7 +471,7 @@ def test_get_groupwise_affine_qparams(self):
473471
scale_dtype=scale_dtype,
474472
zero_point_dtype=zero_point_dtype,
475473
preserve_zero=False,
476-
zero_point_domain=ZeroPointDomain.FLOAT,
474+
zero_point_domain="float",
477475
)
478476

479477
self.assertTrue(torch.equal(scale, scale_ref))

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
choose_qparams_affine,
77
quantize_affine,
88
dequantize_affine,
9-
ZeroPointDomain,
10-
MappingType,
119
int_scaled_matmul,
1210
)
1311
from torchao.quantization.utils import (
@@ -98,12 +96,12 @@ class AffineQuantizedTensor(torch.Tensor):
9896
shape (torch.Size): the shape for the Tensor
9997
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
10098
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
101-
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
99+
zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float"
102100
if zero_point is in integer domain, zero point is added to the quantized integer value during
103101
quantization
104102
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
105103
value during quantization
106-
default is ZeroPointDomain.INT
104+
default is "int"
107105
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
108106
dtype: dtype for external representation of the tensor, e.g. torch.float32
109107
"""
@@ -116,7 +114,7 @@ def __new__(
116114
shape: torch.Size,
117115
quant_min: Optional[int] = None,
118116
quant_max: Optional[int] = None,
119-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
117+
zero_point_domain: str = "int",
120118
dtype=None,
121119
strides=None,
122120
):
@@ -138,7 +136,7 @@ def __init__(
138136
shape: torch.Size,
139137
quant_min: Optional[int] = None,
140138
quant_max: Optional[int] = None,
141-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
139+
zero_point_domain: str = "int",
142140
dtype=None,
143141
strides=None,
144142
):
@@ -184,7 +182,7 @@ def __tensor_unflatten__(
184182
def from_float(
185183
cls,
186184
input_float: torch.Tensor,
187-
mapping_type: MappingType,
185+
mapping_type: str,
188186
block_size: Tuple[int, ...],
189187
target_dtype: torch.dtype,
190188
quant_min: Optional[int] = None,
@@ -193,7 +191,7 @@ def from_float(
193191
scale_dtype: Optional[torch.dtype] = None,
194192
zero_point_dtype: Optional[torch.dtype] = None,
195193
preserve_zero: bool = True,
196-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
194+
zero_point_domain: str = "int",
197195
extended_layout: str = "plain",
198196
# TODO: this is only for "tensor_core_tiled", need to figure out
199197
# the proper API for this arg
@@ -520,7 +518,7 @@ def get_plain(self):
520518
target_dtype = torch.int32
521519
quant_min = 0
522520
quant_max = 15
523-
zero_point_domain = ZeroPointDomain.FLOAT
521+
zero_point_domain = "int"
524522
assert len(block_size) == 2 and block_size[0] == 1
525523
groupsize = block_size[-1]
526524
dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero)
@@ -597,7 +595,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
597595
weight_is_uint4 and
598596
weight_qtensor.dtype == torch.bfloat16 and
599597
len(weight_qtensor.shape) == 2 and
600-
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
598+
weight_qtensor.zero_point_domain == "float" and
601599
weight_qtensor.extended_layout == "tensor_core_tiled"
602600
):
603601
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
@@ -640,7 +638,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
640638
len(weight_qtensor.block_size) == 2 and
641639
weight_qtensor.block_size[0] == 1 and
642640
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
643-
weight_qtensor.zero_point_domain == ZeroPointDomain.INT and
641+
weight_qtensor.zero_point_domain == "int" and
644642
weight_qtensor.extended_layout == "plain"
645643
):
646644
# TODO: enable cpu and mps efficient path

torchao/quantization/quant_api.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@
3131
to_linear_act_quantized,
3232
)
3333

34-
from .quant_primitives import (
35-
MappingType,
36-
ZeroPointDomain,
37-
)
3834
from .weight_only import WeightOnlyInt8QuantLinear
3935
from .unified import Quantizer, TwoStepQuantizer
4036
from .GPTQ import (
@@ -270,15 +266,15 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens
270266
271267
# weight settings
272268
groupsize = 32
273-
mapping_type = MappingType.ASYMMETRIC
269+
mapping_type = "asymmetric"
274270
block_size = (1, groupsize)
275271
target_dtype = torch.int32
276272
quant_min = 0
277273
quant_max = 15
278274
eps = 1e-6
279275
preserve_zero = False
280276
zero_point_dtype = torch.bfloat16
281-
zero_point_domain = ZeroPointDomain.FLOAT
277+
zero_point_domain = "float"
282278
283279
apply_weight_quant = lambda x: to_affine_quantized(
284280
x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
@@ -319,7 +315,7 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight):
319315
from torchao.dtypes import to_affine_quantized
320316

321317
# weight settings
322-
mapping_type = MappingType.SYMMETRIC
318+
mapping_type = "symmetric"
323319
block_size = (1, group_size)
324320
target_dtype = torch.int8
325321
eps = torch.finfo(torch.float32).eps
@@ -336,7 +332,7 @@ def get_per_token_block_size(x):
336332
return block_size
337333

338334
# input settings
339-
input_mapping_type = MappingType.ASYMMETRIC
335+
input_mapping_type = "asymmetric"
340336
input_target_dtype = torch.int8
341337
input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)
342338

@@ -360,16 +356,15 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
360356
def apply_int4_weight_only_quant(weight):
361357
# avoid circular dep
362358
from torchao.dtypes import to_affine_quantized
363-
364-
mapping_type = MappingType.ASYMMETRIC
359+
mapping_type = "asymmetric"
365360
block_size = (1, group_size)
366361
target_dtype = torch.int32
367362
quant_min = 0
368363
quant_max = 15
369364
eps = 1e-6
370365
preserve_zero = False
371366
zero_point_dtype = torch.bfloat16
372-
zero_point_domain = ZeroPointDomain.FLOAT
367+
zero_point_domain = "float"
373368
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles)
374369

375370
return apply_int4_weight_only_quant
@@ -383,7 +378,7 @@ def apply_int8wo_quant(weight):
383378
# avoid circular dep
384379
from torchao.dtypes import to_affine_quantized
385380

386-
mapping_type = MappingType.SYMMETRIC
381+
mapping_type = "symmetric"
387382
target_dtype = torch.int8
388383
eps = torch.finfo(torch.float32).eps
389384
zero_point_dtype = torch.int64
@@ -406,7 +401,7 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
406401
# avoid circular dep
407402
from torchao.dtypes import to_affine_quantized
408403
# weight settings
409-
mapping_type = MappingType.SYMMETRIC
404+
mapping_type = "symmetric"
410405
def get_weight_block_size(x):
411406
return (1, x.shape[1])
412407
target_dtype = torch.int8
@@ -420,7 +415,7 @@ def get_per_token_block_size(x):
420415
block_size[i] = 1
421416
return block_size
422417

423-
input_mapping_type = MappingType.SYMMETRIC
418+
input_mapping_type = "symmetric"
424419
input_target_dtype = torch.int8
425420
input_eps = 1e-5
426421
input_quant_min = -127

0 commit comments

Comments
 (0)