diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py new file mode 100644 index 0000000000..3a9480f675 --- /dev/null +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_tensor.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao.quantization import ( + IntxWeightOnlyConfig, + quantize_, +) +from torchao.quantization.granularity import PerGroup +from torchao.quantization.utils import compute_error +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, +) + + +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_8, "Need pytorch 2.8+") +class TestIntxUnpackedTensor(TestCase): + def setUp(self): + self.config = IntxWeightOnlyConfig( + weight_dtype=torch.int4, + granularity=PerGroup(32), + version=2, + ) + + def test_embedding(self): + dtype = torch.bfloat16 + device = "cpu" + input = torch.randint(low=0, high=128, size=(10,), device=device) + embedding = torch.nn.Embedding(128, 256, dtype=dtype, device=device) + original = embedding(input) + quantize_(embedding, self.config) + quantized = embedding(input) + error = compute_error(original, quantized) + self.assertTrue(error > 20) + + def test_linear(self): + dtype = torch.bfloat16 + device = "cpu" + input = torch.randn(1, 128, dtype=dtype, device=device) + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + original = linear(input) + quantize_(linear, self.config) + quantized = linear(input) + error = compute_error(original, quantized) + self.assertTrue(error > 20) + + def test_slice(self): + dtype = torch.bfloat16 + device = "cpu" + dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) + + dummy1 = torch.nn.Linear(256, 64, bias=False, dtype=dtype, device=device) + dummy1.weight = torch.nn.Parameter( + dummy.weight.narrow(0, 0, 64), requires_grad=False + ) + + dummy2 = torch.nn.Linear(128, 256, dtype=dtype, device=device) + dummy2.weight = torch.nn.Parameter( + dummy.weight.narrow(1, 0, 128), requires_grad=False + ) + + quantize_(dummy, self.config) + weight1 = dummy.weight.narrow(0, 0, 64) + weight2 = dummy.weight.narrow(1, 0, 128) + + self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, 64)) + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, 64)) + + self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, 128)) + self.assertEqual(weight2.scale, dummy.weight.scale.narrow(1, 0, 4)) + + # check for sliced weight, before and after float8 quantization + # does not differ too much + input = torch.randn(2, 256, dtype=dtype, device=device) + res_ref = dummy1(input) + dummy.weight = torch.nn.Parameter(weight1, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 20 + + input = torch.randn(2, 128, dtype=dtype, device=device) + res_ref = dummy2(input) + dummy.weight = torch.nn.Parameter(weight2, requires_grad=False) + res = dummy(input) + assert compute_error(res, res_ref) > 15 + + def test_slice_and_copy_(self): + device = "cpu" + l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(1024, 1024, dtype=torch.bfloat16, device=device) + ) + quantize_(l, self.config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, 512) + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + assert param.data.zero_point.data_ptr() == param_data.zero_point.data_ptr() + orig_value = param.data.qdata[0][0].item() + + # dummy_l has random input (shouldn't be 0) + dummy_l = torch.nn.Linear(1024, 1024).to(device).to(torch.bfloat16) + quantize_(dummy_l, self.config) + quantized = dummy_l.weight + quantized = quantized.narrow(0, 0, 512) + + param_data.copy_(quantized) + + # making sure param.data is updated + assert param.data.qdata[0][0] != orig_value + + def test_to_dtype(self): + activations_bf16 = torch.randn(1, 128, dtype=torch.bfloat16) + activations_fp32 = torch.randn(1, 128, dtype=torch.float32) + activations_fp16 = torch.randn(1, 128, dtype=torch.float16) + + linear = torch.nn.Linear(128, 256) + quantize_(linear, self.config) + + linear.to(dtype=torch.float16) + linear(activations_fp16) + + linear.to(dtype=torch.float32) + linear(activations_fp32) + + linear.to(dtype=torch.bfloat16) + linear(activations_bf16) + + def test_export(self): + linear = torch.nn.Linear(128, 256) + quantize_(linear, self.config) + ep = torch.export.export(linear, (torch.randn(1, 128),)) + assert "torch.ops.torchao.dequantize_affine.default" in ep.graph_module.code + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 8e98e55178..3c541deb83 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -93,6 +93,7 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, + IntxUnpackedTensor, ) from .smoothquant import ( SmoothFakeDynamicallyQuantizedLinear, @@ -161,6 +162,7 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", + "IntxUnpackedTensor", "Float8Tensor", # smooth quant - subject to change "get_scale", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ed5abb7333..7759665c6c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -75,6 +75,7 @@ Int4MarlinSparseTensor, Int4PreshuffledTensor, Int4Tensor, + IntxUnpackedTensor, QuantizeTensorToFloat8Kwargs, ) from torchao.quantization.transform_module import ( @@ -454,6 +455,10 @@ def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" +def _embedding_extra_repr(self): + return f"num_embeddings={self.weight.shape[0]}, embedding_dim={self.weight.shape[1]}, weight={_quantization_type(self.weight)}" + + def _get_linear_subclass_inserter( constructor, *, allow_requires_grad=False, propagate_bias=False, **kwargs ): @@ -1987,6 +1992,8 @@ class IntxWeightOnlyConfig(AOBaseConfig): mapping_type: MappingType = MappingType.SYMMETRIC scale_dtype: Optional[torch.dtype] = None layout: Layout = QDQLayout() + packing_format: PackingFormat = PackingFormat.UNPACKED_TO_INT8 + version: int = 1 def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig") @@ -2005,16 +2012,13 @@ def __post_init__(self): ) -@register_quantize_module_handler(IntxWeightOnlyConfig) -def _intx_weight_only_transform( - module: torch.nn.Module, config: IntxWeightOnlyConfig -) -> torch.nn.Module: - weight = module.weight +def _intx_weight_only_quantize_tensor(weight, config): weight_dtype = config.weight_dtype granularity = config.granularity mapping_type = config.mapping_type scale_dtype = config.scale_dtype layout = config.layout + packing_format = config.packing_format assert weight.dim() == 2, ( f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" @@ -2029,11 +2033,28 @@ def _intx_weight_only_transform( else: raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") + block_size = (1, group_size) + + if config.version == 2: + if config.packing_format == PackingFormat.UNPACKED_TO_INT8: + new_weight = IntxUnpackedTensor.from_hp( + weight, + block_size, + weight_dtype, + mapping_type=mapping_type, + ) + if scale_dtype is not None and scale_dtype != weight.dtype: + new_weight.scale = new_weight.scale.to(scale_dtype).to(weight.dtype) + return new_weight + else: + raise ValueError(f"Unsupported packing format: {packing_format}") + + # Version 1 quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype] weight = to_affine_quantized_intx( input_float=weight, mapping_type=mapping_type, - block_size=(1, group_size), + block_size=block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, @@ -2043,7 +2064,25 @@ def _intx_weight_only_transform( zero_point_domain=ZeroPointDomain.INT, _layout=layout, ) - module.weight = torch.nn.Parameter(weight, requires_grad=False) + return weight + + +@register_quantize_module_handler(IntxWeightOnlyConfig) +def _intx_weight_only_transform( + module: torch.nn.Module, config: IntxWeightOnlyConfig +) -> torch.nn.Module: + assert hasattr(module, "weight"), ( + "applying intx weight only quant requires module to have weight attribute" + + " but {module} does not have one" + ) + new_weight = _intx_weight_only_quantize_tensor(module.weight, config) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + elif isinstance(module, nn.Embedding): + module.extra_repr = types.MethodType(_embedding_extra_repr, module) + return module diff --git a/torchao/quantization/quantize_/common/packing_format.py b/torchao/quantization/quantize_/common/packing_format.py index 96a29d2990..89acf4eff3 100644 --- a/torchao/quantization/quantize_/common/packing_format.py +++ b/torchao/quantization/quantize_/common/packing_format.py @@ -35,3 +35,8 @@ class PackingFormat(str, Enum): marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization """ MARLIN_SPARSE = "marlin_sparse" + + """ + Unpacked means the subbyte quantized data is stored as int8 + """ + UNPACKED_TO_INT8 = "unpacked_to_int8" diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 8441382243..9eeb0e7dc5 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -11,6 +11,9 @@ from .int4.int4_tensor import ( Int4Tensor, ) +from .intx.intx_unpacked_tensor import ( + IntxUnpackedTensor, +) __all__ = [ "Int4Tensor", @@ -18,4 +21,5 @@ "Int4MarlinSparseTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", + "IntxUnpackedTensor", ] diff --git a/torchao/quantization/quantize_/workflows/intx/__init__.py b/torchao/quantization/quantize_/workflows/intx/__init__.py new file mode 100644 index 0000000000..c0f1f807a5 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/intx/__init__.py @@ -0,0 +1,5 @@ +from .intx_unpacked_tensor import IntxUnpackedTensor + +__all__ = [ + "IntxUnpackedTensor", +] diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py new file mode 100644 index 0000000000..bd6d08b998 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_tensor.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.quantization.quant_primitives import ( + _DTYPE_TO_QVALUE_BOUNDS, + MappingType, + choose_qparams_affine, + dequantize_affine, + quantize_affine, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, + fill_defaults, +) + +__all__ = [ + "IntxUnpackedTensor", +] + +aten = torch.ops.aten + +_FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32] + + +class IntxUnpackedTensor(TorchAOBaseTensor): + """ + intx quantization with unpacked format. Subbyte quantized data is represented as int8. + The range of the quantized values are restricted to the quant_min and quant_max of the target_dtype, e.g., + if target_dtype=torch.int4, qdata will be an int8 tensor with values in [-8, 7]. + Quantization is represented in a decomposed way. + This format is inteded for torch.export use cases. + + Tensor Attributes: + qdata: int data for quantization. + dtype is int8, but the range of the qdata is determined by target_dtype + Shape is the same as original Tensor: (n, k) for 2D tensor + scale: block scales for quantization + dtype is the same as the original Tensor dtype. + Shape is (n // block_size[0], k // block_size[1]) for 2D tensor + zero_point: block zero points for quantization + dtype is the same as the original Tensor dtype or int8 + Shape is (n // block_size[0], k // block_size[1]) for 2D tensor + + Non-Tensor Attributes: + target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8) + block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size) + dtype: the dtype of the dequantized Tensor + """ + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["target_dtype", "block_size", "dtype"] + + def __new__(cls, qdata, scale, zero_point, target_dtype, block_size, dtype): + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + shape = qdata.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata, + scale, + zero_point, + target_dtype, + block_size, + dtype, + ): + assert qdata.dtype == torch.int8, ( + f"qdata dtype must be int8, but got {qdata.dtype}" + ) + assert scale.dtype in _FLOAT_TYPES, ( + f"scale dtype must be one of {_FLOAT_TYPES}, but got {scale.dtype}" + ) + assert zero_point.dtype in _FLOAT_TYPES or zero_point.dtype == torch.int8, ( + f"zero_point dtype must be {torch.int8} or one of {_FLOAT_TYPES}, but got {zero_point.dtype}" + ) + + assert target_dtype in [ + getattr(torch, f"int{bit_width}") for bit_width in range(1, 9) + ] + + assert len(block_size) == qdata.ndim + n_blocks = [] + for i in range(len(block_size)): + assert qdata.shape[i] % block_size[i] == 0 + n_blocks.append(qdata.shape[i] // block_size[i]) + scale = scale.reshape(*n_blocks) + zero_point = zero_point.reshape(*n_blocks) + + assert dtype in _FLOAT_TYPES, ( + f"dtype must be one of {_FLOAT_TYPES}, but got {dtype}" + ) + + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + + self.target_dtype = target_dtype + self.block_size = block_size + + def _quantization_type(self): + return f"target_dtype={self.target_dtype}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}" + + def _has_float_zero_point(self) -> bool: + return self.zero_point.dtype in _FLOAT_TYPES + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs.pop("device") + dtype = kwargs.pop("dtype") + assert dtype in _FLOAT_TYPES + return IntxUnpackedTensor( + self.qdata.to(device), + self.scale.to(device=device, dtype=dtype), + self.zero_point.to(device=device, dtype=dtype) + if self._has_float_zero_point() + else self.zero_point.to(device), + self.target_dtype, + self.block_size, + dtype, + ) + + @classmethod + def from_hp( + cls, + hp_tensor: torch.Tensor, + block_size: Tuple[int], + target_dtype: torch.dtype, + *, + mapping_type: MappingType = MappingType.SYMMETRIC, + ): + """ + Create an IntxUnpackedTensor from a high-precision tensor + """ + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype] + scale, zero_point = choose_qparams_affine( + hp_tensor, + mapping_type, + block_size, + target_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + if zero_point.dtype == torch.int32: + int8_min, int8_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8] + assert zero_point.min().item() >= int8_min + assert zero_point.max().item() <= int8_max + zero_point = zero_point.to(torch.int8) + qdata = quantize_affine( + hp_tensor, + block_size, + scale, + zero_point, + output_dtype=torch.int8, + quant_min=qmin, + quant_max=qmax, + ) + return IntxUnpackedTensor( + qdata=qdata, + scale=scale, + zero_point=zero_point, + target_dtype=target_dtype, + block_size=block_size, + dtype=hp_tensor.dtype, + ) + + def dequantize(self): + qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[self.target_dtype] + return dequantize_affine( + self.qdata, + self.block_size, + self.scale, + self.zero_point, + torch.int8, + qmin, + qmax, + output_dtype=self.dtype, + ) + + +implements = IntxUnpackedTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if isinstance(input_tensor, IntxUnpackedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, IntxUnpackedTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements([torch.nn.functional.embedding, aten.embedding.default]) +def _(func, types, args, kwargs): + assert len(args) == 2 + indices, weight_tensor = ( + args[0], + args[1], + ) + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.embedding(indices, weight_tensor, **kwargs) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + + # Slicing must be compatible with the block size to make sense on the quantized tensor + # In particular both start and end must be a multiple of block_size[dim] + # Otherwise the sliced tensor cannot be represented as a IntxUnpackedTensor + # For example, if block_size = 4, we might have: + # + # qdata: i i i i | i i i i + # scale: s s + # + # If we set start = 2 and end = 8, then the qdata slice is: + # + # qdata_slice: i i (i i | i i i i) + # + # But then the block_size for the first two qdata in the slice is 2 + # and remaining blocks have size 4. This cannot be represented + # with the metadata we store in an IntxUnpackedTensor, which requires uniform blocking + + assert start % self.block_size[dim] == 0, ( + f"slice args are incompatible with blocking: start={start} must be divisible by block_size[dim]={self.block_size[dim]}" + ) + start_scale = start // self.block_size[dim] + + assert end % self.block_size[dim] == 0, ( + f"slice args are incompatible with blocking: end={end} must be divisible by block_size[dim]={self.block_size[dim]}" + ) + end_scale = end // self.block_size[dim] + + qdata = aten.slice.Tensor(self.qdata, dim, start, end, step) + scale = aten.slice.Tensor(self.scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor(self.zero_point, dim, start_scale, end_scale, step) + + new_block_size = [] + for i in range(qdata.ndim): + assert scale.shape[i] == zero_point.shape[i] + n_blocks = scale.shape[i] + assert qdata.shape[i] % n_blocks == 0 + new_block_size.append(qdata.shape[i] // n_blocks) + new_block_size = tuple(new_block_size) + + new = IntxUnpackedTensor( + qdata, + scale, + zero_point, + self.target_dtype, + new_block_size, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +IntxUnpackedTensor.__module__ = "torchao.quantization" + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with IntxUnpackedTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([IntxUnpackedTensor])