You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add decorator for custom op and inductor decomp registration
Summary:
This PR adds a decorator to register custom op and also an inductor dcomposition.
The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op,
this is because some backends like xnnpack wants to work with these higher level ops.
Test Plan:
regression tests:
`python test/quantization/test_quant_api.py`
`python test/integration/test_integration.py`
also need to check performance with `python tutorials/quantize_vit/run_vit_b_quant.py`
Reviewers:
Subscribers:
Tasks:
Tags:
Copy file name to clipboardExpand all lines: torchao/dtypes/affine_quantized_tensor.py
+9-11Lines changed: 9 additions & 11 deletions
Original file line number
Diff line number
Diff line change
@@ -6,8 +6,6 @@
6
6
choose_qparams_affine,
7
7
quantize_affine,
8
8
dequantize_affine,
9
-
ZeroPointDomain,
10
-
MappingType,
11
9
int_scaled_matmul,
12
10
)
13
11
fromtorchao.quantization.utilsimport (
@@ -98,12 +96,12 @@ class AffineQuantizedTensor(torch.Tensor):
98
96
shape (torch.Size): the shape for the Tensor
99
97
quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
100
98
quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data`
101
-
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
99
+
zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float"
102
100
if zero_point is in integer domain, zero point is added to the quantized integer value during
103
101
quantization
104
102
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
105
103
value during quantization
106
-
default is ZeroPointDomain.INT
104
+
default is "int"
107
105
input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object
108
106
dtype: dtype for external representation of the tensor, e.g. torch.float32
0 commit comments