diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py new file mode 100644 index 0000000000..19486870ac --- /dev/null +++ b/test/dtypes/test_uint4.py @@ -0,0 +1,229 @@ +import torch +from torchao.dtypes.uint4 import ( + UInt4Tensor, + PerChannelSymmetricWeightUInt4Tensor, +) +import unittest +from unittest import TestCase, main +from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer + +from torch._export import capture_pre_autograd_graph +from torch._export import dynamic_dim +from torch.testing._internal.common_quantization import ( + NodeSpec as ns, + QuantizationTestCase, +) +from torchao.quantization.utils import ( + compute_error, +) +from torchao.quantization.quant_api import ( + replace_with_custom_fn_if_matches_filter, +) +from torch.ao.quantization.observer import ObserverBase +from torch import nn +from torch.fx import ( + Node, + GraphModule, +) +from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, +) +import copy + +def _apply_weight_only_uint4_quant(model): + def fn(mod): + mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False) + return mod + + replace_with_custom_fn_if_matches_filter( + model, + lambda mod: fn(mod), + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + +class TestUInt4(QuantizationTestCase): + def test_basic_tensor_ops(self): + x = UInt4Tensor(torch.tensor([ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], dtype=torch.uint8)) + self.assertEqual(x.shape, (3, 16)) + # TODO: make sure this returns torch.uint4 + self.assertIs(x.dtype, torch.uint4) + # making sure these works + x.to(torch.uint8) + expected = UInt4Tensor(torch.tensor([ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], dtype=torch.uint8)) + self.assertEqual(x[0:1, :], expected) + expected = UInt4Tensor(torch.tensor([ + [0x23, 0x45], + [0x23, 0x45], + [0x23, 0x45], + ], dtype=torch.uint8)) + self.assertEqual(x[:, 2:6], expected) + torch.save(x, "uint4_tensor.pt") + x = torch.load("uint4_tensor.pt") + self.assertEqual(x[:, 2:6], expected) + # only test locally + # print("x:", x[0]) + + def test_gpu_quant(self): + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + x = torch.randn(*x_shape) + m = nn.Sequential(nn.Linear(4, 16)) + y_ref = m(x) + _apply_weight_only_uint4_quant(m) + y_wo = m(x) + # sqnr = compute_error(y_ref, y_wo) + opt = torch.compile(m, fullgraph=True, mode="max-autotune") + # make sure it runs + opt(x) + + def test_pt2e_quant(self): + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + OP_TO_ANNOTATOR, + QuantizationConfig, + ) + class Uint4Observer(ObserverBase): + def __init__(self, *args, **kwargs): + # just faking a dtype here + # TODO: make flow work with new dtypes + super().__init__(dtype=torch.int8) + + def forward(self, x): + return x + + def calculate_qparams(self, **kwargs): + pass + + def convert(self, model: GraphModule, observer_node: Node): + with model.graph.inserting_before(observer_node): + q_node = model.graph.call_function( + torch.ops.qtensors.quantize_per_tensor_uint4, (observer_node.args[0], 1.0, 0), {}) + dq_node = model.graph.call_function( + torch.ops.qtensors.dequantize_per_tensor_uint4, (q_node, 1.0, 0), {}) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) + + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( + _is_annotated, + _mark_nodes_as_annotated, + ) + + class Int8ActUint4WeightQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + uint4_qspec = QuantizationSpec( + dtype=torch.uint4, + quant_min=0, + quant_max=2**4 - 1, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=Uint4Observer, + ) + int8_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer, + ) + quantization_config = QuantizationConfig( + input_activation=int8_qspec, + weight=uint4_qspec, + bias=None, + output_activation=int8_qspec, + ) + for n in model.graph.nodes: + if n.op != "call_function" or n.target not in [ + torch.ops.aten.linear.default, + ]: + continue + linear_node = n + + input_qspec_map = {} + input_act = linear_node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + + weight = linear_node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = quantization_config.weight + + partition = [linear_node, linear_node.args[1]] + + bias = linear_node.args[2] if len(linear_node.args) > 2 else None + if isinstance(bias, Node): + input_qspec_map[bias] = quantization_config.bias + partition.append(bias) + + if _is_annotated(partition): + continue + + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + _mark_nodes_as_annotated(partition) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x) + + quantizer = Int8ActUint4WeightQuantizer() + node_occurrence = { + # for weight + torch.ops.qtensors.quantize_per_tensor_uint4: 1, + torch.ops.qtensors.dequantize_per_tensor_uint4: 1, + # for activation + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.qtensors.dequantize_per_tensor_uint4, + torch.ops.aten.linear.default, + torch.ops.quantized_decomposed.quantize_per_tensor.default, + ] + example_inputs = (torch.randn(2, 4),) + + # _test_quantizer in PT2EQuantizationTestCase + # resetting dynamo cache + export_with_dynamic_shape = False + torch._dynamo.reset() + m_eager = M().eval() + + # program capture + m = copy.deepcopy(m_eager) + m = capture_pre_autograd_graph( + m, + example_inputs, + ) + + m = prepare_pt2e(m, quantizer) + # Calibrate + m(*example_inputs) + m = convert_pt2e(m, fold_quantize=False) + pt2_quant_output = m(*example_inputs) + + node_occurrence = { + ns.call_function(k): v for k, v in node_occurrence.items() + } + node_list = [ns.call_function(n) for n in node_list] + self.checkGraphModuleNodes( + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list + ) + +if __name__ == "__main__": + main() diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py new file mode 100644 index 0000000000..5ea6ea3e67 --- /dev/null +++ b/torchao/dtypes/__init__.py @@ -0,0 +1,5 @@ +from .uint4 import UInt4Tensor + +__all__ = [ + "UInt4Tensor" +] diff --git a/torchao/dtypes/uint4.py b/torchao/dtypes/uint4.py new file mode 100644 index 0000000000..9769bebf40 --- /dev/null +++ b/torchao/dtypes/uint4.py @@ -0,0 +1,294 @@ +import torch +import torch._prims_common as utils +import torch.utils._pytree as pytree +from torch.library import Library, impl + +def down_size(size): + assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" + return (*size[:-1], size[-1] // 2) + +def up_size(size): + return (*size[:-1], size[-1] * 2) + +def fill_defaults(args, n, defaults_tail): + """ + __torch_dispatch__ doesn't guarantee the number of arguments you are + passed (e.g., defaulted arguments are not passed); but usually it is + convenient to pad out the arguments list with defaults. This function + helps you do that. + Args: + args: the list of positional arguments passed to __torch_dispatch__ + n: the number of arguments you are expecting to get + defaults_tail: default values for the arguments, starting from the + end of the list + Example: + >>> fill_defaults([1, 2, 3], 5, [3, 4, 5]) + [1, 2, 3, 4, 5] + >>> fill_defaults([1, 2, 3], 5, [None, None, None]) + [1, 2, 3, None, None]] + """ + if n - len(defaults_tail) > len(args): + raise RuntimeError("not enough defaults to fill arguments") + r = list(args) + for i in range(len(args), n): + r.append(defaults_tail[i - n + len(defaults_tail)]) + return r + +# from +# https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233 +def unpack_uint4(uint8_data) -> torch.Tensor: + """Get the original weight from the normalized float weight format""" + # since we are using uint8 we will decode 2 entries per byte + # Shift elements down 4 and select out the bottom 4 bits + shape = uint8_data.shape + first_elements = (uint8_data >> 4).to(torch.uint8) + second_elements = (uint8_data & 0b1111).to(torch.uint8) + return torch.stack([first_elements, second_elements], dim=-1).view(up_size(shape)) + +def pack_uint4(uint8_data) -> torch.Tensor: + # converting to uint8 for operations + shape = uint8_data.shape + assert shape[-1] % 2 == 0 + uint8_data = uint8_data.contiguous().view(-1) + return (uint8_data[::2] << 4 | uint8_data[1::2]).view(down_size(shape)) + +qtensor_lib = Library("qtensors", "DEF") +qtensor_lib.define("quantize_per_tensor_uint4(Tensor input, float scale, int zero_point) -> Tensor") + +@impl(qtensor_lib, "quantize_per_tensor_uint4", "CompositeExplicitAutograd") +def quantize_per_tensor_uint4( + input: torch.Tensor, + scale: float, + zero_point: int, +) -> torch.Tensor: + inv_scale = 1.0 / scale + return pack_uint4(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8)) + +qtensor_lib.define("dequantize_per_tensor_uint4(Tensor input, float scale, int zero_point) -> Tensor") +@impl(qtensor_lib, "dequantize_per_tensor_uint4", "CompositeExplicitAutograd") +def dequantize_per_tensor_uint4( + input: torch.Tensor, + scale: float, + zero_point: int, +) -> torch.Tensor: + input = unpack_uint4(input) + return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale + +class UInt4Tensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, **kwargs): + assert elem.dtype is torch.uint8 + assert not kwargs.get("requires_grad", False) + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs) + + def __init__(self, elem, **kwargs): + self.elem = elem + + @classmethod + def from_unpacked(cls, unpacked): + return UInt4Tensor(pack_uint4(unpacked)) + + def tolist(self): + return self.to(torch.uint8).tolist() + + def __tensor_flatten__(self): + return ["elem"], None + + @staticmethod + def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): + assert meta is None + elem = flattened["elem"] + return UInt4Tensor(elem) + + def __hash__(self): + return hash(self.elem) + + def __eq__(self, other): + return torch.equal(self.elem, other.elem) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func is torch.ops.aten.view.default: + self, size = args + size = utils.infer_size(size, self.numel()) + assert not kwargs + # WARNING: views not preserved + return UInt4Tensor(self.elem.reshape(down_size(size))) + elif func is torch.ops.aten.view.dtype: + self, dtype = args + if dtype == torch.uint8: + return unpack_uint4(self.elem).view(torch.uint8) + return NotImplementedError(f"view {args}") + elif func is torch.ops.aten.to.dtype: + self, dtype = args + if dtype == torch.uint8: + return unpack_uint4(self.elem).view(torch.uint8) + return NotImplementedError(f"to {args}") + elif func is torch.ops.aten.eq.Tensor: + args = pytree.tree_map_only(UInt4Tensor, lambda x: x.elem.view(torch.uint8), args) + kwargs = pytree.tree_map_only(UInt4Tensor, lambda x: x.elem.view(torch.uint8), kwargs) + return torch.ops.aten.eq.Tensor(*args, **kwargs) + elif func is torch.ops.aten._to_copy.default: + self, = args + if kwargs == {'dtype': torch.uint8}: + return unpack_uint4(self.elem).view(self.shape) # no wrap + else: + raise NotImplementedError(f"_to_copy {kwargs}") + elif func is torch.ops.aten.unbind.int: + # This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to + # create four tensors containing one element each. But we can't + # do this with uint4 because such a tensor's size is not divisible + # by bytes. What I am going to do instead is promote to uint8 + # when this happens + self, dim = fill_defaults(args, 2, [0]) + if dim != self.dim() - 1: + raise NotImplementedError(f"unbind dim={dim}") + else: + # We're unbinding the last dimension, need to promote + return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind(dim) + elif func is torch.ops.aten.select.int: + self, dim, index = args + if dim != self.dim() - 1: + return UInt4Tensor(torch.ops.aten.select.int(self.elem, dim, index)) + else: + raise NotImplementedError(f"select dim={dim}") + elif func is torch.ops.aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == self.dim() - 1: + # hard case + if step != 1: + raise NotImplementedError(f"slice step={step}") + assert start % 2 == 0, start + assert end >= self.shape[dim] or end % 2 == 0, end + return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start // 2, end // 2, 1)) + else: + # easy case + return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step)) + elif func is torch.ops.aten.t.default: + # assert False, "transpose is not properly implemented currently" + self, = args + unpacked = unpack_uint4(self.elem) + transposed = torch.ops.aten.t.default(unpacked) + transposed_and_packed = pack_uint4(transposed) + return UInt4Tensor(transposed_and_packed) + elif func is torch.ops.aten.transpose_copy.int: + self, dim0, dim1 = args + unpacked = unpack_uint4(self.elem).view(self.shape) + transposed = torch.ops.aten.transpose_copy.int(unpacked, dim0, dim1) + transposed_and_packed = pack_uint4(transposed) + return UInt4Tensor(transposed_and_packed) + elif func is torch.ops.aten.as_strided.default: + # size, stride, storage_offset are referring to tensor elements, not physical bytes + self, size, stride, storage_offset = args + size = down_size(size) + + new_stride = [] + for s in stride: + if s != 1: + # since two int4 equals to 1 uint8 + new_stride.append(s // 2) + else: + new_stride.append(s) + stride = new_stride + + storage_offset //= 2 + return UInt4Tensor(torch.ops.aten.as_strided.default(self.elem, size, stride, storage_offset)) + + raise NotImplementedError(f"{func}") + + __torch_function__ = torch._C._disabled_torch_function_impl + + +def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scale and zero point based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scale is the same dtype as the original tensor + scale = torch.clamp(scale, min=eps).to(x.dtype) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scale/zp + # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x.transpose(0, 1) / scale + x_round = torch.round(x_div) + x_zp = x_round + zero_point + x_zp = x_zp.transpose(0, 1) + quant = torch.clamp(x_zp, quant_min, quant_max) + if target_dtype == torch.uint4: + # TODO: simplify (maybe implement to) + quant = PerChannelSymmetricWeightUInt4Tensor.from_unpacked(quant.to(torch.uint8), scale) + else: + quant = quant.to(target_dtype) + + return quant, scale, zero_point + +class PerChannelSymmetricWeightUInt4Tensor(UInt4Tensor): + @staticmethod + def __new__(cls, elem, scales, **kwargs): + return super().__new__(cls, elem, **kwargs) + + def __init__(self, elem, scales, **kwargs): + super().__init__(elem, **kwargs) + self.scales = scales + + + def __tensor_flatten__(self): + return ["elem", "scales"], None + + @staticmethod + def __tensor_unflatten__(flattened, meta, outer_size, outer_stride): + assert meta is None + elem = flattened["elem"] + scales = flattened["scales"] + return PerChannelSymmetricWeightUInt4Tensor(elem, scales) + + @classmethod + def from_unpacked(cls, unpacked, scales): + return cls(pack_uint4(unpacked), scales) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + if func is torch.ops.aten.addmm.default: + bias, x, weight = args + x_view = x.view(-1, x.shape[-1]) + y = torch.mm(x_view, weight.to(torch.uint8).to(x.dtype)) * weight.scales + y = y.reshape(*x.shape[:-1], -1) + if bias is not None: + y += bias + return y + elif func is torch.ops.aten.t.default: + # TODO: add proper support for transpose + self, = args + unpacked = unpack_uint4(self.elem) + transposed = torch.ops.aten.t.default(unpacked) + return PerChannelSymmetricWeightUInt4Tensor.from_unpacked(transposed, self.scales) + elif func is torch.ops.aten.detach.default: + self, = args + return self + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def from_float(cls, w_fp32): + w_int4, scales, _zp = _dynamically_quantize_per_channel_int4( + w_fp32, 0, 15, torch.uint4 + ) + w_int4 = w_int4.to(device=w_fp32.device) + return w_int4