Skip to content

Commit 4531b87

Browse files
committed
Adding uint4 dtype implementation
Summary: We have a lot of interest for int4 dtypes, and we'd like to add the dtype out of PyTorch core. This PR added some preliminary support for uint4 through tensor subclass and we'll continue to iterate on this Test Plan: python test/dtypes/test_int4.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 51cf717 Pull Request resolved: #13
1 parent e16898d commit 4531b87

File tree

3 files changed

+528
-0
lines changed

3 files changed

+528
-0
lines changed

test/dtypes/test_uint4.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
import torch
2+
from torchao.dtypes.uint4 import (
3+
UInt4Tensor,
4+
PerChannelSymmetricWeightUInt4Tensor,
5+
)
6+
import unittest
7+
from unittest import TestCase, main
8+
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
9+
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
10+
11+
from torch._export import capture_pre_autograd_graph
12+
from torch._export import dynamic_dim
13+
from torch.testing._internal.common_quantization import (
14+
NodeSpec as ns,
15+
QuantizationTestCase,
16+
)
17+
from torchao.quantization.utils import (
18+
compute_error,
19+
)
20+
from torchao.quantization.quant_api import (
21+
replace_with_custom_fn_if_matches_filter,
22+
)
23+
from torch.ao.quantization.observer import ObserverBase
24+
from torch import nn
25+
from torch.fx import (
26+
Node,
27+
GraphModule,
28+
)
29+
from torch.ao.quantization.quantizer import (
30+
QuantizationAnnotation,
31+
)
32+
import copy
33+
34+
def _apply_weight_only_uint4_quant(model):
35+
def fn(mod):
36+
mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False)
37+
return mod
38+
39+
replace_with_custom_fn_if_matches_filter(
40+
model,
41+
lambda mod: fn(mod),
42+
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
43+
)
44+
45+
class TestUInt4(QuantizationTestCase):
46+
def test_basic_tensor_ops(self):
47+
x = UInt4Tensor(torch.tensor([
48+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
49+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
50+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
51+
], dtype=torch.uint8))
52+
self.assertEqual(x.shape, (3, 16))
53+
# TODO: make sure this returns torch.uint4
54+
self.assertIs(x.dtype, torch.uint4)
55+
# making sure these works
56+
x.to(torch.uint8)
57+
expected = UInt4Tensor(torch.tensor([
58+
[0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF],
59+
], dtype=torch.uint8))
60+
self.assertEqual(x[0:1, :], expected)
61+
expected = UInt4Tensor(torch.tensor([
62+
[0x23, 0x45],
63+
[0x23, 0x45],
64+
[0x23, 0x45],
65+
], dtype=torch.uint8))
66+
self.assertEqual(x[:, 2:6], expected)
67+
torch.save(x, "uint4_tensor.pt")
68+
x = torch.load("uint4_tensor.pt")
69+
self.assertEqual(x[:, 2:6], expected)
70+
# only test locally
71+
# print("x:", x[0])
72+
73+
def test_gpu_quant(self):
74+
for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]:
75+
x = torch.randn(*x_shape)
76+
m = nn.Sequential(nn.Linear(4, 16))
77+
y_ref = m(x)
78+
_apply_weight_only_uint4_quant(m)
79+
y_wo = m(x)
80+
# sqnr = compute_error(y_ref, y_wo)
81+
opt = torch.compile(m, fullgraph=True, mode="max-autotune")
82+
# make sure it runs
83+
opt(x)
84+
85+
def test_pt2e_quant(self):
86+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
87+
OP_TO_ANNOTATOR,
88+
QuantizationConfig,
89+
)
90+
class Uint4Observer(ObserverBase):
91+
def __init__(self, *args, **kwargs):
92+
# just faking a dtype here
93+
# TODO: make flow work with new dtypes
94+
super().__init__(dtype=torch.int8)
95+
96+
def forward(self, x):
97+
return x
98+
99+
def calculate_qparams(self, **kwargs):
100+
pass
101+
102+
def convert(self, model: GraphModule, observer_node: Node):
103+
with model.graph.inserting_before(observer_node):
104+
q_node = model.graph.call_function(
105+
torch.ops.qtensors.quantize_per_tensor_uint4, (observer_node.args[0], 1.0, 0), {})
106+
dq_node = model.graph.call_function(
107+
torch.ops.qtensors.dequantize_per_tensor_uint4, (q_node, 1.0, 0), {})
108+
observer_node.replace_all_uses_with(dq_node)
109+
model.graph.erase_node(observer_node)
110+
111+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
112+
_is_annotated,
113+
_mark_nodes_as_annotated,
114+
)
115+
116+
class Int8ActUint4WeightQuantizer(Quantizer):
117+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
118+
uint4_qspec = QuantizationSpec(
119+
dtype=torch.uint4,
120+
quant_min=0,
121+
quant_max=2**4 - 1,
122+
qscheme=torch.per_tensor_affine,
123+
is_dynamic=False,
124+
observer_or_fake_quant_ctr=Uint4Observer,
125+
)
126+
int8_qspec = QuantizationSpec(
127+
dtype=torch.int8,
128+
quant_min=-128,
129+
quant_max=127,
130+
qscheme=torch.per_tensor_symmetric,
131+
is_dynamic=False,
132+
observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer,
133+
)
134+
quantization_config = QuantizationConfig(
135+
input_activation=int8_qspec,
136+
weight=uint4_qspec,
137+
bias=None,
138+
output_activation=int8_qspec,
139+
)
140+
for n in model.graph.nodes:
141+
if n.op != "call_function" or n.target not in [
142+
torch.ops.aten.linear.default,
143+
]:
144+
continue
145+
linear_node = n
146+
147+
input_qspec_map = {}
148+
input_act = linear_node.args[0]
149+
assert isinstance(input_act, Node)
150+
input_qspec_map[input_act] = quantization_config.input_activation
151+
152+
weight = linear_node.args[1]
153+
assert isinstance(weight, Node)
154+
input_qspec_map[weight] = quantization_config.weight
155+
156+
partition = [linear_node, linear_node.args[1]]
157+
158+
bias = linear_node.args[2] if len(linear_node.args) > 2 else None
159+
if isinstance(bias, Node):
160+
input_qspec_map[bias] = quantization_config.bias
161+
partition.append(bias)
162+
163+
if _is_annotated(partition):
164+
continue
165+
166+
linear_node.meta["quantization_annotation"] = QuantizationAnnotation(
167+
input_qspec_map=input_qspec_map,
168+
output_qspec=quantization_config.output_activation,
169+
_annotated=True,
170+
)
171+
_mark_nodes_as_annotated(partition)
172+
173+
def validate(self, model: torch.fx.GraphModule) -> None:
174+
pass
175+
176+
class M(torch.nn.Module):
177+
def __init__(self):
178+
super().__init__()
179+
self.linear = torch.nn.Linear(4, 4)
180+
181+
def forward(self, x):
182+
return self.linear(x)
183+
184+
quantizer = Int8ActUint4WeightQuantizer()
185+
node_occurrence = {
186+
# for weight
187+
torch.ops.qtensors.quantize_per_tensor_uint4: 1,
188+
torch.ops.qtensors.dequantize_per_tensor_uint4: 1,
189+
# for activation
190+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
191+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2,
192+
}
193+
node_list = [
194+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
195+
torch.ops.qtensors.dequantize_per_tensor_uint4,
196+
torch.ops.aten.linear.default,
197+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
198+
]
199+
example_inputs = (torch.randn(2, 4),)
200+
201+
# _test_quantizer in PT2EQuantizationTestCase
202+
# resetting dynamo cache
203+
export_with_dynamic_shape = False
204+
torch._dynamo.reset()
205+
m_eager = M().eval()
206+
207+
# program capture
208+
m = copy.deepcopy(m_eager)
209+
m = capture_pre_autograd_graph(
210+
m,
211+
example_inputs,
212+
)
213+
214+
m = prepare_pt2e(m, quantizer)
215+
# Calibrate
216+
m(*example_inputs)
217+
m = convert_pt2e(m, fold_quantize=False)
218+
pt2_quant_output = m(*example_inputs)
219+
220+
node_occurrence = {
221+
ns.call_function(k): v for k, v in node_occurrence.items()
222+
}
223+
node_list = [ns.call_function(n) for n in node_list]
224+
self.checkGraphModuleNodes(
225+
m, expected_node_occurrence=node_occurrence, expected_node_list=node_list
226+
)
227+
228+
if __name__ == "__main__":
229+
main()

torchao/dtypes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .uint4 import UInt4Tensor
2+
3+
__all__ = [
4+
"UInt4Tensor"
5+
]

0 commit comments

Comments
 (0)