Skip to content

Adding uint4 dtype implementation #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions test/dtypes/test_int4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import torch
from torchao.dtypes.int4 import UInt4Tensor
import unittest
from unittest import TestCase, main
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer

from torch._export import capture_pre_autograd_graph
from torch._export import dynamic_dim
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
)
import copy


class TestInt4(QuantizationTestCase):
def test_basic_tensor_ops(self):
x = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
self.assertTrue(x.shape, (3, 8))
# making sure these works
x.to(torch.uint8)
expected = UInt4Tensor(torch.tensor([
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
], dtype=torch.uint8))
self.assertTrue(x[0:1, :] == expected)
expected = UInt4Tensor(torch.tensor([
[0x23, 0x45],
[0x23, 0x45],
[0x23, 0x45],
], dtype=torch.uint8))
self.assertTrue(x[:, 2:6] == expected)

def test_gpu_quant(self):
pass

def test_aten_ir(self):
# from torch.library import Library, impl
# test_lib = Library("test_int4", "DEF")
# test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
# @impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
# def quantize_per_tensor_int4(
# input: torch.Tensor,
# scale: float,
# zero_point: int,
# ) -> torch.Tensor:
# inv_scale = 1.0 / scale
# return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8))

# test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
# @impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
# def dequantize_per_tensor_int4(
# input: torch.Tensor,
# scale: float,
# zero_point: int,
# ) -> torch.Tensor:
# return (input.to(torch.float32) - zero_point) * scale

class QuantizePerTensorUInt4(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input: torch.Tensor,
scale: float,
zero_point: int,
) -> torch.Tensor:
inv_scale = 1.0 / scale
return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8))

class DeQuantizePerTensorUInt4(torch.autograd.Function):
@staticmethod
def forward(
ctx,
input: torch.Tensor,
scale: float,
zero_point: int,
) -> torch.Tensor:
return (input.to(torch.float32) - zero_point) * scale

class M(torch.nn.Module):
def forward(self, x, y):
return x + y

example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),)
m = M().eval()
m = capture_pre_autograd_graph(m, example_inputs)
qop = QuantizePerTensorUInt4.apply
dqop = DeQuantizePerTensorUInt4.apply
for n in m.graph.nodes:
if n.target == torch.ops.aten.add.Tensor:
with m.graph.inserting_before(n):
# q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {})
# dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {})
q = m.graph.call_function(qop, (n.args[0], 1.0, 0), {})
dq = m.graph.call_function(dqop, (q, 1.0, 0), {})
n.replace_input_with(n.args[0], dq)
m.recompile()
print("m:", m)
print(m(*example_inputs))

# TODO: need more extension points from quant flow side
@unittest.skip("need more extension points from quant flow side")
def test_pt2e_quant(self):
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
QuantizationConfig,
)

class Int4ActQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
int4_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-2**3,
quant_max=2**3 - 1,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer,
)
int8_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_weight_observer,
)
quantization_config = QuantizationConfig(
input_activation=int8_qspec,
weight=int4_qspec,
bias=None,
output_activation=int8_qspec,
)
OP_TO_ANNOTATOR["conv"](model, quantization_config)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)

def forward(self, x):
return self.conv(x)

quantizer = Int4ActQuantizer()
node_occurrence = {
# one for input of the first conv, one for output for the first conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
example_inputs = (torch.randn(1, 3, 3, 3),)

# _test_quantizer in PT2EQuantizationTestCase
# resetting dynamo cache
export_with_dynamic_shape = False
torch._dynamo.reset()
m_eager = M().eval()

# program capture
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
constraints=[dynamic_dim(example_inputs[0], 0)] if export_with_dynamic_shape else [],
)

m = prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=True)

pt2_quant_output = m(*example_inputs)
node_occurrence = {
ns.call_function(k): v for k, v in expected_node_occurrence.items()
}
if expected_node_list is None:
expected_node_list = []
node_list = [ns.call_function(n) for n in expected_node_list]
self.checkGraphModuleNodes(
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
)

if __name__ == "__main__":
main()
5 changes: 5 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .int4 import UInt4Tensor

__all__ = [
"UInt4Tensor"
]
108 changes: 108 additions & 0 deletions torchao/dtypes/int4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
import torch._prims_common as utils

def down_size(size):
assert size[-1] % 2 == 0, f"{size} last dim not divisible by two"
return (*size[:-1], size[-1] // 2)

def up_size(size):
return (*size[:-1], size[-1] * 2)

def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r

# from
# https://github.com/drisspg/transformer_nuggets/blob/9ad3a7fc552a954eb702ade0e276b8d8e09c3db6/transformer_nuggets/quant/qlora.py#L233
def unpack_uint4(quantized_data) -> torch.Tensor:
"""Get the original weight from the normalized float weight format"""
# since we are using uint8 we will decode 2 entries per byte
# Shift elements down 4 and select out the bottom 4 bits
first_elements = (quantized_data >> 4).to(torch.uint8)
second_elements = (quantized_data & 0b1111).to(torch.uint8)
return torch.stack([first_elements, second_elements], dim=-1)

class UInt4Tensor(torch.Tensor):
@staticmethod
def __new__(cls, elem):
# TODO: uint64 here is wrong, need a real dtype. Don't try to(int64)
# weird shit will happen
assert elem.dtype is torch.uint8
return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.int64)

def __init__(self, elem):
self.elem = elem

def tolist(self):
return self.to(torch.uint8).tolist()

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func is torch.ops.aten.view.default:
self, size = args
size = utils.infer_size(size, self.numel())
assert not kwargs
# WARNING: views not preserved
return UInt4Tensor(self.elem.reshape(down_size(size)))
elif func is torch.ops.aten._to_copy.default:
self, = args
if kwargs == {'dtype': torch.uint8}:
return unpack_uint4(self.elem).view(self.shape) # no wrap
else:
raise NotImplementedError(f"_to_copy {kwargs}")
elif func is torch.ops.aten.unbind.int:
# This is tricky. Given torch.tensor([0, 1, 2, 3]) we want to
# create four tensors containing one element each. But we can't
# do this with uint4 because such a tensor's size is not divisible
# by bytes. What I am going to do instead is promote to uint8
# when this happens
self, dim = fill_defaults(args, 2, [0])
if dim != self.dim() - 1:
raise NotImplementedError(f"unbind dim={dim}")
else:
# We're unbinding the last dimension, need to promote
return torch.ops.aten._to_copy.default(self, dtype=torch.uint8).unbind(dim)
elif func is torch.ops.aten.select.int:
self, dim, index = args
if dim != self.dim() - 1:
return UInt4Tensor(torch.ops.aten.select.int(self.elem, dim, index))
else:
raise NotImplementedError(f"select dim={dim}")
elif func is torch.ops.aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == self.dim() - 1:
# hard case
if step != 1:
raise NotImplementedError(f"slice step={step}")
assert start % 2 == 0, start
assert end >= self.shape[dim] or end % 2 == 0, end
return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start // 2, end // 2, 1))
else:
# easy case
return UInt4Tensor(torch.ops.aten.slice.Tensor(self.elem, dim, start, end, step))
raise NotImplementedError(f"{func}")

def __eq__(self, other):
return torch.equal(self.elem, other.elem)

__torch_function__ = torch._C._disabled_torch_function_impl