Skip to content

Commit 0c72951

Browse files
committed
Adding uint4 dtype implementation
Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: aa80ed3 Pull Request resolved: #13
1 parent e16898d commit 0c72951

File tree

3 files changed

+452
-0
lines changed

3 files changed

+452
-0
lines changed

test/dtypes/test_int4.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import torch
2+
from torchao.dtypes.int4 import UInt4Tensor
3+
import unittest
4+
from unittest import TestCase, main
5+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
6+
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
7+
8+
from torch._export import capture_pre_autograd_graph
9+
from torch._export import dynamic_dim
10+
from torch.testing._internal.common_quantization import (
11+
NodeSpec as ns,
12+
QuantizationTestCase,
13+
)
14+
from torchao.quantization.utils import (
15+
compute_error,
16+
)
17+
from torchao.quantization.quant_api import (
18+
replace_with_custom_fn_if_matches_filter,
19+
)
20+
from torch import nn
21+
import copy
22+
23+
def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype):
24+
# assumes symmetric quantization
25+
# assumes axis == 0
26+
# assumes dense memory format
27+
# TODO(future): relax ^ as needed
28+
29+
# default setup for affine quantization of activations
30+
eps = torch.finfo(torch.float32).eps
31+
32+
# get min and max
33+
min_val, max_val = torch.aminmax(x, dim=1)
34+
35+
# calculate scale and zero point based on min and max
36+
# reference: https://fburl.com/code/srbiybme
37+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
38+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
39+
device = min_val_neg.device
40+
41+
# reference: https://fburl.com/code/4wll53rk
42+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
43+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
44+
# ensure scale is the same dtype as the original tensor
45+
scale = torch.clamp(scale, min=eps).to(x.dtype)
46+
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
47+
48+
# quantize based on qmin/qmax/scale/zp
49+
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
50+
x_div = x.transpose(0, 1) / scale
51+
x_round = torch.round(x_div)
52+
x_zp = x_round + zero_point
53+
x_zp = x_zp.transpose(0, 1)
54+
quant = torch.clamp(x_zp, quant_min, quant_max)
55+
if target_dtype == "int4":
56+
quant = UInt4Tensor.from_unpacked(quant.to(torch.uint8)).view(quant.size())
57+
else:
58+
quant = quant.to(target_dtype)
59+
60+
return quant, scale, zero_point
61+
62+
class _WeightOnlyInt4QuantLinear(torch.nn.Linear):
63+
def __init__(self, *args, **kwargs):
64+
w_int4 = kwargs.pop("w_int4")
65+
scales = kwargs.pop("scales")
66+
super().__init__(*args, **kwargs)
67+
self.w_int4 = w_int4
68+
self.scales = scales
69+
70+
def forward(self, x):
71+
# if len(x.shape)<=2:
72+
# y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales
73+
# else: # turn x into 2d tensor, then undo it for y
74+
x_view = x.view(-1, x.shape[-1])
75+
y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales
76+
y = y.reshape(*x.shape[:-1], -1)
77+
if self.bias is not None:
78+
y += self.bias
79+
return y
80+
81+
@classmethod
82+
def from_float(cls, mod):
83+
w_fp32 = mod.weight
84+
w_int4, scales, _zp = _dynamically_quantize_per_channel_int4(
85+
w_fp32, 0, 15, "int4"
86+
)
87+
# create the new module with a toy size to ensure initialization is fast
88+
fake_in_features, fake_out_features = 8, 8
89+
new_mod = cls(
90+
fake_in_features,
91+
fake_out_features,
92+
bias=mod.bias is not None,
93+
w_int4=w_int4.t().contiguous(),
94+
scales=scales,
95+
)
96+
new_mod.in_features = mod.in_features
97+
new_mod.out_features = mod.out_features
98+
del new_mod.weight
99+
new_mod.bias = mod.bias
100+
device_to_use = next(mod.parameters()).device
101+
new_mod.to(device_to_use)
102+
return new_mod
103+
104+
def _apply_weight_only_int4_quant(model):
105+
replace_with_custom_fn_if_matches_filter(
106+
model,
107+
_WeightOnlyInt4QuantLinear.from_float,
108+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
109+
)
110+
111+
class TestInt4(QuantizationTestCase):
112+
def test_basic_tensor_ops(self):
113+
x = UInt4Tensor(torch.tensor([
114+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
115+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
116+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
117+
], dtype=torch.uint8))
118+
self.assertTrue(x.shape, (3, 8))
119+
# making sure these works
120+
x.to(torch.uint8)
121+
expected = UInt4Tensor(torch.tensor([
122+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
123+
], dtype=torch.uint8))
124+
self.assertTrue(x[0:1, :] == expected)
125+
expected = UInt4Tensor(torch.tensor([
126+
[0x23, 0x45],
127+
[0x23, 0x45],
128+
[0x23, 0x45],
129+
], dtype=torch.uint8))
130+
self.assertTrue(x[:, 2:6] == expected)
131+
132+
def test_gpu_quant(self):
133+
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
134+
x = torch.randn(*x_shape)
135+
m = nn.Sequential(nn.Linear(4, 16))
136+
y_ref = m(x)
137+
_apply_weight_only_int4_quant(m)
138+
y_wo = m(x)
139+
# sqnr = compute_error(y_ref, y_wo)
140+
opt = torch.compile(m, mode="max-autotune")
141+
# make sure it runs
142+
opt(x)
143+
144+
def test_aten_ir(self):
145+
from torch.library import Library, impl
146+
test_lib = Library("test_int4", "DEF")
147+
test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
148+
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
149+
def quantize_per_tensor_int4(
150+
input: torch.Tensor,
151+
scale: float,
152+
zero_point: int,
153+
) -> torch.Tensor:
154+
inv_scale = 1.0 / scale
155+
return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8)
156+
157+
test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
158+
@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
159+
def dequantize_per_tensor_int4(
160+
input: torch.Tensor,
161+
scale: float,
162+
zero_point: int,
163+
) -> torch.Tensor:
164+
return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale
165+
166+
# class QuantizePerTensorUInt4(torch.autograd.Function):
167+
# @staticmethod
168+
# def forward(
169+
# ctx,
170+
# input: torch.Tensor,
171+
# scale: float,
172+
# zero_point: int,
173+
# ) -> torch.Tensor:
174+
# inv_scale = 1.0 / scale
175+
# return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8))
176+
177+
# class DeQuantizePerTensorUInt4(torch.autograd.Function):
178+
# @staticmethod
179+
# def forward(
180+
# ctx,
181+
# input: torch.Tensor,
182+
# scale: float,
183+
# zero_point: int,
184+
# ) -> torch.Tensor:
185+
# return (input.to(torch.float32) - zero_point) * scale
186+
187+
class M(torch.nn.Module):
188+
def forward(self, x, y):
189+
return x + y
190+
191+
example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),)
192+
m = M().eval()
193+
m = capture_pre_autograd_graph(m, example_inputs)
194+
for n in m.graph.nodes:
195+
if n.target == torch.ops.aten.add.Tensor:
196+
with m.graph.inserting_before(n):
197+
q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {})
198+
dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {})
199+
n.replace_input_with(n.args[0], dq)
200+
m.recompile()
201+
202+
# TODO: need more extension points from quant flow side
203+
@unittest.skip("need more extension points from quant flow side")
204+
def test_pt2e_quant(self):
205+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
206+
OP_TO_ANNOTATOR,
207+
QuantizationConfig,
208+
)
209+
210+
class Int4ActQuantizer(Quantizer):
211+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
212+
int4_qspec = QuantizationSpec(
213+
dtype=torch.int8,
214+
quant_min=-2**3,
215+
quant_max=2**3 - 1,
216+
qscheme=torch.per_tensor_affine,
217+
is_dynamic=False,
218+
observer_or_fake_quant_ctr=observer.default_observer,
219+
)
220+
int8_qspec = QuantizationSpec(
221+
dtype=torch.int8,
222+
quant_min=-128,
223+
quant_max=127,
224+
qscheme=torch.per_tensor_symmetric,
225+
is_dynamic=False,
226+
observer_or_fake_quant_ctr=observer.default_weight_observer,
227+
)
228+
quantization_config = QuantizationConfig(
229+
input_activation=int8_qspec,
230+
weight=int4_qspec,
231+
bias=None,
232+
output_activation=int8_qspec,
233+
)
234+
OP_TO_ANNOTATOR["conv"](model, quantization_config)
235+
236+
def validate(self, model: torch.fx.GraphModule) -> None:
237+
pass
238+
239+
class M(torch.nn.Module):
240+
def __init__(self):
241+
super().__init__()
242+
self.conv = torch.nn.Conv2d(3, 3, 3)
243+
244+
def forward(self, x):
245+
return self.conv(x)
246+
247+
quantizer = Int4ActQuantizer()
248+
node_occurrence = {
249+
# one for input of the first conv, one for output for the first conv
250+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
251+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
252+
}
253+
node_list = [
254+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
255+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
256+
torch.ops.aten.conv2d.default,
257+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
258+
]
259+
example_inputs = (torch.randn(1, 3, 3, 3),)
260+
261+
# _test_quantizer in PT2EQuantizationTestCase
262+
# resetting dynamo cache
263+
export_with_dynamic_shape = False
264+
torch._dynamo.reset()
265+
m_eager = M().eval()
266+
267+
# program capture
268+
m = copy.deepcopy(m_eager)
269+
m = capture_pre_autograd_graph(
270+
m,
271+
example_inputs,
272+
constraints=[dynamic_dim(example_inputs[0], 0)] if export_with_dynamic_shape else [],
273+
)
274+
275+
m = prepare_pt2e(m, quantizer)
276+
# Calibrate
277+
m(*example_inputs)
278+
m = convert_pt2e(m, fold_quantize=True)
279+
280+
pt2_quant_output = m(*example_inputs)
281+
node_occurrence = {
282+
ns.call_function(k): v for k, v in expected_node_occurrence.items()
283+
}
284+
if expected_node_list is None:
285+
expected_node_list = []
286+
node_list = [ns.call_function(n) for n in expected_node_list]
287+
self.checkGraphModuleNodes(
288+
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
289+
)
290+
291+
if __name__ == "__main__":
292+
main()

torchao/dtypes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .int4 import UInt4Tensor
2+
3+
__all__ = [
4+
"UInt4Tensor"
5+
]

0 commit comments

Comments
 (0)