From ac2e283848370885c125fa3160c98a24024f3c02 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 17 Jun 2024 16:23:56 -0700 Subject: [PATCH] Add decorator for custom op and inductor decomp registration Summary: This PR adds a decorator to register custom op and also an inductor dcomposition. The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op, this is because some backends like xnnpack wants to work with these higher level ops. Test Plan: regression tests: `python test/quantization/test_quant_api.py` `python test/integration/test_integration.py` also need to check performance with `python tutorials/quantize_vit/run_vit_b_quant.py` Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 1 - test/quantization/test_quant_api.py | 4 - test/quantization/test_quant_primitives.py | 32 +++--- torchao/dtypes/affine_quantized_tensor.py | 20 ++-- torchao/quantization/quant_api.py | 23 ++-- torchao/quantization/quant_primitives.py | 117 ++++++++++++--------- torchao/quantization/subclass.py | 4 - torchao/quantization/utils.py | 16 ++- 8 files changed, 106 insertions(+), 111 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b4fbcb152a..f76339d65b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -37,7 +37,6 @@ choose_qparams_affine, quantize_affine, dequantize_affine, - MappingType, ) from torchao.quantization.utils import ( dequantize_per_channel, diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b22a157568..26146f4889 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -22,10 +22,6 @@ from torchao.dtypes import ( AffineQuantizedTensor, ) -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) from torchao.quantization.subclass import ( LinearActQuantizedTensor, Int8WeightOnlyQuantizedLinearWeight, diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 0e5388c301..ea4ed0bfb6 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -12,8 +12,6 @@ quantize_affine, dequantize_affine, choose_qparams_affine, - MappingType, - ZeroPointDomain, ) # TODO: remove test for utils? from torchao.quantization.utils import ( @@ -167,7 +165,7 @@ def test_choose_qparams_group_sym(self): we don't include it here. We may just replace it with per block quant """ input = torch.randn(10, 10) - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" dtype = torch.int8 block_size = (1, 2) eps = torch.finfo(torch.float32).eps @@ -183,7 +181,7 @@ def test_choose_qparams_group_sym(self): @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_token_asym(self): input = torch.randn(10, 10) - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" dtype = torch.int8 block_size = (1, 10) scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) @@ -198,7 +196,7 @@ def test_choose_qparams_token_asym(self): @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_tensor_asym(self): input = torch.randn(10, 10) - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps @@ -217,7 +215,7 @@ def test_choose_qparams_tensor_asym(self): @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_choose_qparams_tensor_sym(self): input = torch.randn(10, 10) - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" dtype = torch.int8 block_size = (10, 10) eps = torch.finfo(torch.float32).eps @@ -237,7 +235,7 @@ def test_quantize_activation_per_token_abs_max(self): input = torch.randn(10, 10) quantized_ref, scale_ref = quantize_activation_per_token_absmax(input) - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" block_size = list(input.shape) for i in range(len(block_size) - 1): block_size[i] = 1 @@ -278,7 +276,7 @@ def test_quantize_activation_per_token_abs_max_dtype(self): @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_group_sym(self): input = torch.randn(10, 10) - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" dtype = torch.int8 block_size = (1, 2) scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) @@ -303,7 +301,7 @@ def test_quantize_dequantize_group_sym(self): @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym(self): input = torch.randn(10, 10) - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" dtype = torch.int8 block_size = (10, 1) scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) @@ -327,7 +325,7 @@ def test_quantize_dequantize_channel_asym(self): @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_tensor_asym(self): input = torch.randn(10, 10) - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" dtype = torch.int8 block_size = (10, 10) output_dtype = torch.float32 @@ -351,7 +349,7 @@ def test_quantize_dequantize_tensor_asym(self): @unittest.skipIf(is_fbcode(), "broken in fbcode") def test_quantize_dequantize_channel_asym_4d(self): input = torch.randn(3, 3, 10, 10) - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" dtype = torch.int8 block_size = (3, 3, 1, 10) scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) @@ -373,7 +371,7 @@ def test_quantize_dequantize_channel_asym_4d(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch version is 2.3 or lower") def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): input = torch.randn(3, 3, 10, 10) - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" dtype = torch.int8 block_size = (3, 3, 2, 2) scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) @@ -384,7 +382,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): def test_choose_qparams_tensor_asym_eps(self): input = torch.zeros(10, 10) - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" dtype = torch.int8 block_size = (10, 10) scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) @@ -406,7 +404,7 @@ def test_raises(self): """Make sure some errors are raised when user requested an unsupported type of quantization """ input = torch.randn(10, 10) - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" dtype = torch.int8 block_size = (10, 10) scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) @@ -425,7 +423,7 @@ def test_not_preserve_zero_not_supported(self): """Making sure preserve_zero == False is not supported for symmetric quant""" input = torch.randn(10, 256) n_bit = 4 - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" dtype = torch.int8 block_size = (1, 128) quant_min = 0 @@ -453,7 +451,7 @@ def test_get_groupwise_affine_qparams(self): n_bit = 4 scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" dtype = torch.int8 block_size = (1, 128) quant_min = 0 @@ -473,7 +471,7 @@ def test_get_groupwise_affine_qparams(self): scale_dtype=scale_dtype, zero_point_dtype=zero_point_dtype, preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT, + zero_point_domain="float", ) self.assertTrue(torch.equal(scale, scale_ref)) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 83c7d22fb4..129a30b947 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -6,8 +6,6 @@ choose_qparams_affine, quantize_affine, dequantize_affine, - ZeroPointDomain, - MappingType, int_scaled_matmul, ) from torchao.quantization.utils import ( @@ -98,12 +96,12 @@ class AffineQuantizedTensor(torch.Tensor): shape (torch.Size): the shape for the Tensor quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float" if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization - default is ZeroPointDomain.INT + default is "int" input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object dtype: dtype for external representation of the tensor, e.g. torch.float32 """ @@ -116,7 +114,7 @@ def __new__( shape: torch.Size, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: str = "int", dtype=None, strides=None, ): @@ -138,7 +136,7 @@ def __init__( shape: torch.Size, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: str = "int", dtype=None, strides=None, ): @@ -184,7 +182,7 @@ def __tensor_unflatten__( def from_float( cls, input_float: torch.Tensor, - mapping_type: MappingType, + mapping_type: str, block_size: Tuple[int, ...], target_dtype: torch.dtype, quant_min: Optional[int] = None, @@ -193,7 +191,7 @@ def from_float( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: str = "int", extended_layout: str = "plain", # TODO: this is only for "tensor_core_tiled", need to figure out # the proper API for this arg @@ -520,7 +518,7 @@ def get_plain(self): target_dtype = torch.int32 quant_min = 0 quant_max = 15 - zero_point_domain = ZeroPointDomain.FLOAT + zero_point_domain = "int" assert len(block_size) == 2 and block_size[0] == 1 groupsize = block_size[-1] dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero) @@ -597,7 +595,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): weight_is_uint4 and weight_qtensor.dtype == torch.bfloat16 and len(weight_qtensor.shape) == 2 and - weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and + weight_qtensor.zero_point_domain == "float" and weight_qtensor.extended_layout == "tensor_core_tiled" ): assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}" @@ -640,7 +638,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias): len(weight_qtensor.block_size) == 2 and weight_qtensor.block_size[0] == 1 and weight_qtensor.block_size[1] == weight_qtensor.shape[1] and - weight_qtensor.zero_point_domain == ZeroPointDomain.INT and + weight_qtensor.zero_point_domain == "int" and weight_qtensor.extended_layout == "plain" ): # TODO: enable cpu and mps efficient path diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3a1516d9b5..8559454782 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -31,10 +31,6 @@ to_linear_act_quantized, ) -from .quant_primitives import ( - MappingType, - ZeroPointDomain, -) from .weight_only import WeightOnlyInt8QuantLinear from .unified import Quantizer, TwoStepQuantizer from .GPTQ import ( @@ -270,7 +266,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens # weight settings groupsize = 32 - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" block_size = (1, groupsize) target_dtype = torch.int32 quant_min = 0 @@ -278,7 +274,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens eps = 1e-6 preserve_zero = False zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + zero_point_domain = "float" apply_weight_quant = lambda x: to_affine_quantized( x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, @@ -319,7 +315,7 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight): from torchao.dtypes import to_affine_quantized # weight settings - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" block_size = (1, group_size) target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps @@ -336,7 +332,7 @@ def get_per_token_block_size(x): return block_size # input settings - input_mapping_type = MappingType.ASYMMETRIC + input_mapping_type = "asymmetric" input_target_dtype = torch.int8 input_quant_func = lambda x: to_affine_quantized(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) @@ -360,8 +356,7 @@ def int4_weight_only(group_size=128, inner_k_tiles=8): def apply_int4_weight_only_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized - - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" block_size = (1, group_size) target_dtype = torch.int32 quant_min = 0 @@ -369,7 +364,7 @@ def apply_int4_weight_only_quant(weight): eps = 1e-6 preserve_zero = False zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + zero_point_domain = "float" return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled", inner_k_tiles=inner_k_tiles) return apply_int4_weight_only_quant @@ -383,7 +378,7 @@ def apply_int8wo_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 @@ -406,7 +401,7 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight): # avoid circular dep from torchao.dtypes import to_affine_quantized # weight settings - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" def get_weight_block_size(x): return (1, x.shape[1]) target_dtype = torch.int8 @@ -420,7 +415,7 @@ def get_per_token_block_size(x): block_size[i] = 1 return block_size - input_mapping_type = MappingType.SYMMETRIC + input_mapping_type = "symmetric" input_target_dtype = torch.int8 input_eps = 1e-5 input_quant_min = -127 diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a78c42605a..d49806289d 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -4,13 +4,16 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum +from enum import Enum, auto from typing import List, Optional, Tuple, Dict import torch from torchao.kernel.intmm import int_scaled_matmul from torchao.kernel.intmm import safe_int_mm -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_5, +) __all__ = [ @@ -21,31 +24,6 @@ "dequantize_affine", ] -class MappingType(Enum): - """How floating point number is mapped to integer number - - symmetric mapping means floating point range is symetrically mapped to integer range - let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) - we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) - e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) - - asymmetric mapping means we just directly map the floating point range to integer range, - for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter - based on this mapping - e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) - """ - SYMMETRIC = 0 - ASYMMETRIC = 1 - -class ZeroPointDomain(Enum): - """Enum that indicate whether zero_point is in integer domain or floating point domain - - integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) - float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale - """ - INT = 0 - FLOAT = 1 - """ Map from dtype to the bound value of integers TODO: maybe can replace this with call to torch.iinfo @@ -130,17 +108,32 @@ def _get_reduction_params(block_size, input_size): cur_dim += 1 return shape_for_reduction, reduction_dims +def register_custom_op(name: str): + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + if TORCH_VERSION_AFTER_2_5: + opdef = torch.library.custom_op(name, mutates_args=())(fn) + opdef.register_fake(fn) + register_decomposition([opdef._opoverload])(fn) + return opdef + else: + return fn + + return decorator + +@register_custom_op("quant::quantize_affine") def quantize_affine( input: torch.Tensor, - block_size: Tuple[int, ...], + block_size: List[int], scale: torch.Tensor, zero_point: Optional[torch.Tensor], output_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, -): + zero_point_domain: str = "int", +) -> torch.Tensor: """ Args: input (torch.Tensor): original float32, float16 or bfloat16 Tensor @@ -151,12 +144,12 @@ def quantize_affine( output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" for "float" if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization - default is ZeroPointDomain.INT + default is "int" Note: How can block_size represent different granularities? @@ -170,6 +163,11 @@ def quantize_affine( per_group (groupsize=2) | (3, 3, 10, 2) per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + Note: + zero_point_domain also affects how the floating point value is quantized: + + integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) + float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale Output: quantized tensor with requested dtype @@ -188,12 +186,12 @@ def quantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == "int": quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ).to(output_dtype) else: - assert zero_point_domain == ZeroPointDomain.FLOAT + assert zero_point_domain == "float" mid_point = (quant_max + quant_min + 1) / 2 min_val = zero_point - scale * mid_point quant = ( @@ -205,15 +203,16 @@ def quantize_affine( return quant +@register_custom_op("quant::dequantize_affine") def dequantize_affine( input: torch.Tensor, - block_size: Tuple[int, ...], + block_size: List[int], scale: torch.Tensor, zero_point: Optional[torch.Tensor], input_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + zero_point_domain: str = "int", *, output_dtype: torch.dtype = torch.float32, ): @@ -228,12 +227,12 @@ def dequantize_affine( quant_min (Optional[int]): minimum quantized value for input Tensor quant_max (Optional[int]): maximum quantized value for input Tensor output_dtype (torch.dtype): dtype for output Tensor, default is fp32 - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float" if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization - default is ZeroPointDomain.INT + default is "int" Output: dequantized Tensor, with requested dtype or fp32 @@ -255,7 +254,7 @@ def dequantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == "int": # Force a copy to avoid input modification due # to upcoming in-place operations. dequant = input.to(torch.int32, copy=True) @@ -264,7 +263,7 @@ def dequantize_affine( dequant = dequant.to(output_dtype) dequant *= scale else: - assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" + assert zero_point_domain == "float", f"Unexpected zero point domain: {zero_point_domain}" mid_point = (quant_max + quant_min + 1) / 2 # This should allocate new memory and avoid input modification dequant = input - mid_point @@ -275,10 +274,11 @@ def dequantize_affine( return dequant.view(original_shape).to(output_dtype) +@register_custom_op("quant::choose_qparams_affine") def choose_qparams_affine( input: torch.Tensor, - mapping_type: MappingType, - block_size: Tuple[int, ...], + mapping_type: str, + block_size: List[int], target_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, @@ -286,13 +286,13 @@ def choose_qparams_affine( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain = ZeroPointDomain.INT, + zero_point_domain: str = "int", ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input (torch.Tensor): fp32, bf16, fp16 input Tensor - mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric - block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + mapping_type (str): determines how the qparams are calculated, "symmetric" or "asymmetric" + block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam e.g. when size is the same as the input tensor dimension, we are using per tensor quantization target_dtype (torch.dtype): dtype for target quantized Tensor quant_min (Optional[int]): minimum quantized value for target quantized Tensor @@ -310,18 +310,32 @@ def choose_qparams_affine( If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + zero_point_domain (str): the domain that zero_point is in, should be eitehr "int" or "float" if zero_point is in integer domain, zero point is added to the quantized integer value during quantization if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) value during quantization - default is ZeroPointDomain.INT + default is "int" + + + Note: + How floating point number is mapped to integer number? + + symmetric mapping means floating point range is symetrically mapped to integer range + let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4) + we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7) + e.g. scale = (10.2 - (-10.2)) / (7 - (-8)) + + asymmetric mapping means we just directly map the floating point range to integer range, + for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter + based on this mapping + e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) Output: Tuple of scales and zero_points Tensor with requested dtype """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) - assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}" + assert mapping_type in ["symmetric", "asymmetric"], f"Unsupported mapping type: {mapping_type}" if scale_dtype is None: scale_dtype = input.dtype @@ -342,21 +356,22 @@ def choose_qparams_affine( min_val_neg = min_val max_val_pos = max_val - if mapping_type == MappingType.SYMMETRIC: + if mapping_type == "symmetric": max_val_pos = torch.max(-min_val_neg, max_val_pos) scale = max_val_pos / (float(quant_max - quant_min) / 2) if not preserve_zero: raise ValueError("preserve_zero == False is not supported for symmetric quantization") - if zero_point_domain != ZeroPointDomain.INT: + if zero_point_domain != "int": raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: + assert mapping_type == "asymmetric" scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) if preserve_zero: zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) else: - assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain" + assert zero_point_domain == "float", "if not preserve_zero, zero_point must be in FLOAT domain" mid_point = (quant_max + quant_min + 1) / 2 zero_point = min_val_neg + scale * mid_point diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index a2801a622f..1eeb03b591 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -9,10 +9,6 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from .quant_primitives import ( - MappingType, -) - from .utils import ( find_multiple, dequantize_per_channel, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3e3943c93c..f1ab56f0e5 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -10,8 +10,6 @@ import torch.nn.utils.parametrize as parametrize from torchao.utils import find_multiple from .quant_primitives import ( - MappingType, - ZeroPointDomain, choose_qparams_affine, quantize_affine, dequantize_affine, @@ -132,7 +130,7 @@ def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None): # and slightly modified def quantize_activation_per_token_absmax(t): # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" block_size = list(t.shape) for i in range(len(block_size) - 1): block_size[i] = 1 @@ -241,7 +239,7 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): block_size = (1, x.shape[1]) zero_point_dtype = torch.int64 - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype) quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max) return quant, scale, zero_point @@ -278,7 +276,7 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16 assert w.dim() == 2 assert n_bit <= 8, f"only n_bit smaller than 8 is supported, got: {n_bit}" - mapping_type = MappingType.ASYMMETRIC + mapping_type = "asymmetric" target_dtype = torch.int32 block_size = (1, groupsize) quant_min = 0 @@ -298,7 +296,7 @@ def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16 scale_dtype=scale_dtype, zero_point_dtype=zero_point_dtype, preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT + zero_point_domain="float", ) return scale.to(dtype=dtype).reshape(w.shape[0], -1), zero_point.to( @@ -347,7 +345,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_min = 0 quant_max = 2 ** n_bit - 1 - return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) + return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain="float") def groupwise_affine_dequantize_tensor_from_qparams( w_int4x8, @@ -367,7 +365,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( input_dtype = torch.int32 quant_min = 0 quant_max = 2**n_bit - 1 - return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype) + return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain="float", output_dtype=scales.dtype) def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): @@ -401,7 +399,7 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float assert w.dim() == 2 assert n_bit <= 8, f"unsupported n_bit: {n_bit}" - mapping_type = MappingType.SYMMETRIC + mapping_type = "symmetric" block_size = (1, groupsize) eps = torch.finfo(torch.float32).eps ranges = {}