From f08f6de3617efe57dc87a48ea61cac2abd178c99 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 24 May 2024 22:32:08 -0700 Subject: [PATCH] Add support for `AQTStorage` and `PlainAQTStorage` Summary: Today `AffineQuantizedTensor` has hardcoded storage format of `int_data`, `scale`, `zero_point`. But this does not work if we want to support packed weight. In this PR, we added support to hide the storage details for `AffineQuantizedTensor` in a family of tensor subclasses, all should inherit from the base Storage type: `AQTStorage` (affine quantized tensor storage) This PR just added support for a plain storage tensor (`PlainAQTStorage`) that stores `int_data`, `scale` and `zero_point` tensors directly, in the next PR we'll also support storing packed weight (result of `torch.ops.aten._convert_weight_to_int4pack`) in a different type of `AQTStorage`. `AffineQuantizedTensor` will have the following: - storage_tensor: AQTStorage (can store data of different storage formats) - storage_layout: str (a string represents the type of storage_tensor we have, can be used in dispatch) Test Plan: python test/quantization/test_quant_api.py Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 8 +- torchao/dtypes/aqt.py | 302 +++++++++++++++++++++++----- torchao/quantization/quant_api.py | 2 +- 3 files changed, 259 insertions(+), 53 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 9aae14dd83..35b0107836 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -110,8 +110,8 @@ def __init__(self, m=64, n=32, k=64): self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) - def example_inputs(self, batch_size=1): - return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),) + def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): + return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) def forward(self, x): x = self.linear1(x) @@ -450,7 +450,7 @@ def test_quantized_tensor_subclass_int4(self): # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) - example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) + example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda") groupsize = 32 m = quantize(m, get_apply_int4wo_quant(groupsize=groupsize)) @@ -496,7 +496,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self): m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel - example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20))) + example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") m = quantize(m, get_apply_int8dyn_quant()) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index 7619545f52..f4b758ddca 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -18,14 +18,14 @@ def _aqt_is_int8(aqt): """Check if an AffineQuantizedTensor is int8 quantized Tensor""" return ( - aqt.int_data.dtype == torch.int8 and + aqt.layout_tensor.dtype == torch.int8 and aqt.quant_min is None or aqt.quant_min == -128 and aqt.quant_max is None or aqt.quant_max == 127 ) def _aqt_is_int8_reduced_range(aqt): return ( - aqt.int_data.dtype == torch.int8 and + aqt.layout_tensor.dtype == torch.int8 and aqt.quant_min == -127 and aqt.quant_max is None or aqt.quant_max == 127 ) @@ -34,7 +34,7 @@ def _aqt_is_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.int_data.dtype == torch.int32 and + aqt.layout_tensor.dtype == torch.int32 and aqt.quant_min is None or aqt.quant_min == 0 and aqt.quant_max is None or aqt.quant_max == 15 ) @@ -69,6 +69,218 @@ def implements_aqt_aten_ops(aten_ops): def implements_aqt_torch_function(torch_function): return implements_torch_function(AffineQuantizedTensor, torch_function) +_EXTENDED_LAYOUT_TO_AQT_LAYOUT_CLS: Dict[str, Callable] = {} + +def register_aqt_layout_cls(extended_layout: str): + def decorator(layout_cls): + layout_cls.extended_layout = extended_layout + _EXTENDED_LAYOUT_TO_AQT_LAYOUT_CLS[extended_layout] = layout_cls + return layout_cls + return decorator + +def get_aqt_layout_cls(extended_layout: str) -> Callable: + if extended_layout not in _EXTENDED_LAYOUT_TO_AQT_LAYOUT_CLS: + raise ValueError(f"extended_layout: {extended_layout} is not supported yet") + return _EXTENDED_LAYOUT_TO_AQT_LAYOUT_CLS.get(extended_layout) + +class AQTLayout(torch.Tensor): + """ + Base class for the layout tensor for `AffineQuantizedTensor` + """ + # this should be set for each layout class during registration + extended_layout: Optional[str] = None + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + pass + + def get_plain() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + pass + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + +@register_aqt_layout_cls("plain") +class PlainAQTLayout(AQTLayout): + """ + Layout storage class for plain layout for affine quantized tensor, it stores int_data, scale, zero_point + tensors directly as plain tensors. + + fields: + int_data (torch.Tensor): the quantized integer data Tensor + scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor + zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + return cls(int_data, scale, zero_point) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]), + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point), + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self): + return self.int_data, self.scale, self.zero_point + + +@register_aqt_layout_cls("tensor_core_tiled") +class TensorCoreTiledAQTLayout(AQTLayout): + """ + Layout storage class for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, + it stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of + dimension: [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] + TODO: innerKTiles is hardcoded as 8 currently, we'll make this an argument later after decided + on the API + + fields: + packed_weight (torch.Tensor): the 4-d packed tensor in a tensor_core_tiled layout + scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor + """ + + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + # TODO: expose the arg + innerKTiles = 8 + self.packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), innerKTiles) + self.scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + + def __tensor_flatten__(self): + return ["packed_weight", "scale_and_zero"] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] + return cls(packed_weight, scale_and_zero) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + if device != "cuda" or (isinstance(device, torch.device) and device.type != "cuda"): + raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device") + return self.__class__( + self.packed_weight.to(kwargs["device"]), + self.scale_and_zero.to(kwargs["device"]) + ) + + def _apply_fn_to_data(self, fn): + self.packed_weight = fn(self.packed_weight) + self.scale_and_zero = fn(self.scale_and_zero) + return self + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self): + raise NotImplementedError( + f"Unpacking for tensor core tiled storage is not yet implemented" + ) class AffineQuantizedTensor(torch.Tensor): """ @@ -82,9 +294,9 @@ class AffineQuantizedTensor(torch.Tensor): quantized_tensor = float_tensor / scale + zero_point fields: - int_data (torch.Tensor): the quantized integer data Tensor - scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor - zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + layout_tensor (AQTLayout): tensor that serves as a general layout storage for the quantized data, + e.g. storing plain tensors (int_data, scale, zero_point) or packed formats depending on device + and operator/kernel block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam e.g. when size is the same as the input tensor dimension, we are using per tensor quantization shape (torch.Size): the shape for the Tensor @@ -103,9 +315,7 @@ class AffineQuantizedTensor(torch.Tensor): @staticmethod def __new__( cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, + layout_tensor: AQTLayout, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[int] = None, @@ -115,9 +325,9 @@ def __new__( strides=None, ): kwargs = {} - kwargs["device"] = int_data.device + kwargs["device"] = layout_tensor.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout ) if dtype is None: dtype = scale.dtype @@ -129,9 +339,7 @@ def __new__( def __init__( self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, + layout_tensor: AQTLayout, block_size: Tuple[int, ...], shape: torch.Size, quant_min: Optional[int] = None, @@ -140,9 +348,7 @@ def __init__( dtype=None, strides=None, ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point + self.layout_tensor = layout_tensor self.block_size = block_size self.quant_min = quant_min self.quant_max = quant_max @@ -157,21 +363,20 @@ def __repr__(self): def dequantize(self, output_dtype=None): if output_dtype is None: output_dtype = self.dtype - return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) + int_data, scale, zero_point = self.layout_tensor.get_plain() + return dequantize_affine(int_data, self.block_size, scale, zero_point, int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + layout_tensor = tensor_data_dict["layout_tensor"] block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes return cls( - int_data, - scale, - zero_point, + layout_tensor, block_size, shape if outer_size is None else outer_size, quant_min, @@ -195,13 +400,15 @@ def from_float( zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + extended_layout: str = "plain", ): scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + + layout_cls = get_aqt_layout_cls(extended_layout) + layout_tensor = layout_cls(int_data, scale, zero_point) return cls( - int_data, - scale, - zero_point, + layout_tensor, block_size, input_float.shape, quant_min, @@ -210,6 +417,10 @@ def from_float( dtype=input_float.dtype ) + @property + def layout(self) -> str: + return self.layout_tensor.extended_layout + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs @@ -238,9 +449,7 @@ def _get_to_kwargs(self, *args, **kwargs): def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) return self.__class__( - self.int_data.to(kwargs["device"]), - self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), + self.layout_tensor.to(kwargs["device"]), self.block_size, self.shape, self.quant_min, @@ -251,9 +460,7 @@ def to(self, *args, **kwargs): def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.int_data), - fn(self.scale), - fn(self.zero_point), + fn(self.layout_tensor), self.block_size, self.shape, self.quant_min, @@ -308,7 +515,9 @@ def functional_linear(*args, **kwargs): if ( is_cuda and input_is_int8 and - input_tensor_dtype_is_expected + input_tensor_dtype_is_expected and + input_tensor.layout == "plain" and + weight_qtensor.layout == "plain" ): # # 1. do the matrix form of dot(X_i, W_j) @@ -321,10 +530,10 @@ def functional_linear(*args, **kwargs): # value of a float 16, (which results in a value of inf even if multiplying # by the other scale would bring it within the expected range) - x_vals_int8 = input_tensor.int_data - x_scales = input_tensor.scale - w_vals_int8_t = weight_qtensor.int_data.contiguous().t() - w_scales = weight_qtensor.scale + x_vals_int8 = input_tensor.layout_tensor.int_data + x_scales = input_tensor.layout_tensor.scale + w_vals_int8_t = weight_qtensor.layout_tensor.int_data.contiguous().t() + w_scales = weight_qtensor.layout_tensor.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) @@ -344,35 +553,32 @@ def functional_linear(*args, **kwargs): # weight only quantization # TODO: enable cpu and mps path as well # TODO: make sure weight dimension matches the expectation of the int4mm kernel - # TODO: move this to TinygemmAffineQuantizedTensor if ( is_cuda and weight_is_uint4 and weight_qtensor.dtype == torch.bfloat16 and len(weight_qtensor.shape) == 2 and weight_qtensor.block_size[0] == 1 and - weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT + weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and + weight_qtensor.layout == "tensor_core_tiled" ): # groupwise int4 quantization - # TODO: currently doing packing on the fly, we'll need to figure out - # the API to do packing before hand - # TODO: expose the arg - innerKTiles = 8 - packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) - scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) groupsize = weight_qtensor.block_size[-1] - return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) + packed_weight = weight_qtensor.layout_tensor.packed_weight + scale_and_zero = weight_qtensor.layout_tensor.scale_and_zero + return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scale_and_zero) elif ( is_cpu and weight_is_int8 and len(weight_qtensor.shape) == 2 and len(weight_qtensor.block_size) == 2 and weight_qtensor.block_size[0] == 1 and - weight_qtensor.block_size[1] == weight_qtensor.shape[1] + weight_qtensor.block_size[1] == weight_qtensor.shape[1] and + weight_qtensor.layout == "plain" ): # TODO: enable mps path as well # per channel int8 weight only quantizated mm - return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) + return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.layout_tensor.int_data, weight_qtensor.layout_tensor.scale) else: weight_tensor = weight_qtensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 02678ab2cd..7ec88c7498 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -348,7 +348,7 @@ def apply_int4wo_quant(weight): preserve_zero = False zero_point_dtype = torch.bfloat16 zero_point_domain = ZeroPointDomain.FLOAT - return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) + return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, extended_layout="tensor_core_tiled") return apply_int4wo_quant