From 3e7e466e39fbbc8fcb611caa9a8722fb9e5e6bfe Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 22 Nov 2023 17:07:54 -0800 Subject: [PATCH 01/17] Adding uint4 dtype implementation Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 292 +++++++++++++++++++++++++++++++++++++ torchao/dtypes/__init__.py | 5 + torchao/dtypes/int4.py | 124 ++++++++++++++++ 3 files changed, 421 insertions(+) create mode 100644 test/dtypes/test_int4.py create mode 100644 torchao/dtypes/__init__.py create mode 100644 torchao/dtypes/int4.py diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py new file mode 100644 index 0000000000..e5a1ff3cfb --- /dev/null +++ b/test/dtypes/test_int4.py @@ -0,0 +1,292 @@ +import torch +from torchao.dtypes.int4 import UInt4Tensor +import unittest +from unittest import TestCase, main +from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer + +from torch._export import capture_pre_autograd_graph +from torch._export import dynamic_dim +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + QuantizationTestCase, +) +from torchao.quantization.utils import ( + compute_error, +) +from torchao.quantization.quant_api import ( + replace_with_custom_fn_if_matches_filter, +) +from torch import nn +import copy + +def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scale and zero point based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scale is the same dtype as the original tensor + scale = torch.clamp(scale, min=eps).to(x.dtype) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scale/zp + # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x.transpose(0, 1) / scale + x_round = torch.round(x_div) + x_zp = x_round + zero_point + x_zp = x_zp.transpose(0, 1) + quant = torch.clamp(x_zp, quant_min, quant_max) + if target_dtype == "int4": + quant = UInt4Tensor.from_unpacked(quant.to(torch.uint8)).view(quant.size()) + else: + quant = quant.to(target_dtype) + + return quant, scale, zero_point + +class _WeightOnlyInt4QuantLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + w_int4 = kwargs.pop("w_int4") + scales = kwargs.pop("scales") + super().__init__(*args, **kwargs) + self.w_int4 = w_int4 + self.scales = scales + + def forward(self, x): + # if len(x.shape)<=2: + # y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales + # else: # turn x into 2d tensor, then undo it for y + x_view = x.view(-1, x.shape[-1]) + y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales + y = y.reshape(*x.shape[:-1], -1) + if self.bias is not None: + y += self.bias + return y + + @classmethod + def from_float(cls, mod): + w_fp32 = mod.weight + w_int4, scales, _zp = _dynamically_quantize_per_channel_int4( + w_fp32, 0, 15, "int4" + ) + # create the new module with a toy size to ensure initialization is fast + fake_in_features, fake_out_features = 8, 8 + new_mod = cls( + fake_in_features, + fake_out_features, + bias=mod.bias is not None, + w_int4=w_int4.t().contiguous(), + scales=scales, + ) + new_mod.in_features = mod.in_features + new_mod.out_features = mod.out_features + del new_mod.weight + new_mod.bias = mod.bias + device_to_use = next(mod.parameters()).device + new_mod.to(device_to_use) + return new_mod + +def _apply_weight_only_int4_quant(model): + replace_with_custom_fn_if_matches_filter( + model, + _WeightOnlyInt4QuantLinear.from_float, + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + +class TestInt4(QuantizationTestCase): + def test_basic_tensor_ops(self): + x = UInt4Tensor(torch.tensor([ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], dtype=torch.uint8)) + self.assertTrue(x.shape, (3, 8)) + # making sure these works + x.to(torch.uint8) + expected = UInt4Tensor(torch.tensor([ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], dtype=torch.uint8)) + self.assertTrue(x[0:1, :] == expected) + expected = UInt4Tensor(torch.tensor([ + [0x23, 0x45], + [0x23, 0x45], + [0x23, 0x45], + ], dtype=torch.uint8)) + self.assertTrue(x[:, 2:6] == expected) + + def test_gpu_quant(self): + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + x = torch.randn(*x_shape) + m = nn.Sequential(nn.Linear(4, 16)) + y_ref = m(x) + _apply_weight_only_int4_quant(m) + y_wo = m(x) + # sqnr = compute_error(y_ref, y_wo) + opt = torch.compile(m, mode="max-autotune") + # make sure it runs + opt(x) + + def test_aten_ir(self): + from torch.library import Library, impl + test_lib = Library("test_int4", "DEF") + test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") + @impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") + def quantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, + ) -> torch.Tensor: + inv_scale = 1.0 / scale + return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8) + + test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") + @impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") + def dequantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, + ) -> torch.Tensor: + return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale + + # class QuantizePerTensorUInt4(torch.autograd.Function): + # @staticmethod + # def forward( + # ctx, + # input: torch.Tensor, + # scale: float, + # zero_point: int, + # ) -> torch.Tensor: + # inv_scale = 1.0 / scale + # return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8)) + + # class DeQuantizePerTensorUInt4(torch.autograd.Function): + # @staticmethod + # def forward( + # ctx, + # input: torch.Tensor, + # scale: float, + # zero_point: int, + # ) -> torch.Tensor: + # return (input.to(torch.float32) - zero_point) * scale + + class M(torch.nn.Module): + def forward(self, x, y): + return x + y + + example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),) + m = M().eval() + m = capture_pre_autograd_graph(m, example_inputs) + for n in m.graph.nodes: + if n.target == torch.ops.aten.add.Tensor: + with m.graph.inserting_before(n): + q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {}) + dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {}) + n.replace_input_with(n.args[0], dq) + m.recompile() + + # TODO: need more extension points from quant flow side + @unittest.skip("need more extension points from quant flow side") + def test_pt2e_quant(self): + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + OP_TO_ANNOTATOR, + QuantizationConfig, + ) + + class Int4ActQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + int4_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-2**3, + quant_max=2**3 - 1, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer, + ) + int8_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_weight_observer, + ) + quantization_config = QuantizationConfig( + input_activation=int8_qspec, + weight=int4_qspec, + bias=None, + output_activation=int8_qspec, + ) + OP_TO_ANNOTATOR["conv"](model, quantization_config) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return self.conv(x) + + quantizer = Int4ActQuantizer() + node_occurrence = { + # one for input of the first conv, one for output for the first conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.aten.conv2d.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + example_inputs = (torch.randn(1, 3, 3, 3),) + + # _test_quantizer in PT2EQuantizationTestCase + # resetting dynamo cache + export_with_dynamic_shape = False + torch._dynamo.reset() + m_eager = M().eval() + + # program capture + m = copy.deepcopy(m_eager) + m = capture_pre_autograd_graph( + m, + example_inputs, + constraints=[dynamic_dim(example_inputs[0], 0)] if export_with_dynamic_shape else [], + ) + + m = prepare_pt2e(m, quantizer) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m, fold_quantize=True) + + pt2_quant_output = m(*example_inputs) + node_occurrence = { + ns.call_function(k): v for k, v in expected_node_occurrence.items() + } + if expected_node_list is None: + expected_node_list = [] + node_list = [ns.call_function(n) for n in expected_node_list] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + +if __name__ == "__main__": + main() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py new file mode 100644 index 0000000000..3dc5b7c2eb --- /dev/null +++ b/torchao/dtypes/__init__.py @@ -0,0 +1,5 @@ +from .int4 import UInt4Tensor + +__all__ = [ + "UInt4Tensor" +] diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py new file mode 100644 index 0000000000..f8d8d796e9 --- /dev/null +++ b/torchao/dtypes/int4.py @@ -0,0 +1,124 @@ +import torch +import torch._prims_common as utils + +def down_size(size): + assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" + return (*size[:-1], size[-1] // 2) + +def up_size(size): + return (*size[:-1], size[-1] * 2) + +def fill_defaults(args, n, defaults_tail): + """ + __torch_dispatch__ doesn't guarantee the number of arguments you are + passed (e.g., defaulted arguments are not passed); but usually it is + convenient to pad out the arguments list with defaults. This function + helps you do that. + Args: + args: the list of positional arguments passed to __torch_dispatch__ + n: the number of arguments you are expecting to get + defaults_tail: default values for the arguments, starting from the + end of the list + Example: + >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) + [1, 2, 3, 4, 5] + >>> fill_defaults([1, 2, 3], 5, [None, None, None]) + [1, 2, 3, None, None]] + """ + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + +# from +# https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233 +def unpack_uint4(quantized_data) -> torch.Tensor: + """Get the original weight from the normalized float weight format""" + # since we are using uint8 we will decode 2 entries per byte + # Shift elements down 4 and select out the bottom 4 bits + first_elements = (quantized_data >> 4).to(torch.uint8) + second_elements = (quantized_data & 0b1111).to(torch.uint8) + return torch.stack([first_elements, second_elements], dim=-1) + +def pack_uint4(uint8_data) -> torch.Tensor: + shape = uint8_data.shape + uint8_data = uint8_data.contiguous().view(-1) + return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) + +class UInt4Tensor(torch.Tensor): + @staticmethod + def __new__(cls, elem): + # TODO: uint64 here is wrong, need a real dtype. Don't try to(int64) + # weird shit will happen + assert elem.dtype is torch.uint8 + return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int64) + + def __init__(self, elem): + self.elem = elem + + @classmethod + def from_unpacked(cls, unpacked): + return UInt4Tensor(pack_uint4(unpacked)) + + def tolist(self): + return self.to(torch.uint8).tolist() + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func is torch.ops.aten.view.default: + self, size = args + size = utils.infer_size(size, self.numel()) + assert not kwargs + # WARNING: views not preserved + return UInt4Tensor(self.elem.reshape(down_size(size))) + elif func is torch.ops.aten._to_copy.default: + self, = args + if kwargs == {'dtype': torch.uint8}: + return unpack_uint4(self.elem).view(self.shape) # no wrap + else: + raise NotImplementedError(f"_to_copy {kwargs}") + elif func is torch.ops.aten.unbind.int: + # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to + # create four tensors containing one element each. But we can't + # do this with uint4 because such a tensor's size is not divisible + # by bytes. What I am going to do instead is promote to uint8 + # when this happens + self, dim = fill_defaults(args, 2, [0]) + if dim != self.dim() - 1: + raise NotImplementedError(f"unbind dim={dim}") + else: + # We're unbinding the last dimension, need to promote + return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind(dim) + elif func is torch.ops.aten.select.int: + self, dim, index = args + if dim != self.dim() - 1: + return UInt4Tensor(torch.ops.aten.select.int(self.elem, dim, index)) + else: + raise NotImplementedError(f"select dim={dim}") + elif func is torch.ops.aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == self.dim() - 1: + # hard case + if step != 1: + raise NotImplementedError(f"slice step={step}") + assert start % 2 == 0, start + assert end >= self.shape[dim] or end % 2 == 0, end + return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start // 2, end // 2, 1)) + else: + # easy case + return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step)) + elif func is torch.ops.aten.t.default: + self, = args + unpacked = unpack_uint4(self.elem).view(self.shape) + transposed = torch.ops.aten.t.default(unpacked) + transposed_and_packed = pack_uint4(transposed) + return UInt4Tensor(transposed_and_packed) + + raise NotImplementedError(f"{func}") + + def __eq__(self, other): + return torch.equal(self.elem, other.elem) + + __torch_function__ = torch._C._disabled_torch_function_impl From a11039ba78d381b0d91f67837ebaefca3d31f65e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 29 Nov 2023 14:11:43 -0800 Subject: [PATCH 02/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 1 + torchao/dtypes/int4.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index e5a1ff3cfb..04610f4adf 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -72,6 +72,7 @@ def forward(self, x): # y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales # else: # turn x into 2d tensor, then undo it for y x_view = x.view(-1, x.shape[-1]) + mm_res = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales y = y.reshape(*x.shape[:-1], -1) if self.bias is not None: diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index f8d8d796e9..6529280cd5 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -65,6 +65,18 @@ def from_unpacked(cls, unpacked): def tolist(self): return self.to(torch.uint8).tolist() + def __tensor_flatten__(self): + return ["elem"], None + + @staticmethod + def __tensor_unflatten__(flattened, meta): + assert meta is None + elem = flattened["elem"] + return UInt4Tensor(elem) + + def __hash__(self): + return hash(self.elem) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): if func is torch.ops.aten.view.default: @@ -115,6 +127,19 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): transposed = torch.ops.aten.t.default(unpacked) transposed_and_packed = pack_uint4(transposed) return UInt4Tensor(transposed_and_packed) + elif func is torch.ops.aten.as_strided.default: + self, size, stride, storage_offset = args + size = up_size(size) + new_stride = [] + for s in stride: + if s != 1: + # since two int4 equals to 1 byte + new_stride.append(s // 2) + else: + # wondering if we need to have 1/2 stride? + new_stride.append(s) + stride = new_stride + return UInt4Tensor(torch.ops.aten.as_strided.default(self.elem, size, stride, storage_offset)) raise NotImplementedError(f"{func}") From 1fc4660163d548ea36fc7d17ea11e99466b21d8d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 29 Nov 2023 14:21:56 -0800 Subject: [PATCH 03/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index 04610f4adf..e5a1ff3cfb 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -72,7 +72,6 @@ def forward(self, x): # y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales # else: # turn x into 2d tensor, then undo it for y x_view = x.view(-1, x.shape[-1]) - mm_res = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales y = y.reshape(*x.shape[:-1], -1) if self.bias is not None: From 19f3d0d63f9390a063e5756ba85d54c67d0d3afc Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 7 Dec 2023 08:33:56 -0800 Subject: [PATCH 04/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torchao/dtypes/int4.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index 6529280cd5..542e01443d 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -1,6 +1,9 @@ import torch import torch._prims_common as utils +# TODO: uint8 --> bits8 +# TODO: adding support for pt2e quant, after https://github.com/pytorch/pytorch/pull/115001 is landed + def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" return (*size[:-1], size[-1] // 2) @@ -128,17 +131,20 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): transposed_and_packed = pack_uint4(transposed) return UInt4Tensor(transposed_and_packed) elif func is torch.ops.aten.as_strided.default: + # size, stride, storage_offset are referring to tensor elements, not physical bytes self, size, stride, storage_offset = args - size = up_size(size) + size = down_size(size) + new_stride = [] for s in stride: if s != 1: - # since two int4 equals to 1 byte + # since two int4 equals to 1 bits8 new_stride.append(s // 2) else: - # wondering if we need to have 1/2 stride? new_stride.append(s) stride = new_stride + + storage_offset //= 2 return UInt4Tensor(torch.ops.aten.as_strided.default(self.elem, size, stride, storage_offset)) raise NotImplementedError(f"{func}") From 734c7051621498886c2f65fff0526901db8757ba Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 13 Dec 2023 13:23:52 -0800 Subject: [PATCH 05/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 91 ++++++++++++++++++++++++++-------------- torchao/dtypes/int4.py | 5 ++- 2 files changed, 64 insertions(+), 32 deletions(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index e5a1ff3cfb..903ceaf79a 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -17,6 +17,7 @@ from torchao.quantization.quant_api import ( replace_with_custom_fn_if_matches_filter, ) +from torch.ao.quantization.observer import ObserverBase from torch import nn import copy @@ -108,6 +109,30 @@ def _apply_weight_only_int4_quant(model): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) +from torch.library import Library, impl + +test_lib = Library("test_int4", "DEF") +test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") + +@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") +def quantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, +) -> torch.Tensor: + inv_scale = 1.0 / scale + return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8) + +test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") +@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") +def dequantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, +) -> torch.Tensor: + return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale + + class TestInt4(QuantizationTestCase): def test_basic_tensor_ops(self): x = UInt4Tensor(torch.tensor([ @@ -142,27 +167,6 @@ def test_gpu_quant(self): opt(x) def test_aten_ir(self): - from torch.library import Library, impl - test_lib = Library("test_int4", "DEF") - test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") - @impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") - def quantize_per_tensor_int4( - input: torch.Tensor, - scale: float, - zero_point: int, - ) -> torch.Tensor: - inv_scale = 1.0 / scale - return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8) - - test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") - @impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") - def dequantize_per_tensor_int4( - input: torch.Tensor, - scale: float, - zero_point: int, - ) -> torch.Tensor: - return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale - # class QuantizePerTensorUInt4(torch.autograd.Function): # @staticmethod # def forward( @@ -199,23 +203,46 @@ def forward(self, x, y): n.replace_input_with(n.args[0], dq) m.recompile() - # TODO: need more extension points from quant flow side - @unittest.skip("need more extension points from quant flow side") def test_pt2e_quant(self): from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( OP_TO_ANNOTATOR, QuantizationConfig, ) + class int4_class(): + pass + + torch.int4 = int4_class() + + class Int4Observer(ObserverBase): + def __init__(self, *args, **kwargs): + # just faking a dtype here + # TODO: make flow work with new dtypes + super().__init__(dtype=torch.int8) - class Int4ActQuantizer(Quantizer): + def forward(self, x): + return x + + def calculate_qparams(self, **kwargs): + pass + + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + with model.graph.inserting_before(observer_node): + q_node = model.graph.call_function( + torch.ops.test_int4.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {}) + dq_node = model.graph.call_function( + torch.ops.test_int4.dequantize_per_tensor_int4, (q_node, 1.0, 0), {}) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) + + class Int4WeightQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: int4_qspec = QuantizationSpec( - dtype=torch.int8, + dtype=torch.int4, quant_min=-2**3, quant_max=2**3 - 1, qscheme=torch.per_tensor_affine, is_dynamic=False, - observer_or_fake_quant_ctr=observer.default_observer, + observer_or_fake_quant_ctr=Int4Observer, ) int8_qspec = QuantizationSpec( dtype=torch.int8, @@ -244,15 +271,18 @@ def __init__(self): def forward(self, x): return self.conv(x) - quantizer = Int4ActQuantizer() + quantizer = Int4WeightQuantizer() node_occurrence = { - # one for input of the first conv, one for output for the first conv + # for weight + torch.ops.test_int4.quantize_per_tensor_int4: 1, + torch.ops.test_int4.dequantize_per_tensor_int4: 1, + # for activation torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, } node_list = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.test_int4.dequantize_per_tensor_int4, torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, ] @@ -269,7 +299,6 @@ def forward(self, x): m = capture_pre_autograd_graph( m, example_inputs, - constraints=[dynamic_dim(example_inputs[0], 0)] if export_with_dynamic_shape else [], ) m = prepare_pt2e(m, quantizer) diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index 542e01443d..0968f4fe17 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -1,8 +1,11 @@ import torch import torch._prims_common as utils +# TODO: fix error from symbolic_context +# TODO: adding support for pt2e quant +# module swap --> subclass (for it to be composable with distributed, sparsity etc. subclasses) # TODO: uint8 --> bits8 -# TODO: adding support for pt2e quant, after https://github.com/pytorch/pytorch/pull/115001 is landed + def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" From a9d6cdc71cbbd1c6d8fcf3ef71f8e2b9b8937e4b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 18 Dec 2023 09:32:14 -0800 Subject: [PATCH 06/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 3 ++- torchao/dtypes/int4.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index 903ceaf79a..dcf7e2c9fd 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -91,7 +91,8 @@ def from_float(cls, mod): fake_in_features, fake_out_features, bias=mod.bias is not None, - w_int4=w_int4.t().contiguous(), + # w_int4=w_int4.t().contiguous(), + w_int4=torch.ops.aten.transpose_copy(w_int4, 0, 1), scales=scales, ) new_mod.in_features = mod.in_features diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index 0968f4fe17..83b1c54390 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -133,6 +133,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): transposed = torch.ops.aten.t.default(unpacked) transposed_and_packed = pack_uint4(transposed) return UInt4Tensor(transposed_and_packed) + elif func is torch.ops.aten.transpose_copy.int: + self, dim0, dim1 = args + unpacked = unpack_uint4(self.elem).view(self.shape) + transposed = torch.ops.aten.transpose_copy.int(unpacked, dim0, dim1) + transposed_and_packed = pack_uint4(transposed) + return UInt4Tensor(transposed_and_packed) elif func is torch.ops.aten.as_strided.default: # size, stride, storage_offset are referring to tensor elements, not physical bytes self, size, stride, storage_offset = args From 1b51eb5c84ae3d52a5d8021185ca6e83de6c330a Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 18 Dec 2023 09:36:11 -0800 Subject: [PATCH 07/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index dcf7e2c9fd..0d31353caf 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -141,7 +141,7 @@ def test_basic_tensor_ops(self): [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], ], dtype=torch.uint8)) - self.assertTrue(x.shape, (3, 8)) + self.assertEqual(x.shape, (3, 16)) # making sure these works x.to(torch.uint8) expected = UInt4Tensor(torch.tensor([ From b5ce8c669bfd0663de5dccc7119347b685ba4afd Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 18 Dec 2023 16:18:13 -0800 Subject: [PATCH 08/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 80 +++++++++++++++++++++++++++++++++------- torchao/dtypes/int4.py | 33 ++++++++--------- 2 files changed, 82 insertions(+), 31 deletions(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index 0d31353caf..4a1e03581d 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -19,6 +19,13 @@ ) from torch.ao.quantization.observer import ObserverBase from torch import nn +from torch.fx import ( + Node, + GraphModule, +) +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, +) import copy def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype): @@ -54,7 +61,7 @@ def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype x_zp = x_zp.transpose(0, 1) quant = torch.clamp(x_zp, quant_min, quant_max) if target_dtype == "int4": - quant = UInt4Tensor.from_unpacked(quant.to(torch.uint8)).view(quant.size()) + quant = UInt4Tensor.from_unpacked(quant.view(torch.bits8)).view(quant.size()) else: quant = quant.to(target_dtype) @@ -131,7 +138,17 @@ def dequantize_per_tensor_int4( scale: float, zero_point: int, ) -> torch.Tensor: - return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale + print("1") + a = input.to(torch.uint8) + print("2") + a = a.to(torch.float32) + print("3") + a = a - zero_point + print("4") + a = a * scale + print("5") + return a + # return (input.to(torch.uint8).to(torch.float32) - zero_point) * scale class TestInt4(QuantizationTestCase): @@ -140,19 +157,19 @@ def test_basic_tensor_ops(self): [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) + ], dtype=torch.bits8)) self.assertEqual(x.shape, (3, 16)) # making sure these works x.to(torch.uint8) expected = UInt4Tensor(torch.tensor([ [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) + ], dtype=torch.bits8)) self.assertTrue(x[0:1, :] == expected) expected = UInt4Tensor(torch.tensor([ [0x23, 0x45], [0x23, 0x45], [0x23, 0x45], - ], dtype=torch.uint8)) + ], dtype=torch.bits8)) self.assertTrue(x[:, 2:6] == expected) def test_gpu_quant(self): @@ -177,7 +194,7 @@ def test_aten_ir(self): # zero_point: int, # ) -> torch.Tensor: # inv_scale = 1.0 / scale - # return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8)) + # return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.bits8)) # class DeQuantizePerTensorUInt4(torch.autograd.Function): # @staticmethod @@ -226,7 +243,7 @@ def forward(self, x): def calculate_qparams(self, **kwargs): pass - def convert(self, model: torch.fx.GraphModule, observer_node: Node): + def convert(self, model: GraphModule, observer_node: Node): with model.graph.inserting_before(observer_node): q_node = model.graph.call_function( torch.ops.test_int4.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {}) @@ -235,6 +252,11 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): observer_node.replace_all_uses_with(dq_node) model.graph.erase_node(observer_node) + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _is_annotated, + _mark_nodes_as_annotated, + ) + class Int4WeightQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: int4_qspec = QuantizationSpec( @@ -251,7 +273,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: quant_max=127, qscheme=torch.per_tensor_symmetric, is_dynamic=False, - observer_or_fake_quant_ctr=observer.default_weight_observer, + observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer, ) quantization_config = QuantizationConfig( input_activation=int8_qspec, @@ -259,7 +281,39 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: bias=None, output_activation=int8_qspec, ) - OP_TO_ANNOTATOR["conv"](model, quantization_config) + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ]: + continue + conv_node = n + + input_qspec_map = {} + input_act = conv_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + + weight = conv_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = quantization_config.weight + + partition = [conv_node, conv_node.args[1]] + + bias = conv_node.args[2] if len(conv_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = quantization_config.bias + partition.append(bias) + + if _is_annotated(partition): + continue + + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) def validate(self, model: torch.fx.GraphModule) -> None: pass @@ -305,15 +359,13 @@ def forward(self, x): m = prepare_pt2e(m, quantizer) # Calibrate m(*example_inputs) - m = convert_pt2e(m, fold_quantize=True) + m = convert_pt2e(m, fold_quantize=False) pt2_quant_output = m(*example_inputs) node_occurrence = { - ns.call_function(k): v for k, v in expected_node_occurrence.items() + ns.call_function(k): v for k, v in node_occurrence.items() } - if expected_node_list is None: - expected_node_list = [] - node_list = [ns.call_function(n) for n in expected_node_list] + node_list = [ns.call_function(n) for n in node_list] self.checkGraphModuleNodes( m, expected_node_occurrence=node_occurrence, expected_node_list=node_list ) diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index 83b1c54390..8c927b7cc2 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -1,11 +1,8 @@ import torch import torch._prims_common as utils -# TODO: fix error from symbolic_context -# TODO: adding support for pt2e quant -# module swap --> subclass (for it to be composable with distributed, sparsity etc. subclasses) # TODO: uint8 --> bits8 - +# module swap --> subclass (for it to be composable with distributed, sparsity etc. subclasses) def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" @@ -42,23 +39,23 @@ def fill_defaults(args, n, defaults_tail): # https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233 def unpack_uint4(quantized_data) -> torch.Tensor: """Get the original weight from the normalized float weight format""" - # since we are using uint8 we will decode 2 entries per byte + # since we are using bits8 we will decode 2 entries per byte # Shift elements down 4 and select out the bottom 4 bits - first_elements = (quantized_data >> 4).to(torch.uint8) - second_elements = (quantized_data & 0b1111).to(torch.uint8) + first_elements = (quantized_data >> 4).to(torch.bits8) + second_elements = (quantized_data & 0b1111).to(torch.bits8) return torch.stack([first_elements, second_elements], dim=-1) -def pack_uint4(uint8_data) -> torch.Tensor: - shape = uint8_data.shape - uint8_data = uint8_data.contiguous().view(-1) - return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) +def pack_uint4(bits8_data) -> torch.Tensor: + shape = bits8_data.shape + bits8_data = bits8_data.contiguous().view(-1) + return (bits8_data[::2] << 4 | bits8_data[1::2]).view(down_size(shape)) class UInt4Tensor(torch.Tensor): @staticmethod def __new__(cls, elem): # TODO: uint64 here is wrong, need a real dtype. Don't try to(int64) # weird shit will happen - assert elem.dtype is torch.uint8 + assert elem.dtype is torch.bits8 return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int64) def __init__(self, elem): @@ -69,13 +66,13 @@ def from_unpacked(cls, unpacked): return UInt4Tensor(pack_uint4(unpacked)) def tolist(self): - return self.to(torch.uint8).tolist() + return self.to(torch.bits8).tolist() def __tensor_flatten__(self): return ["elem"], None @staticmethod - def __tensor_unflatten__(flattened, meta): + def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): assert meta is None elem = flattened["elem"] return UInt4Tensor(elem) @@ -93,7 +90,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return UInt4Tensor(self.elem.reshape(down_size(size))) elif func is torch.ops.aten._to_copy.default: self, = args - if kwargs == {'dtype': torch.uint8}: + if kwargs == {'dtype': torch.bits8}: + print("_to_copy", args) return unpack_uint4(self.elem).view(self.shape) # no wrap else: raise NotImplementedError(f"_to_copy {kwargs}") @@ -101,14 +99,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to # create four tensors containing one element each. But we can't # do this with uint4 because such a tensor's size is not divisible - # by bytes. What I am going to do instead is promote to uint8 + # by bytes. What I am going to do instead is promote to bits8 # when this happens self, dim = fill_defaults(args, 2, [0]) if dim != self.dim() - 1: raise NotImplementedError(f"unbind dim={dim}") else: # We're unbinding the last dimension, need to promote - return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind(dim) + return torch.ops.aten._to_copy.default(self, dtype=torch.bits8).unbind(dim) elif func is torch.ops.aten.select.int: self, dim, index = args if dim != self.dim() - 1: @@ -128,6 +126,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # easy case return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step)) elif func is torch.ops.aten.t.default: + assert False, "transpose is not properly implemented currently" self, = args unpacked = unpack_uint4(self.elem).view(self.shape) transposed = torch.ops.aten.t.default(unpacked) From bb483f3ec716e54d406a6cac2f357d080dbba51f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 19 Dec 2023 11:23:13 -0800 Subject: [PATCH 09/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 2 +- torchao/dtypes/int4.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index 4a1e03581d..d0cfd63c96 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -138,7 +138,7 @@ def dequantize_per_tensor_int4( scale: float, zero_point: int, ) -> torch.Tensor: - print("1") + print("1", input.dtype) a = input.to(torch.uint8) print("2") a = a.to(torch.float32) diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index 8c927b7cc2..3fd10bfe02 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -2,6 +2,7 @@ import torch._prims_common as utils # TODO: uint8 --> bits8 +# TODO: change int4_tensor.dtype to return torch.int4 # module swap --> subclass (for it to be composable with distributed, sparsity etc. subclasses) def down_size(size): @@ -82,6 +83,7 @@ def __hash__(self): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): + print(f"func called: {func}") if func is torch.ops.aten.view.default: self, size = args size = utils.infer_size(size, self.numel()) @@ -89,9 +91,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # WARNING: views not preserved return UInt4Tensor(self.elem.reshape(down_size(size))) elif func is torch.ops.aten._to_copy.default: + print("_to_copy:", args) self, = args if kwargs == {'dtype': torch.bits8}: - print("_to_copy", args) return unpack_uint4(self.elem).view(self.shape) # no wrap else: raise NotImplementedError(f"_to_copy {kwargs}") From 9c9084ee803912604ee514d27068643562e3b44d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 19 Dec 2023 15:28:15 -0800 Subject: [PATCH 10/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 56 ++++++++++++++++++++++------------------ torchao/dtypes/int4.py | 42 ++++++++++++++++++------------ 2 files changed, 57 insertions(+), 41 deletions(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index d0cfd63c96..c0f5cd7eea 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -61,7 +61,7 @@ def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype x_zp = x_zp.transpose(0, 1) quant = torch.clamp(x_zp, quant_min, quant_max) if target_dtype == "int4": - quant = UInt4Tensor.from_unpacked(quant.view(torch.bits8)).view(quant.size()) + quant = UInt4Tensor.from_unpacked(quant.view(torch.uint8)).view(quant.size()) else: quant = quant.to(target_dtype) @@ -129,7 +129,7 @@ def quantize_per_tensor_int4( zero_point: int, ) -> torch.Tensor: inv_scale = 1.0 / scale - return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8) + return UInt4Tensor.from_unpacked(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.uint8)) test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") @impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") @@ -138,18 +138,23 @@ def dequantize_per_tensor_int4( scale: float, zero_point: int, ) -> torch.Tensor: - print("1", input.dtype) - a = input.to(torch.uint8) - print("2") - a = a.to(torch.float32) - print("3") - a = a - zero_point - print("4") - a = a * scale - print("5") - return a - # return (input.to(torch.uint8).to(torch.float32) - zero_point) * scale + return (input.to(torch.uint8).to(torch.float32) - zero_point) * scale +@impl(test_lib, "quantize_per_tensor_int4", "Meta") +def quantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, +) -> torch.Tensor: + return torch.empty_like(input, dtype=torch.uint8) + +@impl(test_lib, "dequantize_per_tensor_int4", "Meta") +def dequantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, +) -> torch.Tensor: + return torch.empty_like(input, dtype=torch.float32) class TestInt4(QuantizationTestCase): def test_basic_tensor_ops(self): @@ -157,19 +162,21 @@ def test_basic_tensor_ops(self): [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.bits8)) + ], dtype=torch.uint8)) self.assertEqual(x.shape, (3, 16)) + # TODO: make sure this returns torch.int4 + print("dtype:", x.dtype) # making sure these works x.to(torch.uint8) expected = UInt4Tensor(torch.tensor([ [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.bits8)) + ], dtype=torch.uint8)) self.assertTrue(x[0:1, :] == expected) expected = UInt4Tensor(torch.tensor([ [0x23, 0x45], [0x23, 0x45], [0x23, 0x45], - ], dtype=torch.bits8)) + ], dtype=torch.uint8)) self.assertTrue(x[:, 2:6] == expected) def test_gpu_quant(self): @@ -194,7 +201,7 @@ def test_aten_ir(self): # zero_point: int, # ) -> torch.Tensor: # inv_scale = 1.0 / scale - # return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.bits8)) + # return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8)) # class DeQuantizePerTensorUInt4(torch.autograd.Function): # @staticmethod @@ -210,7 +217,7 @@ class M(torch.nn.Module): def forward(self, x, y): return x + y - example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),) + example_inputs = (torch.randn(1, 2, 10, 10), torch.randn(1, 2, 10, 10),) m = M().eval() m = capture_pre_autograd_graph(m, example_inputs) for n in m.graph.nodes: @@ -226,11 +233,6 @@ def test_pt2e_quant(self): OP_TO_ANNOTATOR, QuantizationConfig, ) - class int4_class(): - pass - - torch.int4 = int4_class() - class Int4Observer(ObserverBase): def __init__(self, *args, **kwargs): # just faking a dtype here @@ -321,7 +323,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: class M(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) + self.conv = torch.nn.Conv2d(4, 4, 4) def forward(self, x): return self.conv(x) @@ -341,13 +343,14 @@ def forward(self, x): torch.ops.aten.conv2d.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, ] - example_inputs = (torch.randn(1, 3, 3, 3),) + example_inputs = (torch.randn(1, 4, 8, 8),) # _test_quantizer in PT2EQuantizationTestCase # resetting dynamo cache export_with_dynamic_shape = False torch._dynamo.reset() m_eager = M().eval() + _ = torch._export.export(m_eager, example_inputs) # program capture m = copy.deepcopy(m_eager) @@ -360,8 +363,11 @@ def forward(self, x): # Calibrate m(*example_inputs) m = convert_pt2e(m, fold_quantize=False) + m = torch._export.export(m, example_inputs) + print("m:", m) pt2_quant_output = m(*example_inputs) + print("output:", pt2_quant_output) node_occurrence = { ns.call_function(k): v for k, v in node_occurrence.items() } diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index 3fd10bfe02..a3465e67ac 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -1,10 +1,16 @@ import torch import torch._prims_common as utils -# TODO: uint8 --> bits8 +# all test pass +# TODO: uint8 --> bits8 (currently blocked on bits8 supporting copy_ to long, or support >> operation # TODO: change int4_tensor.dtype to return torch.int4 # module swap --> subclass (for it to be composable with distributed, sparsity etc. subclasses) +# class int4_class(torch.dtype): +# pass + +torch.int4 = torch.dtype("int4") + def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" return (*size[:-1], size[-1] // 2) @@ -40,24 +46,28 @@ def fill_defaults(args, n, defaults_tail): # https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233 def unpack_uint4(quantized_data) -> torch.Tensor: """Get the original weight from the normalized float weight format""" - # since we are using bits8 we will decode 2 entries per byte + # since we are using uint8 we will decode 2 entries per byte # Shift elements down 4 and select out the bottom 4 bits - first_elements = (quantized_data >> 4).to(torch.bits8) - second_elements = (quantized_data & 0b1111).to(torch.bits8) + first_elements = (quantized_data >> 4).to(torch.uint8) + second_elements = (quantized_data & 0b1111).to(torch.uint8) return torch.stack([first_elements, second_elements], dim=-1) -def pack_uint4(bits8_data) -> torch.Tensor: - shape = bits8_data.shape - bits8_data = bits8_data.contiguous().view(-1) - return (bits8_data[::2] << 4 | bits8_data[1::2]).view(down_size(shape)) +def pack_uint4(uint8_data) -> torch.Tensor: + shape = uint8_data.shape + assert shape[-1] % 2 == 0 + uint8_data = uint8_data.contiguous().view(-1) + print("size 1:", (uint8_data[::2] << 4).size()) + print("size 2:", (uint8_data[1::2] << 4).size()) + return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) class UInt4Tensor(torch.Tensor): @staticmethod def __new__(cls, elem): # TODO: uint64 here is wrong, need a real dtype. Don't try to(int64) # weird shit will happen - assert elem.dtype is torch.bits8 - return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int64) + assert elem.dtype is torch.uint8 + # TODO: right now tensor.dtype still displays bits8 + return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int4) def __init__(self, elem): self.elem = elem @@ -67,7 +77,7 @@ def from_unpacked(cls, unpacked): return UInt4Tensor(pack_uint4(unpacked)) def tolist(self): - return self.to(torch.bits8).tolist() + return self.to(torch.uint8).tolist() def __tensor_flatten__(self): return ["elem"], None @@ -91,9 +101,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # WARNING: views not preserved return UInt4Tensor(self.elem.reshape(down_size(size))) elif func is torch.ops.aten._to_copy.default: - print("_to_copy:", args) + # print("_to_copy:", args) self, = args - if kwargs == {'dtype': torch.bits8}: + if kwargs == {'dtype': torch.uint8}: return unpack_uint4(self.elem).view(self.shape) # no wrap else: raise NotImplementedError(f"_to_copy {kwargs}") @@ -101,14 +111,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to # create four tensors containing one element each. But we can't # do this with uint4 because such a tensor's size is not divisible - # by bytes. What I am going to do instead is promote to bits8 + # by bytes. What I am going to do instead is promote to uint8 # when this happens self, dim = fill_defaults(args, 2, [0]) if dim != self.dim() - 1: raise NotImplementedError(f"unbind dim={dim}") else: # We're unbinding the last dimension, need to promote - return torch.ops.aten._to_copy.default(self, dtype=torch.bits8).unbind(dim) + return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind(dim) elif func is torch.ops.aten.select.int: self, dim, index = args if dim != self.dim() - 1: @@ -148,7 +158,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): new_stride = [] for s in stride: if s != 1: - # since two int4 equals to 1 bits8 + # since two int4 equals to 1 uint8 new_stride.append(s // 2) else: new_stride.append(s) From dfe122d22bc4757799366c7c8d328a96865e4d26 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 20 Dec 2023 22:47:54 -0800 Subject: [PATCH 11/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 214 ++++++--------------------------------- torchao/dtypes/int4.py | 202 +++++++++++++++++++++++++++++++----- 2 files changed, 203 insertions(+), 213 deletions(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index c0f5cd7eea..3fce6ff921 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -1,5 +1,8 @@ import torch -from torchao.dtypes.int4 import UInt4Tensor +from torchao.dtypes.int4 import ( + UInt4Tensor, + PerChannelSymmetricWeightUInt4Tensor, +) import unittest from unittest import TestCase, main from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e @@ -28,156 +31,39 @@ ) import copy -def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype): - # assumes symmetric quantization - # assumes axis == 0 - # assumes dense memory format - # TODO(future): relax ^ as needed - - # default setup for affine quantization of activations - eps = torch.finfo(torch.float32).eps - - # get min and max - min_val, max_val = torch.aminmax(x, dim=1) - - # calculate scale and zero point based on min and max - # reference: https://fburl.com/code/srbiybme - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - device = min_val_neg.device - - # reference: https://fburl.com/code/4wll53rk - max_val_pos = torch.max(-min_val_neg, max_val_pos) - scale = max_val_pos / (float(quant_max - quant_min) / 2) - # ensure scale is the same dtype as the original tensor - scale = torch.clamp(scale, min=eps).to(x.dtype) - zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) - - # quantize based on qmin/qmax/scale/zp - # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 - x_div = x.transpose(0, 1) / scale - x_round = torch.round(x_div) - x_zp = x_round + zero_point - x_zp = x_zp.transpose(0, 1) - quant = torch.clamp(x_zp, quant_min, quant_max) - if target_dtype == "int4": - quant = UInt4Tensor.from_unpacked(quant.view(torch.uint8)).view(quant.size()) - else: - quant = quant.to(target_dtype) - - return quant, scale, zero_point - -class _WeightOnlyInt4QuantLinear(torch.nn.Linear): - def __init__(self, *args, **kwargs): - w_int4 = kwargs.pop("w_int4") - scales = kwargs.pop("scales") - super().__init__(*args, **kwargs) - self.w_int4 = w_int4 - self.scales = scales - - def forward(self, x): - # if len(x.shape)<=2: - # y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales - # else: # turn x into 2d tensor, then undo it for y - x_view = x.view(-1, x.shape[-1]) - y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales - y = y.reshape(*x.shape[:-1], -1) - if self.bias is not None: - y += self.bias - return y - - @classmethod - def from_float(cls, mod): - w_fp32 = mod.weight - w_int4, scales, _zp = _dynamically_quantize_per_channel_int4( - w_fp32, 0, 15, "int4" - ) - # create the new module with a toy size to ensure initialization is fast - fake_in_features, fake_out_features = 8, 8 - new_mod = cls( - fake_in_features, - fake_out_features, - bias=mod.bias is not None, - # w_int4=w_int4.t().contiguous(), - w_int4=torch.ops.aten.transpose_copy(w_int4, 0, 1), - scales=scales, - ) - new_mod.in_features = mod.in_features - new_mod.out_features = mod.out_features - del new_mod.weight - new_mod.bias = mod.bias - device_to_use = next(mod.parameters()).device - new_mod.to(device_to_use) - return new_mod - def _apply_weight_only_int4_quant(model): + def fn(mod): + mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False) + return mod + replace_with_custom_fn_if_matches_filter( model, - _WeightOnlyInt4QuantLinear.from_float, + lambda mod: fn(mod), lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) -from torch.library import Library, impl - -test_lib = Library("test_int4", "DEF") -test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") - -@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") -def quantize_per_tensor_int4( - input: torch.Tensor, - scale: float, - zero_point: int, -) -> torch.Tensor: - inv_scale = 1.0 / scale - return UInt4Tensor.from_unpacked(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.uint8)) - -test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") -@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") -def dequantize_per_tensor_int4( - input: torch.Tensor, - scale: float, - zero_point: int, -) -> torch.Tensor: - return (input.to(torch.uint8).to(torch.float32) - zero_point) * scale - -@impl(test_lib, "quantize_per_tensor_int4", "Meta") -def quantize_per_tensor_int4( - input: torch.Tensor, - scale: float, - zero_point: int, -) -> torch.Tensor: - return torch.empty_like(input, dtype=torch.uint8) - -@impl(test_lib, "dequantize_per_tensor_int4", "Meta") -def dequantize_per_tensor_int4( - input: torch.Tensor, - scale: float, - zero_point: int, -) -> torch.Tensor: - return torch.empty_like(input, dtype=torch.float32) - class TestInt4(QuantizationTestCase): def test_basic_tensor_ops(self): x = UInt4Tensor(torch.tensor([ [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) + ], dtype=torch.uint8).view(torch.bits8)) self.assertEqual(x.shape, (3, 16)) # TODO: make sure this returns torch.int4 - print("dtype:", x.dtype) + self.assertEqual(x.dtype, torch.bits8) # making sure these works x.to(torch.uint8) expected = UInt4Tensor(torch.tensor([ [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) - self.assertTrue(x[0:1, :] == expected) + ], dtype=torch.uint8).view(torch.bits8)) + self.assertEqual(x[0:1, :], expected) expected = UInt4Tensor(torch.tensor([ [0x23, 0x45], [0x23, 0x45], [0x23, 0x45], - ], dtype=torch.uint8)) - self.assertTrue(x[:, 2:6] == expected) + ], dtype=torch.uint8).view(torch.bits8)) + self.assertEqual(x[:, 2:6], expected) def test_gpu_quant(self): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: @@ -191,43 +77,6 @@ def test_gpu_quant(self): # make sure it runs opt(x) - def test_aten_ir(self): - # class QuantizePerTensorUInt4(torch.autograd.Function): - # @staticmethod - # def forward( - # ctx, - # input: torch.Tensor, - # scale: float, - # zero_point: int, - # ) -> torch.Tensor: - # inv_scale = 1.0 / scale - # return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8)) - - # class DeQuantizePerTensorUInt4(torch.autograd.Function): - # @staticmethod - # def forward( - # ctx, - # input: torch.Tensor, - # scale: float, - # zero_point: int, - # ) -> torch.Tensor: - # return (input.to(torch.float32) - zero_point) * scale - - class M(torch.nn.Module): - def forward(self, x, y): - return x + y - - example_inputs = (torch.randn(1, 2, 10, 10), torch.randn(1, 2, 10, 10),) - m = M().eval() - m = capture_pre_autograd_graph(m, example_inputs) - for n in m.graph.nodes: - if n.target == torch.ops.aten.add.Tensor: - with m.graph.inserting_before(n): - q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {}) - dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {}) - n.replace_input_with(n.args[0], dq) - m.recompile() - def test_pt2e_quant(self): from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( OP_TO_ANNOTATOR, @@ -259,7 +108,7 @@ def convert(self, model: GraphModule, observer_node: Node): _mark_nodes_as_annotated, ) - class Int4WeightQuantizer(Quantizer): + class Int8ActInt4WeightQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: int4_qspec = QuantizationSpec( dtype=torch.int4, @@ -285,24 +134,23 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: ) for n in model.graph.nodes: if n.op != "call_function" or n.target not in [ - torch.ops.aten.conv1d.default, - torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, ]: continue - conv_node = n + linear_node = n input_qspec_map = {} - input_act = conv_node.args[0] + input_act = linear_node.args[0] assert isinstance(input_act, Node) input_qspec_map[input_act] = quantization_config.input_activation - weight = conv_node.args[1] + weight = linear_node.args[1] assert isinstance(weight, Node) input_qspec_map[weight] = quantization_config.weight - partition = [conv_node, conv_node.args[1]] + partition = [linear_node, linear_node.args[1]] - bias = conv_node.args[2] if len(conv_node.args) > 2 else None + bias = linear_node.args[2] if len(linear_node.args) > 2 else None if isinstance(bias, Node): input_qspec_map[bias] = quantization_config.bias partition.append(bias) @@ -310,7 +158,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: if _is_annotated(partition): continue - conv_node.meta["quantization_annotation"] = QuantizationAnnotation( + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map=input_qspec_map, output_qspec=quantization_config.output_activation, _annotated=True, @@ -323,12 +171,12 @@ def validate(self, model: torch.fx.GraphModule) -> None: class M(torch.nn.Module): def __init__(self): super().__init__() - self.conv = torch.nn.Conv2d(4, 4, 4) + self.linear = torch.nn.Linear(4, 4) def forward(self, x): - return self.conv(x) + return self.linear(x) - quantizer = Int4WeightQuantizer() + quantizer = Int8ActInt4WeightQuantizer() node_occurrence = { # for weight torch.ops.test_int4.quantize_per_tensor_int4: 1, @@ -340,17 +188,16 @@ def forward(self, x): node_list = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.test_int4.dequantize_per_tensor_int4, - torch.ops.aten.conv2d.default, + torch.ops.aten.linear.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, ] - example_inputs = (torch.randn(1, 4, 8, 8),) + example_inputs = (torch.randn(2, 4),) # _test_quantizer in PT2EQuantizationTestCase # resetting dynamo cache export_with_dynamic_shape = False torch._dynamo.reset() m_eager = M().eval() - _ = torch._export.export(m_eager, example_inputs) # program capture m = copy.deepcopy(m_eager) @@ -363,11 +210,8 @@ def forward(self, x): # Calibrate m(*example_inputs) m = convert_pt2e(m, fold_quantize=False) - m = torch._export.export(m, example_inputs) - print("m:", m) - pt2_quant_output = m(*example_inputs) - print("output:", pt2_quant_output) + node_occurrence = { ns.call_function(k): v for k, v in node_occurrence.items() } diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index a3465e67ac..6d6edfceaf 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -1,15 +1,11 @@ import torch import torch._prims_common as utils +import torch.utils._pytree as pytree +from torch.library import Library, impl -# all test pass -# TODO: uint8 --> bits8 (currently blocked on bits8 supporting copy_ to long, or support >> operation -# TODO: change int4_tensor.dtype to return torch.int4 -# module swap --> subclass (for it to be composable with distributed, sparsity etc. subclasses) +# TODO: change int4_tensor.dtype to return torch.int4 (currently it's torch.bits8) -# class int4_class(torch.dtype): -# pass - -torch.int4 = torch.dtype("int4") +torch.int4 = torch.dtype(torch.bits8, "int4") def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" @@ -44,30 +40,56 @@ def fill_defaults(args, n, defaults_tail): # from # https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233 -def unpack_uint4(quantized_data) -> torch.Tensor: +def unpack_uint4(bits8_data) -> torch.Tensor: """Get the original weight from the normalized float weight format""" # since we are using uint8 we will decode 2 entries per byte # Shift elements down 4 and select out the bottom 4 bits - first_elements = (quantized_data >> 4).to(torch.uint8) - second_elements = (quantized_data & 0b1111).to(torch.uint8) - return torch.stack([first_elements, second_elements], dim=-1) + uint8_data = bits8_data.view(torch.uint8) + shape = uint8_data.shape + first_elements = (uint8_data >> 4).to(torch.uint8) + second_elements = (uint8_data & 0b1111).to(torch.uint8) + return torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape)).view(torch.bits8) -def pack_uint4(uint8_data) -> torch.Tensor: +def pack_uint4(bits8_data) -> torch.Tensor: + # converting to uint8 for operations + uint8_data = bits8_data.view(torch.uint8) shape = uint8_data.shape assert shape[-1] % 2 == 0 uint8_data = uint8_data.contiguous().view(-1) - print("size 1:", (uint8_data[::2] << 4).size()) - print("size 2:", (uint8_data[1::2] << 4).size()) - return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) + return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)).view(torch.bits8) + +test_lib = Library("test_int4", "DEF") +test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") + +@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") +def quantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, +) -> torch.Tensor: + inv_scale = 1.0 / scale + return pack_uint4(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8)) + +test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") +@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") +def dequantize_per_tensor_int4( + input: torch.Tensor, + scale: float, + zero_point: int, +) -> torch.Tensor: + input = unpack_uint4(input) + return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale class UInt4Tensor(torch.Tensor): @staticmethod - def __new__(cls, elem): + def __new__(cls, elem, **kwargs): # TODO: uint64 here is wrong, need a real dtype. Don't try to(int64) # weird shit will happen - assert elem.dtype is torch.uint8 + assert elem.dtype is torch.bits8 + assert not kwargs.get("requires_grad", False) + kwargs["requires_grad"] = False # TODO: right now tensor.dtype still displays bits8 - return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int4) + return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int4, **kwargs) def __init__(self, elem): self.elem = elem @@ -77,7 +99,7 @@ def from_unpacked(cls, unpacked): return UInt4Tensor(pack_uint4(unpacked)) def tolist(self): - return self.to(torch.uint8).tolist() + return self.to(torch.bits8).tolist() def __tensor_flatten__(self): return ["elem"], None @@ -93,32 +115,44 @@ def __hash__(self): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - print(f"func called: {func}") if func is torch.ops.aten.view.default: self, size = args size = utils.infer_size(size, self.numel()) assert not kwargs # WARNING: views not preserved return UInt4Tensor(self.elem.reshape(down_size(size))) + elif func is torch.ops.aten.view.dtype: + self, dtype = args + if dtype == torch.uint8: + return unpack_uint4(self.elem).view(torch.uint8) + return NotImplementedError(f"view {args}") + elif func is torch.ops.aten.to.dtype: + self, dtype = args + if dtype == torch.uint8: + return unpack_uint4(self.elem).view(torch.uint8) + return NotImplementedError(f"to {args}") + elif func is torch.ops.aten.eq.Tensor: + args = pytree.tree_map_only(UInt4Tensor, lambda x: x.elem.view(torch.uint8), args) + kwargs = pytree.tree_map_only(UInt4Tensor, lambda x: x.elem.view(torch.uint8), kwargs) + return torch.ops.aten.eq.Tensor(*args, **kwargs) elif func is torch.ops.aten._to_copy.default: - # print("_to_copy:", args) self, = args if kwargs == {'dtype': torch.uint8}: - return unpack_uint4(self.elem).view(self.shape) # no wrap + return unpack_uint4(self.elem).view(self.shape).view(torch.uint8) # no wrap else: raise NotImplementedError(f"_to_copy {kwargs}") elif func is torch.ops.aten.unbind.int: # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to # create four tensors containing one element each. But we can't # do this with uint4 because such a tensor's size is not divisible - # by bytes. What I am going to do instead is promote to uint8 + # by bytes. What I am going to do instead is promote to bits8 # when this happens self, dim = fill_defaults(args, 2, [0]) if dim != self.dim() - 1: raise NotImplementedError(f"unbind dim={dim}") else: # We're unbinding the last dimension, need to promote - return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind(dim) + return torch.ops.aten._to_copy.default(self, dtype=torch.bits8).unbind(dim) elif func is torch.ops.aten.select.int: self, dim, index = args if dim != self.dim() - 1: @@ -138,9 +172,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # easy case return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step)) elif func is torch.ops.aten.t.default: - assert False, "transpose is not properly implemented currently" + # assert False, "transpose is not properly implemented currently" self, = args - unpacked = unpack_uint4(self.elem).view(self.shape) + unpacked = unpack_uint4(self.elem) transposed = torch.ops.aten.t.default(unpacked) transposed_and_packed = pack_uint4(transposed) return UInt4Tensor(transposed_and_packed) @@ -158,7 +192,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): new_stride = [] for s in stride: if s != 1: - # since two int4 equals to 1 uint8 + # since two int4 equals to 1 bits8 new_stride.append(s // 2) else: new_stride.append(s) @@ -173,3 +207,115 @@ def __eq__(self, other): return torch.equal(self.elem, other.elem) __torch_function__ = torch._C._disabled_torch_function_impl + + +def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scale and zero point based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scale is the same dtype as the original tensor + scale = torch.clamp(scale, min=eps).to(x.dtype) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scale/zp + # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x.transpose(0, 1) / scale + x_round = torch.round(x_div) + x_zp = x_round + zero_point + x_zp = x_zp.transpose(0, 1) + quant = torch.clamp(x_zp, quant_min, quant_max) + if target_dtype == "int4": + quant = PerChannelSymmetricWeightUInt4Tensor.from_unpacked(quant.to(torch.uint8).view(torch.bits8), scale) + else: + quant = quant.to(target_dtype) + + return quant, scale, zero_point + +class PerChannelSymmetricWeightUInt4Tensor(UInt4Tensor): + @staticmethod + def __new__(cls, elem, scales): + return super().__new__(cls, elem) + + def __init__(self, elem, scales): + super().__init__(elem) + self.scales = scales + + @classmethod + def from_unpacked(cls, unpacked, scales): + return cls(pack_uint4(unpacked), scales) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + x, weight, bias = ( + args[0], + args[1], + args[2] if len(args)>2 else None + ) + x_view = x.view(-1, x.shape[-1]) + weight_scales = weight.scales + weight = weight.to(torch.uint8).to(x.dtype) + out = torch.mm(x_view, weight.t()) + out = out * weight_scales + out = out.reshape(*x.shape[:-1], -1) + if bias is not None: + out += bias + return out + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except: + print(f"ERR: subclass doesn't implement {func}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # didin't hit this, there is a mysterious error if we try to go through this path + # torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: + # Exception: Please convert all Tensors to FakeTensors first or + # instantiate FakeTensorMode with 'allow_non_fake_inputs' + if func is torch.ops.aten.addmm.default: + bias, x, weight = args + x_view = x.view(-1, x.shape[-1]) + y = torch.mm(x_view, weight.to(torch.uint8).to(x.dtype)) * weight.scales + y = y.reshape(*x.shape[:-1], -1) + if bias is not None: + y += bias + return y + elif func is torch.ops.aten.t.default: + # assert False, "transpose is not properly implemented currently" + self, = args + unpacked = unpack_uint4(self.elem) + transposed = torch.ops.aten.t.default(unpacked) + transposed_and_packed = pack_uint4(transposed) + return cls(transposed_and_packed, self.scales) + elif func is torch.ops.aten.detach.default: + self, = args + return self + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def from_float(cls, w_fp32): + w_int4, scales, _zp = _dynamically_quantize_per_channel_int4( + w_fp32, 0, 15, "int4" + ) + w_int4.to(device=w_fp32.device) + return w_int4 From 9a22440109b22167a3c1d0303a46f7efe71d2c09 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 21 Dec 2023 10:08:02 -0800 Subject: [PATCH 12/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torchao/dtypes/int4.py | 48 +++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index 6d6edfceaf..5b64f296b5 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -261,30 +261,30 @@ def __init__(self, elem, scales): def from_unpacked(cls, unpacked, scales): return cls(pack_uint4(unpacked), scales) - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = {} if kwargs is None else kwargs - - if func is torch.nn.functional.linear: - x, weight, bias = ( - args[0], - args[1], - args[2] if len(args)>2 else None - ) - x_view = x.view(-1, x.shape[-1]) - weight_scales = weight.scales - weight = weight.to(torch.uint8).to(x.dtype) - out = torch.mm(x_view, weight.t()) - out = out * weight_scales - out = out.reshape(*x.shape[:-1], -1) - if bias is not None: - out += bias - return out - try: - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - except: - print(f"ERR: subclass doesn't implement {func}") + # @classmethod + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # kwargs = {} if kwargs is None else kwargs + + # if func is torch.nn.functional.linear: + # x, weight, bias = ( + # args[0], + # args[1], + # args[2] if len(args)>2 else None + # ) + # x_view = x.view(-1, x.shape[-1]) + # weight_scales = weight.scales + # weight = weight.to(torch.uint8).to(x.dtype) + # out = torch.mm(x_view, weight.t()) + # out = out * weight_scales + # out = out.reshape(*x.shape[:-1], -1) + # if bias is not None: + # out += bias + # return out + # try: + # with torch._C.DisableTorchFunctionSubclass(): + # return func(*args, **kwargs) + # except: + # print(f"ERR: subclass doesn't implement {func}") @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): From c69d7530aa038884eaa1607224710a4b98b3d77a Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 2 Jan 2024 10:23:58 -0800 Subject: [PATCH 13/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 10 +++++----- torchao/dtypes/int4.py | 17 +++++++++++------ 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index 3fce6ff921..d0091662fe 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -50,8 +50,8 @@ def test_basic_tensor_ops(self): [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], ], dtype=torch.uint8).view(torch.bits8)) self.assertEqual(x.shape, (3, 16)) - # TODO: make sure this returns torch.int4 - self.assertEqual(x.dtype, torch.bits8) + # TODO: make sure this returns torch.uint4 + self.assertEqual(x.dtype, torch.uint4) # making sure these works x.to(torch.uint8) expected = UInt4Tensor(torch.tensor([ @@ -111,9 +111,9 @@ def convert(self, model: GraphModule, observer_node: Node): class Int8ActInt4WeightQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: int4_qspec = QuantizationSpec( - dtype=torch.int4, - quant_min=-2**3, - quant_max=2**3 - 1, + dtype=torch.uint4, + quant_min=0, + quant_max=2**4 - 1, qscheme=torch.per_tensor_affine, is_dynamic=False, observer_or_fake_quant_ctr=Int4Observer, diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index 5b64f296b5..f3d93af46c 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -3,9 +3,9 @@ import torch.utils._pytree as pytree from torch.library import Library, impl -# TODO: change int4_tensor.dtype to return torch.int4 (currently it's torch.bits8) +# TODO: make test_gpu_quant work with __torch_dispatch__ -torch.int4 = torch.dtype(torch.bits8, "int4") +torch.uint4 = torch.dtype(torch.bits8, "uint4") def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" @@ -88,8 +88,7 @@ def __new__(cls, elem, **kwargs): assert elem.dtype is torch.bits8 assert not kwargs.get("requires_grad", False) kwargs["requires_grad"] = False - # TODO: right now tensor.dtype still displays bits8 - return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int4, **kwargs) + return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs) def __init__(self, elem): self.elem = elem @@ -113,6 +112,11 @@ def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): def __hash__(self): return hash(self.elem) + def __getattribute__(self, name): + if name == "dtype": + return torch.uint4 + return super().__getattribute__(name) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): if func is torch.ops.aten.view.default: @@ -241,7 +245,8 @@ def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype x_zp = x_round + zero_point x_zp = x_zp.transpose(0, 1) quant = torch.clamp(x_zp, quant_min, quant_max) - if target_dtype == "int4": + if target_dtype == torch.uint4: + # TODO: simplify (maybe implement to) quant = PerChannelSymmetricWeightUInt4Tensor.from_unpacked(quant.to(torch.uint8).view(torch.bits8), scale) else: quant = quant.to(target_dtype) @@ -315,7 +320,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): @classmethod def from_float(cls, w_fp32): w_int4, scales, _zp = _dynamically_quantize_per_channel_int4( - w_fp32, 0, 15, "int4" + w_fp32, 0, 15, torch.uint4 ) w_int4.to(device=w_fp32.device) return w_int4 From 6f0aa7d48172c753f0ae0b98dd42d134d92afc99 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 16 Jan 2024 13:37:38 -0800 Subject: [PATCH 14/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/test_int4.py | 22 ++++++++++++-------- torchao/dtypes/int4.py | 45 ++++++++++++++++++---------------------- 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_int4.py index d0091662fe..7fe85b75b7 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_int4.py @@ -48,22 +48,26 @@ def test_basic_tensor_ops(self): [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8).view(torch.bits8)) + ], dtype=torch.uint8)) self.assertEqual(x.shape, (3, 16)) # TODO: make sure this returns torch.uint4 - self.assertEqual(x.dtype, torch.uint4) + self.assertIs(x.dtype, torch.uint4) # making sure these works x.to(torch.uint8) expected = UInt4Tensor(torch.tensor([ [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8).view(torch.bits8)) + ], dtype=torch.uint8)) self.assertEqual(x[0:1, :], expected) expected = UInt4Tensor(torch.tensor([ [0x23, 0x45], [0x23, 0x45], [0x23, 0x45], - ], dtype=torch.uint8).view(torch.bits8)) + ], dtype=torch.uint8)) self.assertEqual(x[:, 2:6], expected) + torch.save(x, "uint4_tensor.pt") + x = torch.load("uint4_tensor.pt") + self.assertEqual(x[:, 2:6], expected) + print("x:", x[0]) def test_gpu_quant(self): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: @@ -97,9 +101,9 @@ def calculate_qparams(self, **kwargs): def convert(self, model: GraphModule, observer_node: Node): with model.graph.inserting_before(observer_node): q_node = model.graph.call_function( - torch.ops.test_int4.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {}) + torch.ops.qtensors.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {}) dq_node = model.graph.call_function( - torch.ops.test_int4.dequantize_per_tensor_int4, (q_node, 1.0, 0), {}) + torch.ops.qtensors.dequantize_per_tensor_int4, (q_node, 1.0, 0), {}) observer_node.replace_all_uses_with(dq_node) model.graph.erase_node(observer_node) @@ -179,15 +183,15 @@ def forward(self, x): quantizer = Int8ActInt4WeightQuantizer() node_occurrence = { # for weight - torch.ops.test_int4.quantize_per_tensor_int4: 1, - torch.ops.test_int4.dequantize_per_tensor_int4: 1, + torch.ops.qtensors.quantize_per_tensor_int4: 1, + torch.ops.qtensors.dequantize_per_tensor_int4: 1, # for activation torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, } node_list = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.test_int4.dequantize_per_tensor_int4, + torch.ops.qtensors.dequantize_per_tensor_int4, torch.ops.aten.linear.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, ] diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/int4.py index f3d93af46c..8a7b72bb0b 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/int4.py @@ -5,8 +5,6 @@ # TODO: make test_gpu_quant work with __torch_dispatch__ -torch.uint4 = torch.dtype(torch.bits8, "uint4") - def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" return (*size[:-1], size[-1] // 2) @@ -40,38 +38,36 @@ def fill_defaults(args, n, defaults_tail): # from # https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233 -def unpack_uint4(bits8_data) -> torch.Tensor: +def unpack_uint4(uint8_data) -> torch.Tensor: """Get the original weight from the normalized float weight format""" # since we are using uint8 we will decode 2 entries per byte # Shift elements down 4 and select out the bottom 4 bits - uint8_data = bits8_data.view(torch.uint8) shape = uint8_data.shape first_elements = (uint8_data >> 4).to(torch.uint8) second_elements = (uint8_data & 0b1111).to(torch.uint8) - return torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape)).view(torch.bits8) + return torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape)) -def pack_uint4(bits8_data) -> torch.Tensor: +def pack_uint4(uint8_data) -> torch.Tensor: # converting to uint8 for operations - uint8_data = bits8_data.view(torch.uint8) shape = uint8_data.shape assert shape[-1] % 2 == 0 uint8_data = uint8_data.contiguous().view(-1) - return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)).view(torch.bits8) + return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) -test_lib = Library("test_int4", "DEF") -test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") +qtensor_lib = Library("qtensors", "DEF") +qtensor_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") -@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") +@impl(qtensor_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") def quantize_per_tensor_int4( input: torch.Tensor, scale: float, zero_point: int, ) -> torch.Tensor: inv_scale = 1.0 / scale - return pack_uint4(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8)) + return pack_uint4(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8)) -test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") -@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") +qtensor_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") +@impl(qtensor_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") def dequantize_per_tensor_int4( input: torch.Tensor, scale: float, @@ -85,7 +81,7 @@ class UInt4Tensor(torch.Tensor): def __new__(cls, elem, **kwargs): # TODO: uint64 here is wrong, need a real dtype. Don't try to(int64) # weird shit will happen - assert elem.dtype is torch.bits8 + assert elem.dtype is torch.uint8 assert not kwargs.get("requires_grad", False) kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs) @@ -98,7 +94,7 @@ def from_unpacked(cls, unpacked): return UInt4Tensor(pack_uint4(unpacked)) def tolist(self): - return self.to(torch.bits8).tolist() + return self.to(torch.uint8).tolist() def __tensor_flatten__(self): return ["elem"], None @@ -112,6 +108,9 @@ def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): def __hash__(self): return hash(self.elem) + def __eq__(self, other): + return torch.equal(self.elem, other.elem) + def __getattribute__(self, name): if name == "dtype": return torch.uint4 @@ -142,21 +141,21 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): elif func is torch.ops.aten._to_copy.default: self, = args if kwargs == {'dtype': torch.uint8}: - return unpack_uint4(self.elem).view(self.shape).view(torch.uint8) # no wrap + return unpack_uint4(self.elem).view(self.shape) # no wrap else: raise NotImplementedError(f"_to_copy {kwargs}") elif func is torch.ops.aten.unbind.int: # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to # create four tensors containing one element each. But we can't # do this with uint4 because such a tensor's size is not divisible - # by bytes. What I am going to do instead is promote to bits8 + # by bytes. What I am going to do instead is promote to uint8 # when this happens self, dim = fill_defaults(args, 2, [0]) if dim != self.dim() - 1: raise NotImplementedError(f"unbind dim={dim}") else: # We're unbinding the last dimension, need to promote - return torch.ops.aten._to_copy.default(self, dtype=torch.bits8).unbind(dim) + return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind(dim) elif func is torch.ops.aten.select.int: self, dim, index = args if dim != self.dim() - 1: @@ -196,7 +195,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): new_stride = [] for s in stride: if s != 1: - # since two int4 equals to 1 bits8 + # since two int4 equals to 1 uint8 new_stride.append(s // 2) else: new_stride.append(s) @@ -207,9 +206,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): raise NotImplementedError(f"{func}") - def __eq__(self, other): - return torch.equal(self.elem, other.elem) - __torch_function__ = torch._C._disabled_torch_function_impl @@ -247,7 +243,7 @@ def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype quant = torch.clamp(x_zp, quant_min, quant_max) if target_dtype == torch.uint4: # TODO: simplify (maybe implement to) - quant = PerChannelSymmetricWeightUInt4Tensor.from_unpacked(quant.to(torch.uint8).view(torch.bits8), scale) + quant = PerChannelSymmetricWeightUInt4Tensor.from_unpacked(quant.to(torch.uint8), scale) else: quant = quant.to(target_dtype) @@ -266,7 +262,6 @@ def __init__(self, elem, scales): def from_unpacked(cls, unpacked, scales): return cls(pack_uint4(unpacked), scales) - # @classmethod # def __torch_function__(cls, func, types, args=(), kwargs=None): # kwargs = {} if kwargs is None else kwargs From c52b1235a32bc3c4091a897240b449639fe81560 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 17 Jan 2024 15:45:55 -0800 Subject: [PATCH 15/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- test/dtypes/{test_int4.py => test_uint4.py} | 35 +++++------ torchao/dtypes/__init__.py | 2 +- torchao/dtypes/{int4.py => uint4.py} | 67 ++++++++------------- 3 files changed, 44 insertions(+), 60 deletions(-) rename test/dtypes/{test_int4.py => test_uint4.py} (89%) rename torchao/dtypes/{int4.py => uint4.py} (86%) diff --git a/test/dtypes/test_int4.py b/test/dtypes/test_uint4.py similarity index 89% rename from test/dtypes/test_int4.py rename to test/dtypes/test_uint4.py index 7fe85b75b7..19486870ac 100644 --- a/test/dtypes/test_int4.py +++ b/test/dtypes/test_uint4.py @@ -1,5 +1,5 @@ import torch -from torchao.dtypes.int4 import ( +from torchao.dtypes.uint4 import ( UInt4Tensor, PerChannelSymmetricWeightUInt4Tensor, ) @@ -31,7 +31,7 @@ ) import copy -def _apply_weight_only_int4_quant(model): +def _apply_weight_only_uint4_quant(model): def fn(mod): mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False) return mod @@ -42,7 +42,7 @@ def fn(mod): lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) -class TestInt4(QuantizationTestCase): +class TestUInt4(QuantizationTestCase): def test_basic_tensor_ops(self): x = UInt4Tensor(torch.tensor([ [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], @@ -67,17 +67,18 @@ def test_basic_tensor_ops(self): torch.save(x, "uint4_tensor.pt") x = torch.load("uint4_tensor.pt") self.assertEqual(x[:, 2:6], expected) - print("x:", x[0]) + # only test locally + # print("x:", x[0]) def test_gpu_quant(self): for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: x = torch.randn(*x_shape) m = nn.Sequential(nn.Linear(4, 16)) y_ref = m(x) - _apply_weight_only_int4_quant(m) + _apply_weight_only_uint4_quant(m) y_wo = m(x) # sqnr = compute_error(y_ref, y_wo) - opt = torch.compile(m, mode="max-autotune") + opt = torch.compile(m, fullgraph=True, mode="max-autotune") # make sure it runs opt(x) @@ -86,7 +87,7 @@ def test_pt2e_quant(self): OP_TO_ANNOTATOR, QuantizationConfig, ) - class Int4Observer(ObserverBase): + class Uint4Observer(ObserverBase): def __init__(self, *args, **kwargs): # just faking a dtype here # TODO: make flow work with new dtypes @@ -101,9 +102,9 @@ def calculate_qparams(self, **kwargs): def convert(self, model: GraphModule, observer_node: Node): with model.graph.inserting_before(observer_node): q_node = model.graph.call_function( - torch.ops.qtensors.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {}) + torch.ops.qtensors.quantize_per_tensor_uint4, (observer_node.args[0], 1.0, 0), {}) dq_node = model.graph.call_function( - torch.ops.qtensors.dequantize_per_tensor_int4, (q_node, 1.0, 0), {}) + torch.ops.qtensors.dequantize_per_tensor_uint4, (q_node, 1.0, 0), {}) observer_node.replace_all_uses_with(dq_node) model.graph.erase_node(observer_node) @@ -112,15 +113,15 @@ def convert(self, model: GraphModule, observer_node: Node): _mark_nodes_as_annotated, ) - class Int8ActInt4WeightQuantizer(Quantizer): + class Int8ActUint4WeightQuantizer(Quantizer): def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: - int4_qspec = QuantizationSpec( + uint4_qspec = QuantizationSpec( dtype=torch.uint4, quant_min=0, quant_max=2**4 - 1, qscheme=torch.per_tensor_affine, is_dynamic=False, - observer_or_fake_quant_ctr=Int4Observer, + observer_or_fake_quant_ctr=Uint4Observer, ) int8_qspec = QuantizationSpec( dtype=torch.int8, @@ -132,7 +133,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: ) quantization_config = QuantizationConfig( input_activation=int8_qspec, - weight=int4_qspec, + weight=uint4_qspec, bias=None, output_activation=int8_qspec, ) @@ -180,18 +181,18 @@ def __init__(self): def forward(self, x): return self.linear(x) - quantizer = Int8ActInt4WeightQuantizer() + quantizer = Int8ActUint4WeightQuantizer() node_occurrence = { # for weight - torch.ops.qtensors.quantize_per_tensor_int4: 1, - torch.ops.qtensors.dequantize_per_tensor_int4: 1, + torch.ops.qtensors.quantize_per_tensor_uint4: 1, + torch.ops.qtensors.dequantize_per_tensor_uint4: 1, # for activation torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, } node_list = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, - torch.ops.qtensors.dequantize_per_tensor_int4, + torch.ops.qtensors.dequantize_per_tensor_uint4, torch.ops.aten.linear.default, torch.ops.quantized_decomposed.quantize_per_tensor.default, ] diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 3dc5b7c2eb..5ea6ea3e67 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,4 +1,4 @@ -from .int4 import UInt4Tensor +from .uint4 import UInt4Tensor __all__ = [ "UInt4Tensor" diff --git a/torchao/dtypes/int4.py b/torchao/dtypes/uint4.py similarity index 86% rename from torchao/dtypes/int4.py rename to torchao/dtypes/uint4.py index 8a7b72bb0b..07548e42e9 100644 --- a/torchao/dtypes/int4.py +++ b/torchao/dtypes/uint4.py @@ -3,8 +3,6 @@ import torch.utils._pytree as pytree from torch.library import Library, impl -# TODO: make test_gpu_quant work with __torch_dispatch__ - def down_size(size): assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" return (*size[:-1], size[-1] // 2) @@ -55,10 +53,10 @@ def pack_uint4(uint8_data) -> torch.Tensor: return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) qtensor_lib = Library("qtensors", "DEF") -qtensor_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") +qtensor_lib.define("quantize_per_tensor_uint4(Tensor input, float scale, int zero_point) -> Tensor") -@impl(qtensor_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") -def quantize_per_tensor_int4( +@impl(qtensor_lib, "quantize_per_tensor_uint4", "CompositeExplicitAutograd") +def quantize_per_tensor_uint4( input: torch.Tensor, scale: float, zero_point: int, @@ -66,9 +64,9 @@ def quantize_per_tensor_int4( inv_scale = 1.0 / scale return pack_uint4(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8)) -qtensor_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") -@impl(qtensor_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") -def dequantize_per_tensor_int4( +qtensor_lib.define("dequantize_per_tensor_uint4(Tensor input, float scale, int zero_point) -> Tensor") +@impl(qtensor_lib, "dequantize_per_tensor_uint4", "CompositeExplicitAutograd") +def dequantize_per_tensor_uint4( input: torch.Tensor, scale: float, zero_point: int, @@ -79,8 +77,6 @@ def dequantize_per_tensor_int4( class UInt4Tensor(torch.Tensor): @staticmethod def __new__(cls, elem, **kwargs): - # TODO: uint64 here is wrong, need a real dtype. Don't try to(int64) - # weird shit will happen assert elem.dtype is torch.uint8 assert not kwargs.get("requires_grad", False) kwargs["requires_grad"] = False @@ -111,10 +107,10 @@ def __hash__(self): def __eq__(self, other): return torch.equal(self.elem, other.elem) - def __getattribute__(self, name): - if name == "dtype": - return torch.uint4 - return super().__getattribute__(name) + # def __getattribute__(self, name): + # if name == "dtype": + # return torch.uint4 + # return super().__getattribute__(name) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -255,37 +251,25 @@ def __new__(cls, elem, scales): return super().__new__(cls, elem) def __init__(self, elem, scales): - super().__init__(elem) + # super().__init__(elem) + self.elem = elem self.scales = scales + + def __tensor_flatten__(self): + return ["elem", "scales"], None + + @staticmethod + def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): + assert meta is None + elem = flattened["elem"] + scales = flattened["scales"] + return PerChannelSymmetricWeightUInt4Tensor(elem, scales) + @classmethod def from_unpacked(cls, unpacked, scales): return cls(pack_uint4(unpacked), scales) - # def __torch_function__(cls, func, types, args=(), kwargs=None): - # kwargs = {} if kwargs is None else kwargs - - # if func is torch.nn.functional.linear: - # x, weight, bias = ( - # args[0], - # args[1], - # args[2] if len(args)>2 else None - # ) - # x_view = x.view(-1, x.shape[-1]) - # weight_scales = weight.scales - # weight = weight.to(torch.uint8).to(x.dtype) - # out = torch.mm(x_view, weight.t()) - # out = out * weight_scales - # out = out.reshape(*x.shape[:-1], -1) - # if bias is not None: - # out += bias - # return out - # try: - # with torch._C.DisableTorchFunctionSubclass(): - # return func(*args, **kwargs) - # except: - # print(f"ERR: subclass doesn't implement {func}") - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # didin't hit this, there is a mysterious error if we try to go through this path @@ -305,8 +289,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): self, = args unpacked = unpack_uint4(self.elem) transposed = torch.ops.aten.t.default(unpacked) - transposed_and_packed = pack_uint4(transposed) - return cls(transposed_and_packed, self.scales) + return PerChannelSymmetricWeightUInt4Tensor.from_unpacked(transposed, self.scales) elif func is torch.ops.aten.detach.default: self, = args return self @@ -317,5 +300,5 @@ def from_float(cls, w_fp32): w_int4, scales, _zp = _dynamically_quantize_per_channel_int4( w_fp32, 0, 15, torch.uint4 ) - w_int4.to(device=w_fp32.device) + w_int4 = w_int4.to(device=w_fp32.device) return w_int4 From 49c6a43983f13039cdf6d29734fe7033ae15f817 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 22 Jan 2024 10:04:28 -0800 Subject: [PATCH 16/17] Update on "Adding uint4 dtype implementation" Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torchao/dtypes/uint4.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/torchao/dtypes/uint4.py b/torchao/dtypes/uint4.py index 07548e42e9..88cf5e55b5 100644 --- a/torchao/dtypes/uint4.py +++ b/torchao/dtypes/uint4.py @@ -107,11 +107,6 @@ def __hash__(self): def __eq__(self, other): return torch.equal(self.elem, other.elem) - # def __getattribute__(self, name): - # if name == "dtype": - # return torch.uint4 - # return super().__getattribute__(name) - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): if func is torch.ops.aten.view.default: @@ -272,10 +267,6 @@ def from_unpacked(cls, unpacked, scales): @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - # didin't hit this, there is a mysterious error if we try to go through this path - # torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: - # Exception: Please convert all Tensors to FakeTensors first or - # instantiate FakeTensorMode with 'allow_non_fake_inputs' if func is torch.ops.aten.addmm.default: bias, x, weight = args x_view = x.view(-1, x.shape[-1]) @@ -285,7 +276,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): y += bias return y elif func is torch.ops.aten.t.default: - # assert False, "transpose is not properly implemented currently" + # TODO: add proper support for transpose self, = args unpacked = unpack_uint4(self.elem) transposed = torch.ops.aten.t.default(unpacked) From 8ab52a7a19d5b1a1da6492bd15d35d124e6f7a1e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Feb 2024 21:50:09 -0800 Subject: [PATCH 17/17] Update on "Adding uint4 dtype implementation" Summary: This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this we plan to move the uint4 tensor subclass to core after it is more mature Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torchao/dtypes/uint4.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torchao/dtypes/uint4.py b/torchao/dtypes/uint4.py index 88cf5e55b5..9769bebf40 100644 --- a/torchao/dtypes/uint4.py +++ b/torchao/dtypes/uint4.py @@ -82,7 +82,7 @@ def __new__(cls, elem, **kwargs): kwargs["requires_grad"] = False return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs) - def __init__(self, elem): + def __init__(self, elem, **kwargs): self.elem = elem @classmethod @@ -242,12 +242,11 @@ def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype class PerChannelSymmetricWeightUInt4Tensor(UInt4Tensor): @staticmethod - def __new__(cls, elem, scales): - return super().__new__(cls, elem) + def __new__(cls, elem, scales, **kwargs): + return super().__new__(cls, elem, **kwargs) - def __init__(self, elem, scales): - # super().__init__(elem) - self.elem = elem + def __init__(self, elem, scales, **kwargs): + super().__init__(elem, **kwargs) self.scales = scales