Skip to content

Commit db8b979

Browse files
committed
Adding uint4 dtype implementation
Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 9a07ab7 Pull Request resolved: #13
1 parent e16898d commit db8b979

File tree

3 files changed

+542
-0
lines changed

3 files changed

+542
-0
lines changed

test/dtypes/test_int4.py

Lines changed: 374 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,374 @@
1+
import torch
2+
from torchao.dtypes.int4 import UInt4Tensor
3+
import unittest
4+
from unittest import TestCase, main
5+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
6+
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
7+
8+
from torch._export import capture_pre_autograd_graph
9+
from torch._export import dynamic_dim
10+
from torch.testing._internal.common_quantization import (
11+
NodeSpec as ns,
12+
QuantizationTestCase,
13+
)
14+
from torchao.quantization.utils import (
15+
compute_error,
16+
)
17+
from torchao.quantization.quant_api import (
18+
replace_with_custom_fn_if_matches_filter,
19+
)
20+
from torch.ao.quantization.observer import ObserverBase
21+
from torch import nn
22+
from torch.fx import (
23+
Node,
24+
GraphModule,
25+
)
26+
from torch.ao.quantization.quantizer import (
27+
QuantizationAnnotation,
28+
)
29+
import copy
30+
31+
def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype):
32+
# assumes symmetric quantization
33+
# assumes axis == 0
34+
# assumes dense memory format
35+
# TODO(future): relax ^ as needed
36+
37+
# default setup for affine quantization of activations
38+
eps = torch.finfo(torch.float32).eps
39+
40+
# get min and max
41+
min_val, max_val = torch.aminmax(x, dim=1)
42+
43+
# calculate scale and zero point based on min and max
44+
# reference: https://fburl.com/code/srbiybme
45+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
46+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
47+
device = min_val_neg.device
48+
49+
# reference: https://fburl.com/code/4wll53rk
50+
max_val_pos = torch.max(-min_val_neg, max_val_pos)
51+
scale = max_val_pos / (float(quant_max - quant_min) / 2)
52+
# ensure scale is the same dtype as the original tensor
53+
scale = torch.clamp(scale, min=eps).to(x.dtype)
54+
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
55+
56+
# quantize based on qmin/qmax/scale/zp
57+
# reference: torch/ao/quantization/fx/_decomposed.py?lines=63
58+
x_div = x.transpose(0, 1) / scale
59+
x_round = torch.round(x_div)
60+
x_zp = x_round + zero_point
61+
x_zp = x_zp.transpose(0, 1)
62+
quant = torch.clamp(x_zp, quant_min, quant_max)
63+
if target_dtype == "int4":
64+
quant = UInt4Tensor.from_unpacked(quant.view(torch.bits8)).view(quant.size())
65+
else:
66+
quant = quant.to(target_dtype)
67+
68+
return quant, scale, zero_point
69+
70+
class _WeightOnlyInt4QuantLinear(torch.nn.Linear):
71+
def __init__(self, *args, **kwargs):
72+
w_int4 = kwargs.pop("w_int4")
73+
scales = kwargs.pop("scales")
74+
super().__init__(*args, **kwargs)
75+
self.w_int4 = w_int4
76+
self.scales = scales
77+
78+
def forward(self, x):
79+
# if len(x.shape)<=2:
80+
# y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales
81+
# else: # turn x into 2d tensor, then undo it for y
82+
x_view = x.view(-1, x.shape[-1])
83+
y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales
84+
y = y.reshape(*x.shape[:-1], -1)
85+
if self.bias is not None:
86+
y += self.bias
87+
return y
88+
89+
@classmethod
90+
def from_float(cls, mod):
91+
w_fp32 = mod.weight
92+
w_int4, scales, _zp = _dynamically_quantize_per_channel_int4(
93+
w_fp32, 0, 15, "int4"
94+
)
95+
# create the new module with a toy size to ensure initialization is fast
96+
fake_in_features, fake_out_features = 8, 8
97+
new_mod = cls(
98+
fake_in_features,
99+
fake_out_features,
100+
bias=mod.bias is not None,
101+
# w_int4=w_int4.t().contiguous(),
102+
w_int4=torch.ops.aten.transpose_copy(w_int4, 0, 1),
103+
scales=scales,
104+
)
105+
new_mod.in_features = mod.in_features
106+
new_mod.out_features = mod.out_features
107+
del new_mod.weight
108+
new_mod.bias = mod.bias
109+
device_to_use = next(mod.parameters()).device
110+
new_mod.to(device_to_use)
111+
return new_mod
112+
113+
def _apply_weight_only_int4_quant(model):
114+
replace_with_custom_fn_if_matches_filter(
115+
model,
116+
_WeightOnlyInt4QuantLinear.from_float,
117+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
118+
)
119+
120+
from torch.library import Library, impl
121+
122+
test_lib = Library("test_int4", "DEF")
123+
test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
124+
125+
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
126+
def quantize_per_tensor_int4(
127+
input: torch.Tensor,
128+
scale: float,
129+
zero_point: int,
130+
) -> torch.Tensor:
131+
inv_scale = 1.0 / scale
132+
return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8)
133+
134+
test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
135+
@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
136+
def dequantize_per_tensor_int4(
137+
input: torch.Tensor,
138+
scale: float,
139+
zero_point: int,
140+
) -> torch.Tensor:
141+
print("1")
142+
a = input.to(torch.uint8)
143+
print("2")
144+
a = a.to(torch.float32)
145+
print("3")
146+
a = a - zero_point
147+
print("4")
148+
a = a * scale
149+
print("5")
150+
return a
151+
# return (input.to(torch.uint8).to(torch.float32) - zero_point) * scale
152+
153+
154+
class TestInt4(QuantizationTestCase):
155+
def test_basic_tensor_ops(self):
156+
x = UInt4Tensor(torch.tensor([
157+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
158+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
159+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
160+
], dtype=torch.bits8))
161+
self.assertEqual(x.shape, (3, 16))
162+
# making sure these works
163+
x.to(torch.uint8)
164+
expected = UInt4Tensor(torch.tensor([
165+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
166+
], dtype=torch.bits8))
167+
self.assertTrue(x[0:1, :] == expected)
168+
expected = UInt4Tensor(torch.tensor([
169+
[0x23, 0x45],
170+
[0x23, 0x45],
171+
[0x23, 0x45],
172+
], dtype=torch.bits8))
173+
self.assertTrue(x[:, 2:6] == expected)
174+
175+
def test_gpu_quant(self):
176+
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
177+
x = torch.randn(*x_shape)
178+
m = nn.Sequential(nn.Linear(4, 16))
179+
y_ref = m(x)
180+
_apply_weight_only_int4_quant(m)
181+
y_wo = m(x)
182+
# sqnr = compute_error(y_ref, y_wo)
183+
opt = torch.compile(m, mode="max-autotune")
184+
# make sure it runs
185+
opt(x)
186+
187+
def test_aten_ir(self):
188+
# class QuantizePerTensorUInt4(torch.autograd.Function):
189+
# @staticmethod
190+
# def forward(
191+
# ctx,
192+
# input: torch.Tensor,
193+
# scale: float,
194+
# zero_point: int,
195+
# ) -> torch.Tensor:
196+
# inv_scale = 1.0 / scale
197+
# return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.bits8))
198+
199+
# class DeQuantizePerTensorUInt4(torch.autograd.Function):
200+
# @staticmethod
201+
# def forward(
202+
# ctx,
203+
# input: torch.Tensor,
204+
# scale: float,
205+
# zero_point: int,
206+
# ) -> torch.Tensor:
207+
# return (input.to(torch.float32) - zero_point) * scale
208+
209+
class M(torch.nn.Module):
210+
def forward(self, x, y):
211+
return x + y
212+
213+
example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),)
214+
m = M().eval()
215+
m = capture_pre_autograd_graph(m, example_inputs)
216+
for n in m.graph.nodes:
217+
if n.target == torch.ops.aten.add.Tensor:
218+
with m.graph.inserting_before(n):
219+
q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {})
220+
dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {})
221+
n.replace_input_with(n.args[0], dq)
222+
m.recompile()
223+
224+
def test_pt2e_quant(self):
225+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
226+
OP_TO_ANNOTATOR,
227+
QuantizationConfig,
228+
)
229+
class int4_class():
230+
pass
231+
232+
torch.int4 = int4_class()
233+
234+
class Int4Observer(ObserverBase):
235+
def __init__(self, *args, **kwargs):
236+
# just faking a dtype here
237+
# TODO: make flow work with new dtypes
238+
super().__init__(dtype=torch.int8)
239+
240+
def forward(self, x):
241+
return x
242+
243+
def calculate_qparams(self, **kwargs):
244+
pass
245+
246+
def convert(self, model: GraphModule, observer_node: Node):
247+
with model.graph.inserting_before(observer_node):
248+
q_node = model.graph.call_function(
249+
torch.ops.test_int4.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {})
250+
dq_node = model.graph.call_function(
251+
torch.ops.test_int4.dequantize_per_tensor_int4, (q_node, 1.0, 0), {})
252+
observer_node.replace_all_uses_with(dq_node)
253+
model.graph.erase_node(observer_node)
254+
255+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
256+
_is_annotated,
257+
_mark_nodes_as_annotated,
258+
)
259+
260+
class Int4WeightQuantizer(Quantizer):
261+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
262+
int4_qspec = QuantizationSpec(
263+
dtype=torch.int4,
264+
quant_min=-2**3,
265+
quant_max=2**3 - 1,
266+
qscheme=torch.per_tensor_affine,
267+
is_dynamic=False,
268+
observer_or_fake_quant_ctr=Int4Observer,
269+
)
270+
int8_qspec = QuantizationSpec(
271+
dtype=torch.int8,
272+
quant_min=-128,
273+
quant_max=127,
274+
qscheme=torch.per_tensor_symmetric,
275+
is_dynamic=False,
276+
observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer,
277+
)
278+
quantization_config = QuantizationConfig(
279+
input_activation=int8_qspec,
280+
weight=int4_qspec,
281+
bias=None,
282+
output_activation=int8_qspec,
283+
)
284+
for n in model.graph.nodes:
285+
if n.op != "call_function" or n.target not in [
286+
torch.ops.aten.conv1d.default,
287+
torch.ops.aten.conv2d.default,
288+
]:
289+
continue
290+
conv_node = n
291+
292+
input_qspec_map = {}
293+
input_act = conv_node.args[0]
294+
assert isinstance(input_act, Node)
295+
input_qspec_map[input_act] = quantization_config.input_activation
296+
297+
weight = conv_node.args[1]
298+
assert isinstance(weight, Node)
299+
input_qspec_map[weight] = quantization_config.weight
300+
301+
partition = [conv_node, conv_node.args[1]]
302+
303+
bias = conv_node.args[2] if len(conv_node.args) > 2 else None
304+
if isinstance(bias, Node):
305+
input_qspec_map[bias] = quantization_config.bias
306+
partition.append(bias)
307+
308+
if _is_annotated(partition):
309+
continue
310+
311+
conv_node.meta["quantization_annotation"] = QuantizationAnnotation(
312+
input_qspec_map=input_qspec_map,
313+
output_qspec=quantization_config.output_activation,
314+
_annotated=True,
315+
)
316+
_mark_nodes_as_annotated(partition)
317+
318+
def validate(self, model: torch.fx.GraphModule) -> None:
319+
pass
320+
321+
class M(torch.nn.Module):
322+
def __init__(self):
323+
super().__init__()
324+
self.conv = torch.nn.Conv2d(3, 3, 3)
325+
326+
def forward(self, x):
327+
return self.conv(x)
328+
329+
quantizer = Int4WeightQuantizer()
330+
node_occurrence = {
331+
# for weight
332+
torch.ops.test_int4.quantize_per_tensor_int4: 1,
333+
torch.ops.test_int4.dequantize_per_tensor_int4: 1,
334+
# for activation
335+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
336+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
337+
}
338+
node_list = [
339+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
340+
torch.ops.test_int4.dequantize_per_tensor_int4,
341+
torch.ops.aten.conv2d.default,
342+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
343+
]
344+
example_inputs = (torch.randn(1, 3, 3, 3),)
345+
346+
# _test_quantizer in PT2EQuantizationTestCase
347+
# resetting dynamo cache
348+
export_with_dynamic_shape = False
349+
torch._dynamo.reset()
350+
m_eager = M().eval()
351+
352+
# program capture
353+
m = copy.deepcopy(m_eager)
354+
m = capture_pre_autograd_graph(
355+
m,
356+
example_inputs,
357+
)
358+
359+
m = prepare_pt2e(m, quantizer)
360+
# Calibrate
361+
m(*example_inputs)
362+
m = convert_pt2e(m, fold_quantize=False)
363+
364+
pt2_quant_output = m(*example_inputs)
365+
node_occurrence = {
366+
ns.call_function(k): v for k, v in node_occurrence.items()
367+
}
368+
node_list = [ns.call_function(n) for n in node_list]
369+
self.checkGraphModuleNodes(
370+
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
371+
)
372+
373+
if __name__ == "__main__":
374+
main()

torchao/dtypes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .int4 import UInt4Tensor
2+
3+
__all__ = [
4+
"UInt4Tensor"
5+
]

0 commit comments

Comments
 (0)