From 75d839791407f092f2374e785c17a2f1407097a4 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 30 Aug 2024 13:22:04 -0700 Subject: [PATCH 1/5] Allow quantized linear registration in a different file Summary: Previously there was some ordering that we need to maintain for quantized linear dispatch table in AffineQuantizedTensor, the reason is there is a fallback entry that dequantizes the input: https://github.com/pytorch/ao/blob/ba2d3b1333b90ccd0186216649a1c58c6a17ce56/torchao/dtypes/affine_quantized_tensor.py#L1195 so the dispatches with two inputs quantized (static or dynamic quantization) must come before this entry and dispatches with weight only quantization, however the fallback is not really used/needed in practice, since people typically just want to call into a very specific kernel. From offline discussions with @drisspg and @HDCharles, it might be useful to have a "quantized_linear_impl" for `LayoutType`, this allows people to specify and check which quantized_linear_impl they want to use to make sure they can call into the specific kernel, when this field is set, we'll not run the fallback path for quantized linear either (dequantize all activation and weight tensors and run the floating point linear op) I think this can be added for a specific layout type if people want to and we don't have to enforce this in the base `LayoutType` Test Plan: python test/dtypes/test_affine_quantized.py -k test_register_new_dispatch Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_affine_quantized.py | 33 +++++++++++ torchao/dtypes/affine_quantized_tensor.py | 69 +++++++++++++---------- torchao/dtypes/utils.py | 8 ++- torchao/quantization/quant_api.py | 2 +- 4 files changed, 80 insertions(+), 32 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index a4f5010981..a8a6a117eb 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -87,6 +87,39 @@ def test_to_device(self, apply_quant): ql = apply_quant(l) ql.cuda() + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_register_new_dispatch(self): + from torchao.dtypes.affine_quantized_tensor import _register_aqt_quantized_linear_dispatch + from torchao.dtypes import to_affine_quantized_intx + from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization.quant_primitives import MappingType + + def dispatch_condition(input_tensor, weight_tensor, bias): + return ( + isinstance(weight_tensor, AffineQuantizedTensor) and + weight_tensor.quant_min == 0 and + weight_tensor.quant_max == 2**6-1 + ) + + def impl(input_tensor, weight_tensor, bias): + # this is just for testing, normally people will call into uint6 weight only + # quantized linear operator here + assert False, "dispatching to my impl for uint6 weight only quant" + + _register_aqt_quantized_linear_dispatch(dispatch_condition, impl) + + def apply_uint6_weight_only_quant(linear): + linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False) + return linear + + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + apply_uint6_weight_only_quant(l) + + example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda") + with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"): + l(example_input) + + common_utils.instantiate_parametrized_tests(TestAffineQuantized) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 06bb8aeff9..39974cdbff 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -88,9 +88,22 @@ class QuantizedLinearNotImplementedError(NotImplementedError): pass -_QLINEAR_DISPATCH_TABLE = {} -def _register_quantized_linear_dispatch(dispatch_condition, impl): - _QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl +_AQT_QLINEAR_DISPATCH_TABLE = {} +def _register_aqt_quantized_linear_dispatch(dispatch_condition, impl): + """Register a dispatch for quantized linear op with dispatch_condition function and impl function + both takes three arguments: + input_tensor: dimension is (M1, M2, ..., in_features) + weight_tensor: dimension is (out_features, in_features) + bias: dimension is (out_features,) + so that these can be shared by F.linear, aten.mm, aten.addmm dispatches + + Args: + `dispatch_condition` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], bool]: the dispatch + condition for a specialized quantized linear implementation, e.g. bfloat16 activation + uint4 weight + `impl` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: the specialized + quantized linear implementation + """ + _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl class AffineQuantizedTensor(TorchAOBaseTensor): """ @@ -189,7 +202,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor @staticmethod def _quantized_linear_op(input_tensor, weight_tensor, bias): - for dispatch_condition, impl in _QLINEAR_DISPATCH_TABLE.items(): + for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): if dispatch_condition(input_tensor, weight_tensor, bias): return impl(input_tensor, weight_tensor, bias) raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") @@ -440,7 +453,7 @@ def extra_repr(self): @dataclass(frozen=True) class Float8LayoutType(LayoutType): - mm_config: Optional[ScaledMMConfig] + mm_config: Optional[ScaledMMConfig] = None @register_layout_cls(PlainLayoutType) @@ -598,13 +611,13 @@ def from_plain( @register_layout_cls(Float8LayoutType) class Float8AQTLayout(AQTLayout): - """ + """ Layout storage class for float8 layout for affine quantized tensor """ float8_data: torch.Tensor scale: torch.Tensor transposed: bool - + def __new__( cls, float8_data: torch.Tensor, @@ -639,7 +652,7 @@ def _apply_fn_to_data(self, fn): fn(self.float8_data) fn(self.scale) return self - + def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) return self.__class__( @@ -976,21 +989,6 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh y += bias return y -# this is for the case when linear activation is quantized, but is not caught by the previous -# conditions that expects a quantized activation, we just dequantize the activation so that -# it can continue with the weight only quantization dispatches -# NOTE: this is a fallback path that must be registered after all the implementations that expects -# input tensor to be quantized -def _linear_quantized_act_fallback_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) - ) - -def _linear_quantized_act_fallback_impl(input_tensor, weight_tensor, bias): - input_tensor = input_tensor.dequantize() - # dequantize activation and redispatch to F.linear - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): return ( # input is native bfloat16 tensor @@ -1187,19 +1185,18 @@ def _linear_fp_act_fp8_weight_impl( ).reshape(out_shape) -def _register_quantized_linear_dispatches(): +def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), (_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl), - (_linear_quantized_act_fallback_check, _linear_quantized_act_fallback_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), ]: - _register_quantized_linear_dispatch(dispatch_condition, impl) + _register_aqt_quantized_linear_dispatch(dispatch_condition, impl) -_register_quantized_linear_dispatches() +_register_aqt_quantized_linear_dispatches() @implements(torch.nn.functional.linear) def _(func, types, args, kwargs): @@ -1216,7 +1213,11 @@ def _(func, types, args, kwargs): # make the branches easier to understand in `_quantized_linear_op` try: return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError: + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfiic quantized linear implementation + if isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.layout_type.quantized_linear_impl is not None: + raise e + if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, AffineQuantizedTensor): @@ -1239,7 +1240,11 @@ def _(func, types, args, kwargs): try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError: + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfiic quantized linear implementation + if isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.layout_type.quantized_linear_impl is not None: + raise e + if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, AffineQuantizedTensor): @@ -1259,7 +1264,11 @@ def _(func, types, args, kwargs): try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError: + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfiic quantized linear implementation + if isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.layout_type.quantized_linear_impl is not None: + raise e + if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() if isinstance(weight_tensor, AffineQuantizedTensor): diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 036a5ca929..3a197da05e 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,5 @@ import torch -from typing import Dict, Callable, Union, Tuple +from typing import Dict, Callable, Union, Tuple, Optional from collections import defaultdict import functools from dataclasses import dataclass @@ -73,6 +73,12 @@ class MyTensor(torch.Tensor): """ Base class for different LayoutType, should not be instantiated directly +used to allow users to pass around configurations for the layout tensor, e.g. inner_k_tiles +for int4 tensor core tiled layout + +Note: layout is an abstraction not only for custom data representation, it is also used for how the +layout interacts with different operators, e.g. the same data representation can have different +behaviors when running the same operator, e.g. transpose, quantized_linear. """ @dataclass(frozen=True) class LayoutType: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index aa2d3b3f93..ba670f23b7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -498,7 +498,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn): """ Applies float8 weight-only symmetric per-channel quantization to linear layers. - + Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. From f43ef3d6b1d51c7d91e4b720b8eb2bd7d1f2e933 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 30 Aug 2024 13:57:00 -0700 Subject: [PATCH 2/5] fix error --- torchao/dtypes/affine_quantized_tensor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 39974cdbff..b22d8d923a 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1214,8 +1214,8 @@ def _(func, types, args, kwargs): try: return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfiic quantized linear implementation - if isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.layout_type.quantized_linear_impl is not None: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1241,8 +1241,8 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfiic quantized linear implementation - if isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.layout_type.quantized_linear_impl is not None: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1265,8 +1265,8 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfiic quantized linear implementation - if isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.layout_type.quantized_linear_impl is not None: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: raise e if isinstance(input_tensor, AffineQuantizedTensor): From ae7dd160ac72a7fef2a542fcd7c4dc3fb28f747d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 30 Aug 2024 15:20:33 -0700 Subject: [PATCH 3/5] de-register dispatch --- test/dtypes/test_affine_quantized.py | 7 ++++++- torchao/dtypes/affine_quantized_tensor.py | 9 +++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index a8a6a117eb..e47bafe9a4 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -89,7 +89,10 @@ def test_to_device(self, apply_quant): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_register_new_dispatch(self): - from torchao.dtypes.affine_quantized_tensor import _register_aqt_quantized_linear_dispatch + from torchao.dtypes.affine_quantized_tensor import ( + _register_aqt_quantized_linear_dispatch, + _deregister_aqt_quantized_linear_dispatch, + ) from torchao.dtypes import to_affine_quantized_intx from torchao.dtypes import AffineQuantizedTensor from torchao.quantization.quant_primitives import MappingType @@ -119,6 +122,8 @@ def apply_uint6_weight_only_quant(linear): with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"): l(example_input) + _deregister_aqt_quantized_linear_dispatch(dispatch_condition) + common_utils.instantiate_parametrized_tests(TestAffineQuantized) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index b22d8d923a..5d9f3916fc 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -39,6 +39,9 @@ TORCH_VERSION_AT_LEAST_2_5, _is_float8_type ) +import logging + +logger = logging.getLogger(__name__) from torchao.float8.float8_tensor import ScaledMMConfig aten = torch.ops.aten @@ -105,6 +108,12 @@ def _register_aqt_quantized_linear_dispatch(dispatch_condition, impl): """ _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl +def _deregister_aqt_quantized_linear_dispatch(dispatch_condition): + if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: + del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] + else: + logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}") + class AffineQuantizedTensor(TorchAOBaseTensor): """ Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: From c1acd5f63091e0029c34cd92e4bb5fce84842a10 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 2 Sep 2024 10:40:28 -0700 Subject: [PATCH 4/5] make register/deregister fn public --- test/dtypes/test_affine_quantized.py | 8 ++++---- torchao/dtypes/affine_quantized_tensor.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index e47bafe9a4..5f2d1153df 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -90,8 +90,8 @@ def test_to_device(self, apply_quant): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_register_new_dispatch(self): from torchao.dtypes.affine_quantized_tensor import ( - _register_aqt_quantized_linear_dispatch, - _deregister_aqt_quantized_linear_dispatch, + register_aqt_quantized_linear_dispatch, + deregister_aqt_quantized_linear_dispatch, ) from torchao.dtypes import to_affine_quantized_intx from torchao.dtypes import AffineQuantizedTensor @@ -109,7 +109,7 @@ def impl(input_tensor, weight_tensor, bias): # quantized linear operator here assert False, "dispatching to my impl for uint6 weight only quant" - _register_aqt_quantized_linear_dispatch(dispatch_condition, impl) + register_aqt_quantized_linear_dispatch(dispatch_condition, impl) def apply_uint6_weight_only_quant(linear): linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False) @@ -122,7 +122,7 @@ def apply_uint6_weight_only_quant(linear): with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"): l(example_input) - _deregister_aqt_quantized_linear_dispatch(dispatch_condition) + deregister_aqt_quantized_linear_dispatch(dispatch_condition) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 5d9f3916fc..359e04e4a1 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -92,7 +92,7 @@ class QuantizedLinearNotImplementedError(NotImplementedError): _AQT_QLINEAR_DISPATCH_TABLE = {} -def _register_aqt_quantized_linear_dispatch(dispatch_condition, impl): +def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): """Register a dispatch for quantized linear op with dispatch_condition function and impl function both takes three arguments: input_tensor: dimension is (M1, M2, ..., in_features) @@ -108,7 +108,7 @@ def _register_aqt_quantized_linear_dispatch(dispatch_condition, impl): """ _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl -def _deregister_aqt_quantized_linear_dispatch(dispatch_condition): +def deregister_aqt_quantized_linear_dispatch(dispatch_condition): if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] else: From b7c15123ecf610c0e707cdfd9b6b77b87afe4a52 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 2 Sep 2024 10:54:44 -0700 Subject: [PATCH 5/5] rebase and fix error --- torchao/dtypes/affine_quantized_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 359e04e4a1..11b9356adf 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1203,7 +1203,7 @@ def _register_aqt_quantized_linear_dispatches(): (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), ]: - _register_aqt_quantized_linear_dispatch(dispatch_condition, impl) + register_aqt_quantized_linear_dispatch(dispatch_condition, impl) _register_aqt_quantized_linear_dispatches()