Skip to content

Commit 50859a0

Browse files
committed
Adding uint4 dtype implementation
Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 05f1c94 Pull Request resolved: #13
1 parent e16898d commit 50859a0

File tree

3 files changed

+491
-0
lines changed

3 files changed

+491
-0
lines changed

test/dtypes/test_int4.py

Lines changed: 322 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,322 @@
1+
import torch
2+
from torchao.dtypes.int4 import UInt4Tensor
3+
import unittest
4+
from unittest import TestCase, main
5+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
6+
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
7+
8+
from torch._export import capture_pre_autograd_graph
9+
from torch._export import dynamic_dim
10+
from torch.testing._internal.common_quantization import (
11+
NodeSpec as ns,
12+
QuantizationTestCase,
13+
)
14+
from torchao.quantization.utils import (
15+
compute_error,
16+
)
17+
from torchao.quantization.quant_api import (
18+
replace_with_custom_fn_if_matches_filter,
19+
)
20+
from torch.ao.quantization.observer import ObserverBase
21+
from torch import nn
22+
import copy
23+
24+
def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype):
25+
# assumes symmetric quantization
26+
# assumes axis == 0
27+
# assumes dense memory format
28+
# TODO(future): relax ^ as needed
29+
30+
# default setup for affine quantization of activations
31+
eps = torch.finfo(torch.float32).eps
32+
33+
# get min and max
34+
min_val, max_val = torch.aminmax(x, dim=1)
35+
36+
# calculate scale and zero point based on min and max
37+
# reference: https://fburl.com/code/srbiybme
38+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
39+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
40+
device = min_val_neg.device
41+
42+
# reference: https://fburl.com/code/4wll53rk
43+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
44+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
45+
# ensure scale is the same dtype as the original tensor
46+
scale = torch.clamp(scale, min=eps).to(x.dtype)
47+
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
48+
49+
# quantize based on qmin/qmax/scale/zp
50+
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
51+
x_div = x.transpose(0, 1) / scale
52+
x_round = torch.round(x_div)
53+
x_zp = x_round + zero_point
54+
x_zp = x_zp.transpose(0, 1)
55+
quant = torch.clamp(x_zp, quant_min, quant_max)
56+
if target_dtype == "int4":
57+
quant = UInt4Tensor.from_unpacked(quant.to(torch.uint8)).view(quant.size())
58+
else:
59+
quant = quant.to(target_dtype)
60+
61+
return quant, scale, zero_point
62+
63+
class _WeightOnlyInt4QuantLinear(torch.nn.Linear):
64+
def __init__(self, *args, **kwargs):
65+
w_int4 = kwargs.pop("w_int4")
66+
scales = kwargs.pop("scales")
67+
super().__init__(*args, **kwargs)
68+
self.w_int4 = w_int4
69+
self.scales = scales
70+
71+
def forward(self, x):
72+
# if len(x.shape)<=2:
73+
# y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales
74+
# else: # turn x into 2d tensor, then undo it for y
75+
x_view = x.view(-1, x.shape[-1])
76+
y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales
77+
y = y.reshape(*x.shape[:-1], -1)
78+
if self.bias is not None:
79+
y += self.bias
80+
return y
81+
82+
@classmethod
83+
def from_float(cls, mod):
84+
w_fp32 = mod.weight
85+
w_int4, scales, _zp = _dynamically_quantize_per_channel_int4(
86+
w_fp32, 0, 15, "int4"
87+
)
88+
# create the new module with a toy size to ensure initialization is fast
89+
fake_in_features, fake_out_features = 8, 8
90+
new_mod = cls(
91+
fake_in_features,
92+
fake_out_features,
93+
bias=mod.bias is not None,
94+
# w_int4=w_int4.t().contiguous(),
95+
w_int4=torch.ops.aten.transpose_copy(w_int4, 0, 1),
96+
scales=scales,
97+
)
98+
new_mod.in_features = mod.in_features
99+
new_mod.out_features = mod.out_features
100+
del new_mod.weight
101+
new_mod.bias = mod.bias
102+
device_to_use = next(mod.parameters()).device
103+
new_mod.to(device_to_use)
104+
return new_mod
105+
106+
def _apply_weight_only_int4_quant(model):
107+
replace_with_custom_fn_if_matches_filter(
108+
model,
109+
_WeightOnlyInt4QuantLinear.from_float,
110+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
111+
)
112+
113+
from torch.library import Library, impl
114+
115+
test_lib = Library("test_int4", "DEF")
116+
test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
117+
118+
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
119+
def quantize_per_tensor_int4(
120+
input: torch.Tensor,
121+
scale: float,
122+
zero_point: int,
123+
) -> torch.Tensor:
124+
inv_scale = 1.0 / scale
125+
return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8)
126+
127+
test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
128+
@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
129+
def dequantize_per_tensor_int4(
130+
input: torch.Tensor,
131+
scale: float,
132+
zero_point: int,
133+
) -> torch.Tensor:
134+
return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale
135+
136+
137+
class TestInt4(QuantizationTestCase):
138+
def test_basic_tensor_ops(self):
139+
x = UInt4Tensor(torch.tensor([
140+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
141+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
142+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
143+
], dtype=torch.uint8))
144+
self.assertEqual(x.shape, (3, 16))
145+
# making sure these works
146+
x.to(torch.uint8)
147+
expected = UInt4Tensor(torch.tensor([
148+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
149+
], dtype=torch.uint8))
150+
self.assertTrue(x[0:1, :] == expected)
151+
expected = UInt4Tensor(torch.tensor([
152+
[0x23, 0x45],
153+
[0x23, 0x45],
154+
[0x23, 0x45],
155+
], dtype=torch.uint8))
156+
self.assertTrue(x[:, 2:6] == expected)
157+
158+
def test_gpu_quant(self):
159+
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
160+
x = torch.randn(*x_shape)
161+
m = nn.Sequential(nn.Linear(4, 16))
162+
y_ref = m(x)
163+
_apply_weight_only_int4_quant(m)
164+
y_wo = m(x)
165+
# sqnr = compute_error(y_ref, y_wo)
166+
opt = torch.compile(m, mode="max-autotune")
167+
# make sure it runs
168+
opt(x)
169+
170+
def test_aten_ir(self):
171+
# class QuantizePerTensorUInt4(torch.autograd.Function):
172+
# @staticmethod
173+
# def forward(
174+
# ctx,
175+
# input: torch.Tensor,
176+
# scale: float,
177+
# zero_point: int,
178+
# ) -> torch.Tensor:
179+
# inv_scale = 1.0 / scale
180+
# return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8))
181+
182+
# class DeQuantizePerTensorUInt4(torch.autograd.Function):
183+
# @staticmethod
184+
# def forward(
185+
# ctx,
186+
# input: torch.Tensor,
187+
# scale: float,
188+
# zero_point: int,
189+
# ) -> torch.Tensor:
190+
# return (input.to(torch.float32) - zero_point) * scale
191+
192+
class M(torch.nn.Module):
193+
def forward(self, x, y):
194+
return x + y
195+
196+
example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),)
197+
m = M().eval()
198+
m = capture_pre_autograd_graph(m, example_inputs)
199+
for n in m.graph.nodes:
200+
if n.target == torch.ops.aten.add.Tensor:
201+
with m.graph.inserting_before(n):
202+
q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {})
203+
dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {})
204+
n.replace_input_with(n.args[0], dq)
205+
m.recompile()
206+
207+
def test_pt2e_quant(self):
208+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
209+
OP_TO_ANNOTATOR,
210+
QuantizationConfig,
211+
)
212+
class int4_class():
213+
pass
214+
215+
torch.int4 = int4_class()
216+
217+
class Int4Observer(ObserverBase):
218+
def __init__(self, *args, **kwargs):
219+
# just faking a dtype here
220+
# TODO: make flow work with new dtypes
221+
super().__init__(dtype=torch.int8)
222+
223+
def forward(self, x):
224+
return x
225+
226+
def calculate_qparams(self, **kwargs):
227+
pass
228+
229+
def convert(self, model: torch.fx.GraphModule, observer_node: Node):
230+
with model.graph.inserting_before(observer_node):
231+
q_node = model.graph.call_function(
232+
torch.ops.test_int4.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {})
233+
dq_node = model.graph.call_function(
234+
torch.ops.test_int4.dequantize_per_tensor_int4, (q_node, 1.0, 0), {})
235+
observer_node.replace_all_uses_with(dq_node)
236+
model.graph.erase_node(observer_node)
237+
238+
class Int4WeightQuantizer(Quantizer):
239+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
240+
int4_qspec = QuantizationSpec(
241+
dtype=torch.int4,
242+
quant_min=-2**3,
243+
quant_max=2**3 - 1,
244+
qscheme=torch.per_tensor_affine,
245+
is_dynamic=False,
246+
observer_or_fake_quant_ctr=Int4Observer,
247+
)
248+
int8_qspec = QuantizationSpec(
249+
dtype=torch.int8,
250+
quant_min=-128,
251+
quant_max=127,
252+
qscheme=torch.per_tensor_symmetric,
253+
is_dynamic=False,
254+
observer_or_fake_quant_ctr=observer.default_weight_observer,
255+
)
256+
quantization_config = QuantizationConfig(
257+
input_activation=int8_qspec,
258+
weight=int4_qspec,
259+
bias=None,
260+
output_activation=int8_qspec,
261+
)
262+
OP_TO_ANNOTATOR["conv"](model, quantization_config)
263+
264+
def validate(self, model: torch.fx.GraphModule) -> None:
265+
pass
266+
267+
class M(torch.nn.Module):
268+
def __init__(self):
269+
super().__init__()
270+
self.conv = torch.nn.Conv2d(3, 3, 3)
271+
272+
def forward(self, x):
273+
return self.conv(x)
274+
275+
quantizer = Int4WeightQuantizer()
276+
node_occurrence = {
277+
# for weight
278+
torch.ops.test_int4.quantize_per_tensor_int4: 1,
279+
torch.ops.test_int4.dequantize_per_tensor_int4: 1,
280+
# for activation
281+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
282+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
283+
}
284+
node_list = [
285+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
286+
torch.ops.test_int4.dequantize_per_tensor_int4,
287+
torch.ops.aten.conv2d.default,
288+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
289+
]
290+
example_inputs = (torch.randn(1, 3, 3, 3),)
291+
292+
# _test_quantizer in PT2EQuantizationTestCase
293+
# resetting dynamo cache
294+
export_with_dynamic_shape = False
295+
torch._dynamo.reset()
296+
m_eager = M().eval()
297+
298+
# program capture
299+
m = copy.deepcopy(m_eager)
300+
m = capture_pre_autograd_graph(
301+
m,
302+
example_inputs,
303+
)
304+
305+
m = prepare_pt2e(m, quantizer)
306+
# Calibrate
307+
m(*example_inputs)
308+
m = convert_pt2e(m, fold_quantize=True)
309+
310+
pt2_quant_output = m(*example_inputs)
311+
node_occurrence = {
312+
ns.call_function(k): v for k, v in expected_node_occurrence.items()
313+
}
314+
if expected_node_list is None:
315+
expected_node_list = []
316+
node_list = [ns.call_function(n) for n in expected_node_list]
317+
self.checkGraphModuleNodes(
318+
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
319+
)
320+
321+
if __name__ == "__main__":
322+
main()

torchao/dtypes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .int4 import UInt4Tensor
2+
3+
__all__ = [
4+
"UInt4Tensor"
5+
]

0 commit comments

Comments
 (0)