|
| 1 | +import torch |
| 2 | +from torchao.dtypes.int4 import UInt4Tensor |
| 3 | +import unittest |
| 4 | +from unittest import TestCase, main |
| 5 | +from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e |
| 6 | +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer |
| 7 | + |
| 8 | +from torch._export import capture_pre_autograd_graph |
| 9 | +from torch._export import dynamic_dim |
| 10 | +from torch.testing._internal.common_quantization import ( |
| 11 | + NodeSpec as ns, |
| 12 | + QuantizationTestCase, |
| 13 | +) |
| 14 | +from torchao.quantization.utils import ( |
| 15 | + compute_error, |
| 16 | +) |
| 17 | +from torchao.quantization.quant_api import ( |
| 18 | + replace_with_custom_fn_if_matches_filter, |
| 19 | +) |
| 20 | +from torch.ao.quantization.observer import ObserverBase |
| 21 | +from torch import nn |
| 22 | +from torch.fx import ( |
| 23 | + Node, |
| 24 | + GraphModule, |
| 25 | +) |
| 26 | +from torch.ao.quantization.quantizer import ( |
| 27 | + QuantizationAnnotation, |
| 28 | +) |
| 29 | +import copy |
| 30 | + |
| 31 | +def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype): |
| 32 | + # assumes symmetric quantization |
| 33 | + # assumes axis == 0 |
| 34 | + # assumes dense memory format |
| 35 | + # TODO(future): relax ^ as needed |
| 36 | + |
| 37 | + # default setup for affine quantization of activations |
| 38 | + eps = torch.finfo(torch.float32).eps |
| 39 | + |
| 40 | + # get min and max |
| 41 | + min_val, max_val = torch.aminmax(x, dim=1) |
| 42 | + |
| 43 | + # calculate scale and zero point based on min and max |
| 44 | + # reference: https://fburl.com/code/srbiybme |
| 45 | + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) |
| 46 | + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) |
| 47 | + device = min_val_neg.device |
| 48 | + |
| 49 | + # reference: https://fburl.com/code/4wll53rk |
| 50 | + max_val_pos = torch.max(-min_val_neg, max_val_pos) |
| 51 | + scale = max_val_pos / (float(quant_max - quant_min) / 2) |
| 52 | + # ensure scale is the same dtype as the original tensor |
| 53 | + scale = torch.clamp(scale, min=eps).to(x.dtype) |
| 54 | + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) |
| 55 | + |
| 56 | + # quantize based on qmin/qmax/scale/zp |
| 57 | + # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 |
| 58 | + x_div = x.transpose(0, 1) / scale |
| 59 | + x_round = torch.round(x_div) |
| 60 | + x_zp = x_round + zero_point |
| 61 | + x_zp = x_zp.transpose(0, 1) |
| 62 | + quant = torch.clamp(x_zp, quant_min, quant_max) |
| 63 | + if target_dtype == "int4": |
| 64 | + quant = UInt4Tensor.from_unpacked(quant.view(torch.bits8)).view(quant.size()) |
| 65 | + else: |
| 66 | + quant = quant.to(target_dtype) |
| 67 | + |
| 68 | + return quant, scale, zero_point |
| 69 | + |
| 70 | +class _WeightOnlyInt4QuantLinear(torch.nn.Linear): |
| 71 | + def __init__(self, *args, **kwargs): |
| 72 | + w_int4 = kwargs.pop("w_int4") |
| 73 | + scales = kwargs.pop("scales") |
| 74 | + super().__init__(*args, **kwargs) |
| 75 | + self.w_int4 = w_int4 |
| 76 | + self.scales = scales |
| 77 | + |
| 78 | + def forward(self, x): |
| 79 | + # if len(x.shape)<=2: |
| 80 | + # y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales |
| 81 | + # else: # turn x into 2d tensor, then undo it for y |
| 82 | + x_view = x.view(-1, x.shape[-1]) |
| 83 | + y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales |
| 84 | + y = y.reshape(*x.shape[:-1], -1) |
| 85 | + if self.bias is not None: |
| 86 | + y += self.bias |
| 87 | + return y |
| 88 | + |
| 89 | + @classmethod |
| 90 | + def from_float(cls, mod): |
| 91 | + w_fp32 = mod.weight |
| 92 | + w_int4, scales, _zp = _dynamically_quantize_per_channel_int4( |
| 93 | + w_fp32, 0, 15, "int4" |
| 94 | + ) |
| 95 | + # create the new module with a toy size to ensure initialization is fast |
| 96 | + fake_in_features, fake_out_features = 8, 8 |
| 97 | + new_mod = cls( |
| 98 | + fake_in_features, |
| 99 | + fake_out_features, |
| 100 | + bias=mod.bias is not None, |
| 101 | + # w_int4=w_int4.t().contiguous(), |
| 102 | + w_int4=torch.ops.aten.transpose_copy(w_int4, 0, 1), |
| 103 | + scales=scales, |
| 104 | + ) |
| 105 | + new_mod.in_features = mod.in_features |
| 106 | + new_mod.out_features = mod.out_features |
| 107 | + del new_mod.weight |
| 108 | + new_mod.bias = mod.bias |
| 109 | + device_to_use = next(mod.parameters()).device |
| 110 | + new_mod.to(device_to_use) |
| 111 | + return new_mod |
| 112 | + |
| 113 | +def _apply_weight_only_int4_quant(model): |
| 114 | + replace_with_custom_fn_if_matches_filter( |
| 115 | + model, |
| 116 | + _WeightOnlyInt4QuantLinear.from_float, |
| 117 | + lambda mod, fqn: isinstance(mod, torch.nn.Linear), |
| 118 | + ) |
| 119 | + |
| 120 | +from torch.library import Library, impl |
| 121 | + |
| 122 | +test_lib = Library("test_int4", "DEF") |
| 123 | +test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") |
| 124 | + |
| 125 | +@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") |
| 126 | +def quantize_per_tensor_int4( |
| 127 | + input: torch.Tensor, |
| 128 | + scale: float, |
| 129 | + zero_point: int, |
| 130 | +) -> torch.Tensor: |
| 131 | + inv_scale = 1.0 / scale |
| 132 | + return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8) |
| 133 | + |
| 134 | +test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") |
| 135 | +@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") |
| 136 | +def dequantize_per_tensor_int4( |
| 137 | + input: torch.Tensor, |
| 138 | + scale: float, |
| 139 | + zero_point: int, |
| 140 | +) -> torch.Tensor: |
| 141 | + print("1") |
| 142 | + a = input.to(torch.uint8) |
| 143 | + print("2") |
| 144 | + a = a.to(torch.float32) |
| 145 | + print("3") |
| 146 | + a = a - zero_point |
| 147 | + print("4") |
| 148 | + a = a * scale |
| 149 | + print("5") |
| 150 | + return a |
| 151 | + # return (input.to(torch.uint8).to(torch.float32) - zero_point) * scale |
| 152 | + |
| 153 | + |
| 154 | +class TestInt4(QuantizationTestCase): |
| 155 | + def test_basic_tensor_ops(self): |
| 156 | + x = UInt4Tensor(torch.tensor([ |
| 157 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 158 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 159 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 160 | + ], dtype=torch.bits8)) |
| 161 | + self.assertEqual(x.shape, (3, 16)) |
| 162 | + # making sure these works |
| 163 | + x.to(torch.uint8) |
| 164 | + expected = UInt4Tensor(torch.tensor([ |
| 165 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 166 | + ], dtype=torch.bits8)) |
| 167 | + self.assertTrue(x[0:1, :] == expected) |
| 168 | + expected = UInt4Tensor(torch.tensor([ |
| 169 | + [0x23, 0x45], |
| 170 | + [0x23, 0x45], |
| 171 | + [0x23, 0x45], |
| 172 | + ], dtype=torch.bits8)) |
| 173 | + self.assertTrue(x[:, 2:6] == expected) |
| 174 | + |
| 175 | + def test_gpu_quant(self): |
| 176 | + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: |
| 177 | + x = torch.randn(*x_shape) |
| 178 | + m = nn.Sequential(nn.Linear(4, 16)) |
| 179 | + y_ref = m(x) |
| 180 | + _apply_weight_only_int4_quant(m) |
| 181 | + y_wo = m(x) |
| 182 | + # sqnr = compute_error(y_ref, y_wo) |
| 183 | + opt = torch.compile(m, mode="max-autotune") |
| 184 | + # make sure it runs |
| 185 | + opt(x) |
| 186 | + |
| 187 | + def test_aten_ir(self): |
| 188 | + # class QuantizePerTensorUInt4(torch.autograd.Function): |
| 189 | + # @staticmethod |
| 190 | + # def forward( |
| 191 | + # ctx, |
| 192 | + # input: torch.Tensor, |
| 193 | + # scale: float, |
| 194 | + # zero_point: int, |
| 195 | + # ) -> torch.Tensor: |
| 196 | + # inv_scale = 1.0 / scale |
| 197 | + # return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.bits8)) |
| 198 | + |
| 199 | + # class DeQuantizePerTensorUInt4(torch.autograd.Function): |
| 200 | + # @staticmethod |
| 201 | + # def forward( |
| 202 | + # ctx, |
| 203 | + # input: torch.Tensor, |
| 204 | + # scale: float, |
| 205 | + # zero_point: int, |
| 206 | + # ) -> torch.Tensor: |
| 207 | + # return (input.to(torch.float32) - zero_point) * scale |
| 208 | + |
| 209 | + class M(torch.nn.Module): |
| 210 | + def forward(self, x, y): |
| 211 | + return x + y |
| 212 | + |
| 213 | + example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),) |
| 214 | + m = M().eval() |
| 215 | + m = capture_pre_autograd_graph(m, example_inputs) |
| 216 | + for n in m.graph.nodes: |
| 217 | + if n.target == torch.ops.aten.add.Tensor: |
| 218 | + with m.graph.inserting_before(n): |
| 219 | + q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {}) |
| 220 | + dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {}) |
| 221 | + n.replace_input_with(n.args[0], dq) |
| 222 | + m.recompile() |
| 223 | + |
| 224 | + def test_pt2e_quant(self): |
| 225 | + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( |
| 226 | + OP_TO_ANNOTATOR, |
| 227 | + QuantizationConfig, |
| 228 | + ) |
| 229 | + class int4_class(): |
| 230 | + pass |
| 231 | + |
| 232 | + torch.int4 = int4_class() |
| 233 | + |
| 234 | + class Int4Observer(ObserverBase): |
| 235 | + def __init__(self, *args, **kwargs): |
| 236 | + # just faking a dtype here |
| 237 | + # TODO: make flow work with new dtypes |
| 238 | + super().__init__(dtype=torch.int8) |
| 239 | + |
| 240 | + def forward(self, x): |
| 241 | + return x |
| 242 | + |
| 243 | + def calculate_qparams(self, **kwargs): |
| 244 | + pass |
| 245 | + |
| 246 | + def convert(self, model: GraphModule, observer_node: Node): |
| 247 | + with model.graph.inserting_before(observer_node): |
| 248 | + q_node = model.graph.call_function( |
| 249 | + torch.ops.test_int4.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {}) |
| 250 | + dq_node = model.graph.call_function( |
| 251 | + torch.ops.test_int4.dequantize_per_tensor_int4, (q_node, 1.0, 0), {}) |
| 252 | + observer_node.replace_all_uses_with(dq_node) |
| 253 | + model.graph.erase_node(observer_node) |
| 254 | + |
| 255 | + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( |
| 256 | + _is_annotated, |
| 257 | + _mark_nodes_as_annotated, |
| 258 | + ) |
| 259 | + |
| 260 | + class Int4WeightQuantizer(Quantizer): |
| 261 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 262 | + int4_qspec = QuantizationSpec( |
| 263 | + dtype=torch.int4, |
| 264 | + quant_min=-2**3, |
| 265 | + quant_max=2**3 - 1, |
| 266 | + qscheme=torch.per_tensor_affine, |
| 267 | + is_dynamic=False, |
| 268 | + observer_or_fake_quant_ctr=Int4Observer, |
| 269 | + ) |
| 270 | + int8_qspec = QuantizationSpec( |
| 271 | + dtype=torch.int8, |
| 272 | + quant_min=-128, |
| 273 | + quant_max=127, |
| 274 | + qscheme=torch.per_tensor_symmetric, |
| 275 | + is_dynamic=False, |
| 276 | + observer_or_fake_quant_ctr=torch.ao.quantization.observer.default_weight_observer, |
| 277 | + ) |
| 278 | + quantization_config = QuantizationConfig( |
| 279 | + input_activation=int8_qspec, |
| 280 | + weight=int4_qspec, |
| 281 | + bias=None, |
| 282 | + output_activation=int8_qspec, |
| 283 | + ) |
| 284 | + for n in model.graph.nodes: |
| 285 | + if n.op != "call_function" or n.target not in [ |
| 286 | + torch.ops.aten.conv1d.default, |
| 287 | + torch.ops.aten.conv2d.default, |
| 288 | + ]: |
| 289 | + continue |
| 290 | + conv_node = n |
| 291 | + |
| 292 | + input_qspec_map = {} |
| 293 | + input_act = conv_node.args[0] |
| 294 | + assert isinstance(input_act, Node) |
| 295 | + input_qspec_map[input_act] = quantization_config.input_activation |
| 296 | + |
| 297 | + weight = conv_node.args[1] |
| 298 | + assert isinstance(weight, Node) |
| 299 | + input_qspec_map[weight] = quantization_config.weight |
| 300 | + |
| 301 | + partition = [conv_node, conv_node.args[1]] |
| 302 | + |
| 303 | + bias = conv_node.args[2] if len(conv_node.args) > 2 else None |
| 304 | + if isinstance(bias, Node): |
| 305 | + input_qspec_map[bias] = quantization_config.bias |
| 306 | + partition.append(bias) |
| 307 | + |
| 308 | + if _is_annotated(partition): |
| 309 | + continue |
| 310 | + |
| 311 | + conv_node.meta["quantization_annotation"] = QuantizationAnnotation( |
| 312 | + input_qspec_map=input_qspec_map, |
| 313 | + output_qspec=quantization_config.output_activation, |
| 314 | + _annotated=True, |
| 315 | + ) |
| 316 | + _mark_nodes_as_annotated(partition) |
| 317 | + |
| 318 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 319 | + pass |
| 320 | + |
| 321 | + class M(torch.nn.Module): |
| 322 | + def __init__(self): |
| 323 | + super().__init__() |
| 324 | + self.conv = torch.nn.Conv2d(3, 3, 3) |
| 325 | + |
| 326 | + def forward(self, x): |
| 327 | + return self.conv(x) |
| 328 | + |
| 329 | + quantizer = Int4WeightQuantizer() |
| 330 | + node_occurrence = { |
| 331 | + # for weight |
| 332 | + torch.ops.test_int4.quantize_per_tensor_int4: 1, |
| 333 | + torch.ops.test_int4.dequantize_per_tensor_int4: 1, |
| 334 | + # for activation |
| 335 | + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, |
| 336 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, |
| 337 | + } |
| 338 | + node_list = [ |
| 339 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| 340 | + torch.ops.test_int4.dequantize_per_tensor_int4, |
| 341 | + torch.ops.aten.conv2d.default, |
| 342 | + torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| 343 | + ] |
| 344 | + example_inputs = (torch.randn(1, 3, 3, 3),) |
| 345 | + |
| 346 | + # _test_quantizer in PT2EQuantizationTestCase |
| 347 | + # resetting dynamo cache |
| 348 | + export_with_dynamic_shape = False |
| 349 | + torch._dynamo.reset() |
| 350 | + m_eager = M().eval() |
| 351 | + |
| 352 | + # program capture |
| 353 | + m = copy.deepcopy(m_eager) |
| 354 | + m = capture_pre_autograd_graph( |
| 355 | + m, |
| 356 | + example_inputs, |
| 357 | + ) |
| 358 | + |
| 359 | + m = prepare_pt2e(m, quantizer) |
| 360 | + # Calibrate |
| 361 | + m(*example_inputs) |
| 362 | + m = convert_pt2e(m, fold_quantize=False) |
| 363 | + |
| 364 | + pt2_quant_output = m(*example_inputs) |
| 365 | + node_occurrence = { |
| 366 | + ns.call_function(k): v for k, v in node_occurrence.items() |
| 367 | + } |
| 368 | + node_list = [ns.call_function(n) for n in node_list] |
| 369 | + self.checkGraphModuleNodes( |
| 370 | + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list |
| 371 | + ) |
| 372 | + |
| 373 | +if __name__ == "__main__": |
| 374 | + main() |
0 commit comments