|
| 1 | +import torch |
| 2 | +from torchao.dtypes.uint4 import ( |
| 3 | + UInt4Tensor, |
| 4 | + PerChannelSymmetricWeightUInt4Tensor, |
| 5 | +) |
| 6 | +import unittest |
| 7 | +from unittest import TestCase, main |
| 8 | +from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e |
| 9 | +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer |
| 10 | + |
| 11 | +from torch._export import capture_pre_autograd_graph |
| 12 | +from torch._export import dynamic_dim |
| 13 | +from torch.testing._internal.common_quantization import ( |
| 14 | + NodeSpec as ns, |
| 15 | + QuantizationTestCase, |
| 16 | +) |
| 17 | +from torchao.quantization.utils import ( |
| 18 | + compute_error, |
| 19 | +) |
| 20 | +from torchao.quantization.quant_api import ( |
| 21 | + replace_with_custom_fn_if_matches_filter, |
| 22 | +) |
| 23 | +from torch.ao.quantization.observer import ObserverBase |
| 24 | +from torch import nn |
| 25 | +from torch.fx import ( |
| 26 | + Node, |
| 27 | + GraphModule, |
| 28 | +) |
| 29 | +from torch.ao.quantization.quantizer import ( |
| 30 | + QuantizationAnnotation, |
| 31 | +) |
| 32 | +import copy |
| 33 | + |
| 34 | +def _apply_weight_only_uint4_quant(model): |
| 35 | + def fn(mod): |
| 36 | + mod.weight = torch.nn.Parameter(PerChannelSymmetricWeightUInt4Tensor.from_float(mod.weight), requires_grad=False) |
| 37 | + return mod |
| 38 | + |
| 39 | + replace_with_custom_fn_if_matches_filter( |
| 40 | + model, |
| 41 | + lambda mod: fn(mod), |
| 42 | + lambda mod, fqn: isinstance(mod, torch.nn.Linear), |
| 43 | + ) |
| 44 | + |
| 45 | +class TestUInt4(QuantizationTestCase): |
| 46 | + def test_basic_tensor_ops(self): |
| 47 | + x = UInt4Tensor(torch.tensor([ |
| 48 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 49 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 50 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 51 | + ], dtype=torch.uint8)) |
| 52 | + self.assertEqual(x.shape, (3, 16)) |
| 53 | + # TODO: make sure this returns torch.uint4 |
| 54 | + self.assertIs(x.dtype, torch.uint4) |
| 55 | + # making sure these works |
| 56 | + x.to(torch.uint8) |
| 57 | + expected = UInt4Tensor(torch.tensor([ |
| 58 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 59 | + ], dtype=torch.uint8)) |
| 60 | + self.assertEqual(x[0:1, :], expected) |
| 61 | + expected = UInt4Tensor(torch.tensor([ |
| 62 | + [0x23, 0x45], |
| 63 | + [0x23, 0x45], |
| 64 | + [0x23, 0x45], |
| 65 | + ], dtype=torch.uint8)) |
| 66 | + self.assertEqual(x[:, 2:6], expected) |
| 67 | + torch.save(x, "uint4_tensor.pt") |
| 68 | + x = torch.load("uint4_tensor.pt") |
| 69 | + self.assertEqual(x[:, 2:6], expected) |
| 70 | + # only test locally |
| 71 | + # print("x:", x[0]) |
| 72 | + |
| 73 | + def test_gpu_quant(self): |
| 74 | + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: |
| 75 | + x = torch.randn(*x_shape) |
| 76 | + m = nn.Sequential(nn.Linear(4, 16)) |
| 77 | + y_ref = m(x) |
| 78 | + _apply_weight_only_uint4_quant(m) |
| 79 | + y_wo = m(x) |
| 80 | + # sqnr = compute_error(y_ref, y_wo) |
| 81 | + opt = torch.compile(m, fullgraph=True, mode="max-autotune") |
| 82 | + # make sure it runs |
| 83 | + opt(x) |
| 84 | + |
| 85 | + def test_pt2e_quant(self): |
| 86 | + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( |
| 87 | + OP_TO_ANNOTATOR, |
| 88 | + QuantizationConfig, |
| 89 | + ) |
| 90 | + class Uint4Observer(ObserverBase): |
| 91 | + def __init__(self, *args, **kwargs): |
| 92 | + # just faking a dtype here |
| 93 | + # TODO: make flow work with new dtypes |
| 94 | + super().__init__(dtype=torch.int8) |
| 95 | + |
| 96 | + def forward(self, x): |
| 97 | + return x |
| 98 | + |
| 99 | + def calculate_qparams(self, **kwargs): |
| 100 | + pass |
| 101 | + |
| 102 | + def convert(self, model: GraphModule, observer_node: Node): |
| 103 | + with model.graph.inserting_before(observer_node): |
| 104 | + q_node = model.graph.call_function( |
| 105 | + torch.ops.qtensors.quantize_per_tensor_uint4, (observer_node.args[0], 1.0, 0), {}) |
| 106 | + dq_node = model.graph.call_function( |
| 107 | + torch.ops.qtensors.dequantize_per_tensor_uint4, (q_node, 1.0, 0), {}) |
| 108 | + observer_node.replace_all_uses_with(dq_node) |
| 109 | + model.graph.erase_node(observer_node) |
| 110 | + |
| 111 | + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( |
| 112 | + _is_annotated, |
| 113 | + _mark_nodes_as_annotated, |
| 114 | + ) |
| 115 | + |
| 116 | + class Int8ActUint4WeightQuantizer(Quantizer): |
| 117 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 118 | + uint4_qspec = QuantizationSpec( |
| 119 | + dtype=torch.uint4, |
| 120 | + quant_min=0, |
| 121 | + quant_max=2**4 - 1, |
| 122 | + qscheme=torch.per_tensor_affine, |
| 123 | + is_dynamic=False, |
| 124 | + observer_or_fake_quant_ctr=Uint4Observer, |
| 125 | + ) |
| 126 | + int8_qspec = QuantizationSpec( |
| 127 | + dtype=torch.int8, |
| 128 | + quant_min=-128, |
| 129 | + quant_max=127, |
| 130 | + qscheme=torch.per_tensor_symmetric, |
| 131 | + is_dynamic=False, |
| 132 | + observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer, |
| 133 | + ) |
| 134 | + quantization_config = QuantizationConfig( |
| 135 | + input_activation=int8_qspec, |
| 136 | + weight=uint4_qspec, |
| 137 | + bias=None, |
| 138 | + output_activation=int8_qspec, |
| 139 | + ) |
| 140 | + for n in model.graph.nodes: |
| 141 | + if n.op != "call_function" or n.target not in [ |
| 142 | + torch.ops.aten.linear.default, |
| 143 | + ]: |
| 144 | + continue |
| 145 | + linear_node = n |
| 146 | + |
| 147 | + input_qspec_map = {} |
| 148 | + input_act = linear_node.args[0] |
| 149 | + assert isinstance(input_act, Node) |
| 150 | + input_qspec_map[input_act] = quantization_config.input_activation |
| 151 | + |
| 152 | + weight = linear_node.args[1] |
| 153 | + assert isinstance(weight, Node) |
| 154 | + input_qspec_map[weight] = quantization_config.weight |
| 155 | + |
| 156 | + partition = [linear_node, linear_node.args[1]] |
| 157 | + |
| 158 | + bias = linear_node.args[2] if len(linear_node.args) > 2 else None |
| 159 | + if isinstance(bias, Node): |
| 160 | + input_qspec_map[bias] = quantization_config.bias |
| 161 | + partition.append(bias) |
| 162 | + |
| 163 | + if _is_annotated(partition): |
| 164 | + continue |
| 165 | + |
| 166 | + linear_node.meta["quantization_annotation"] = QuantizationAnnotation( |
| 167 | + input_qspec_map=input_qspec_map, |
| 168 | + output_qspec=quantization_config.output_activation, |
| 169 | + _annotated=True, |
| 170 | + ) |
| 171 | + _mark_nodes_as_annotated(partition) |
| 172 | + |
| 173 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 174 | + pass |
| 175 | + |
| 176 | + class M(torch.nn.Module): |
| 177 | + def __init__(self): |
| 178 | + super().__init__() |
| 179 | + self.linear = torch.nn.Linear(4, 4) |
| 180 | + |
| 181 | + def forward(self, x): |
| 182 | + return self.linear(x) |
| 183 | + |
| 184 | + quantizer = Int8ActUint4WeightQuantizer() |
| 185 | + node_occurrence = { |
| 186 | + # for weight |
| 187 | + torch.ops.qtensors.quantize_per_tensor_uint4: 1, |
| 188 | + torch.ops.qtensors.dequantize_per_tensor_uint4: 1, |
| 189 | + # for activation |
| 190 | + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, |
| 191 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, |
| 192 | + } |
| 193 | + node_list = [ |
| 194 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| 195 | + torch.ops.qtensors.dequantize_per_tensor_uint4, |
| 196 | + torch.ops.aten.linear.default, |
| 197 | + torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| 198 | + ] |
| 199 | + example_inputs = (torch.randn(2, 4),) |
| 200 | + |
| 201 | + # _test_quantizer in PT2EQuantizationTestCase |
| 202 | + # resetting dynamo cache |
| 203 | + export_with_dynamic_shape = False |
| 204 | + torch._dynamo.reset() |
| 205 | + m_eager = M().eval() |
| 206 | + |
| 207 | + # program capture |
| 208 | + m = copy.deepcopy(m_eager) |
| 209 | + m = capture_pre_autograd_graph( |
| 210 | + m, |
| 211 | + example_inputs, |
| 212 | + ) |
| 213 | + |
| 214 | + m = prepare_pt2e(m, quantizer) |
| 215 | + # Calibrate |
| 216 | + m(*example_inputs) |
| 217 | + m = convert_pt2e(m, fold_quantize=False) |
| 218 | + pt2_quant_output = m(*example_inputs) |
| 219 | + |
| 220 | + node_occurrence = { |
| 221 | + ns.call_function(k): v for k, v in node_occurrence.items() |
| 222 | + } |
| 223 | + node_list = [ns.call_function(n) for n in node_list] |
| 224 | + self.checkGraphModuleNodes( |
| 225 | + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list |
| 226 | + ) |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + main() |
0 commit comments