|
| 1 | +import torch |
| 2 | +from torchao.dtypes.int4 import UInt4Tensor |
| 3 | +import unittest |
| 4 | +from unittest import TestCase, main |
| 5 | +from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e |
| 6 | +from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer |
| 7 | + |
| 8 | +from torch._export import capture_pre_autograd_graph |
| 9 | +from torch._export import dynamic_dim |
| 10 | +from torch.testing._internal.common_quantization import ( |
| 11 | + NodeSpec as ns, |
| 12 | + QuantizationTestCase, |
| 13 | +) |
| 14 | +from torchao.quantization.utils import ( |
| 15 | + compute_error, |
| 16 | +) |
| 17 | +from torchao.quantization.quant_api import ( |
| 18 | + replace_with_custom_fn_if_matches_filter, |
| 19 | +) |
| 20 | +from torch import nn |
| 21 | +import copy |
| 22 | + |
| 23 | +def _dynamically_quantize_per_channel_int4(x, quant_min, quant_max, target_dtype): |
| 24 | + # assumes symmetric quantization |
| 25 | + # assumes axis == 0 |
| 26 | + # assumes dense memory format |
| 27 | + # TODO(future): relax ^ as needed |
| 28 | + |
| 29 | + # default setup for affine quantization of activations |
| 30 | + eps = torch.finfo(torch.float32).eps |
| 31 | + |
| 32 | + # get min and max |
| 33 | + min_val, max_val = torch.aminmax(x, dim=1) |
| 34 | + |
| 35 | + # calculate scale and zero point based on min and max |
| 36 | + # reference: https://fburl.com/code/srbiybme |
| 37 | + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) |
| 38 | + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) |
| 39 | + device = min_val_neg.device |
| 40 | + |
| 41 | + # reference: https://fburl.com/code/4wll53rk |
| 42 | + max_val_pos = torch.max(-min_val_neg, max_val_pos) |
| 43 | + scale = max_val_pos / (float(quant_max - quant_min) / 2) |
| 44 | + # ensure scale is the same dtype as the original tensor |
| 45 | + scale = torch.clamp(scale, min=eps).to(x.dtype) |
| 46 | + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) |
| 47 | + |
| 48 | + # quantize based on qmin/qmax/scale/zp |
| 49 | + # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 |
| 50 | + x_div = x.transpose(0, 1) / scale |
| 51 | + x_round = torch.round(x_div) |
| 52 | + x_zp = x_round + zero_point |
| 53 | + x_zp = x_zp.transpose(0, 1) |
| 54 | + quant = torch.clamp(x_zp, quant_min, quant_max) |
| 55 | + if target_dtype == "int4": |
| 56 | + quant = UInt4Tensor.from_unpacked(quant.to(torch.uint8)).view(quant.size()) |
| 57 | + else: |
| 58 | + quant = quant.to(target_dtype) |
| 59 | + |
| 60 | + return quant, scale, zero_point |
| 61 | + |
| 62 | +class _WeightOnlyInt4QuantLinear(torch.nn.Linear): |
| 63 | + def __init__(self, *args, **kwargs): |
| 64 | + w_int4 = kwargs.pop("w_int4") |
| 65 | + scales = kwargs.pop("scales") |
| 66 | + super().__init__(*args, **kwargs) |
| 67 | + self.w_int4 = w_int4 |
| 68 | + self.scales = scales |
| 69 | + |
| 70 | + def forward(self, x): |
| 71 | + # if len(x.shape)<=2: |
| 72 | + # y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales |
| 73 | + # else: # turn x into 2d tensor, then undo it for y |
| 74 | + x_view = x.view(-1, x.shape[-1]) |
| 75 | + y = torch.mm(x_view, self.w_int4.to(torch.uint8).to(x.dtype)) * self.scales |
| 76 | + y = y.reshape(*x.shape[:-1], -1) |
| 77 | + if self.bias is not None: |
| 78 | + y += self.bias |
| 79 | + return y |
| 80 | + |
| 81 | + @classmethod |
| 82 | + def from_float(cls, mod): |
| 83 | + w_fp32 = mod.weight |
| 84 | + w_int4, scales, _zp = _dynamically_quantize_per_channel_int4( |
| 85 | + w_fp32, 0, 15, "int4" |
| 86 | + ) |
| 87 | + # create the new module with a toy size to ensure initialization is fast |
| 88 | + fake_in_features, fake_out_features = 8, 8 |
| 89 | + new_mod = cls( |
| 90 | + fake_in_features, |
| 91 | + fake_out_features, |
| 92 | + bias=mod.bias is not None, |
| 93 | + w_int4=w_int4.t().contiguous(), |
| 94 | + scales=scales, |
| 95 | + ) |
| 96 | + new_mod.in_features = mod.in_features |
| 97 | + new_mod.out_features = mod.out_features |
| 98 | + del new_mod.weight |
| 99 | + new_mod.bias = mod.bias |
| 100 | + device_to_use = next(mod.parameters()).device |
| 101 | + new_mod.to(device_to_use) |
| 102 | + return new_mod |
| 103 | + |
| 104 | +def _apply_weight_only_int4_quant(model): |
| 105 | + replace_with_custom_fn_if_matches_filter( |
| 106 | + model, |
| 107 | + _WeightOnlyInt4QuantLinear.from_float, |
| 108 | + lambda mod, fqn: isinstance(mod, torch.nn.Linear), |
| 109 | + ) |
| 110 | + |
| 111 | +class TestInt4(QuantizationTestCase): |
| 112 | + def test_basic_tensor_ops(self): |
| 113 | + x = UInt4Tensor(torch.tensor([ |
| 114 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 115 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 116 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 117 | + ], dtype=torch.uint8)) |
| 118 | + self.assertTrue(x.shape, (3, 8)) |
| 119 | + # making sure these works |
| 120 | + x.to(torch.uint8) |
| 121 | + expected = UInt4Tensor(torch.tensor([ |
| 122 | + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], |
| 123 | + ], dtype=torch.uint8)) |
| 124 | + self.assertTrue(x[0:1, :] == expected) |
| 125 | + expected = UInt4Tensor(torch.tensor([ |
| 126 | + [0x23, 0x45], |
| 127 | + [0x23, 0x45], |
| 128 | + [0x23, 0x45], |
| 129 | + ], dtype=torch.uint8)) |
| 130 | + self.assertTrue(x[:, 2:6] == expected) |
| 131 | + |
| 132 | + def test_gpu_quant(self): |
| 133 | + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: |
| 134 | + x = torch.randn(*x_shape) |
| 135 | + m = nn.Sequential(nn.Linear(4, 16)) |
| 136 | + y_ref = m(x) |
| 137 | + _apply_weight_only_int4_quant(m) |
| 138 | + y_wo = m(x) |
| 139 | + # sqnr = compute_error(y_ref, y_wo) |
| 140 | + opt = torch.compile(m, mode="max-autotune") |
| 141 | + # make sure it runs |
| 142 | + opt(x) |
| 143 | + |
| 144 | + def test_aten_ir(self): |
| 145 | + from torch.library import Library, impl |
| 146 | + test_lib = Library("test_int4", "DEF") |
| 147 | + test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") |
| 148 | + @impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd") |
| 149 | + def quantize_per_tensor_int4( |
| 150 | + input: torch.Tensor, |
| 151 | + scale: float, |
| 152 | + zero_point: int, |
| 153 | + ) -> torch.Tensor: |
| 154 | + inv_scale = 1.0 / scale |
| 155 | + return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8) |
| 156 | + |
| 157 | + test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor") |
| 158 | + @impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd") |
| 159 | + def dequantize_per_tensor_int4( |
| 160 | + input: torch.Tensor, |
| 161 | + scale: float, |
| 162 | + zero_point: int, |
| 163 | + ) -> torch.Tensor: |
| 164 | + return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale |
| 165 | + |
| 166 | + # class QuantizePerTensorUInt4(torch.autograd.Function): |
| 167 | + # @staticmethod |
| 168 | + # def forward( |
| 169 | + # ctx, |
| 170 | + # input: torch.Tensor, |
| 171 | + # scale: float, |
| 172 | + # zero_point: int, |
| 173 | + # ) -> torch.Tensor: |
| 174 | + # inv_scale = 1.0 / scale |
| 175 | + # return UInt4Tensor(torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8)) |
| 176 | + |
| 177 | + # class DeQuantizePerTensorUInt4(torch.autograd.Function): |
| 178 | + # @staticmethod |
| 179 | + # def forward( |
| 180 | + # ctx, |
| 181 | + # input: torch.Tensor, |
| 182 | + # scale: float, |
| 183 | + # zero_point: int, |
| 184 | + # ) -> torch.Tensor: |
| 185 | + # return (input.to(torch.float32) - zero_point) * scale |
| 186 | + |
| 187 | + class M(torch.nn.Module): |
| 188 | + def forward(self, x, y): |
| 189 | + return x + y |
| 190 | + |
| 191 | + example_inputs = (torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3),) |
| 192 | + m = M().eval() |
| 193 | + m = capture_pre_autograd_graph(m, example_inputs) |
| 194 | + for n in m.graph.nodes: |
| 195 | + if n.target == torch.ops.aten.add.Tensor: |
| 196 | + with m.graph.inserting_before(n): |
| 197 | + q = m.graph.call_function(torch.ops.test_int4.quantize_per_tensor_int4, (n.args[0], 1.0, 0), {}) |
| 198 | + dq = m.graph.call_function(torch.ops.test_int4.dequantize_per_tensor_int4, (q, 1.0, 0), {}) |
| 199 | + n.replace_input_with(n.args[0], dq) |
| 200 | + m.recompile() |
| 201 | + |
| 202 | + # TODO: need more extension points from quant flow side |
| 203 | + @unittest.skip("need more extension points from quant flow side") |
| 204 | + def test_pt2e_quant(self): |
| 205 | + from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( |
| 206 | + OP_TO_ANNOTATOR, |
| 207 | + QuantizationConfig, |
| 208 | + ) |
| 209 | + |
| 210 | + class Int4ActQuantizer(Quantizer): |
| 211 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 212 | + int4_qspec = QuantizationSpec( |
| 213 | + dtype=torch.int8, |
| 214 | + quant_min=-2**3, |
| 215 | + quant_max=2**3 - 1, |
| 216 | + qscheme=torch.per_tensor_affine, |
| 217 | + is_dynamic=False, |
| 218 | + observer_or_fake_quant_ctr=observer.default_observer, |
| 219 | + ) |
| 220 | + int8_qspec = QuantizationSpec( |
| 221 | + dtype=torch.int8, |
| 222 | + quant_min=-128, |
| 223 | + quant_max=127, |
| 224 | + qscheme=torch.per_tensor_symmetric, |
| 225 | + is_dynamic=False, |
| 226 | + observer_or_fake_quant_ctr=observer.default_weight_observer, |
| 227 | + ) |
| 228 | + quantization_config = QuantizationConfig( |
| 229 | + input_activation=int8_qspec, |
| 230 | + weight=int4_qspec, |
| 231 | + bias=None, |
| 232 | + output_activation=int8_qspec, |
| 233 | + ) |
| 234 | + OP_TO_ANNOTATOR["conv"](model, quantization_config) |
| 235 | + |
| 236 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 237 | + pass |
| 238 | + |
| 239 | + class M(torch.nn.Module): |
| 240 | + def __init__(self): |
| 241 | + super().__init__() |
| 242 | + self.conv = torch.nn.Conv2d(3, 3, 3) |
| 243 | + |
| 244 | + def forward(self, x): |
| 245 | + return self.conv(x) |
| 246 | + |
| 247 | + quantizer = Int4ActQuantizer() |
| 248 | + node_occurrence = { |
| 249 | + # one for input of the first conv, one for output for the first conv |
| 250 | + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, |
| 251 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, |
| 252 | + } |
| 253 | + node_list = [ |
| 254 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| 255 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| 256 | + torch.ops.aten.conv2d.default, |
| 257 | + torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| 258 | + ] |
| 259 | + example_inputs = (torch.randn(1, 3, 3, 3),) |
| 260 | + |
| 261 | + # _test_quantizer in PT2EQuantizationTestCase |
| 262 | + # resetting dynamo cache |
| 263 | + export_with_dynamic_shape = False |
| 264 | + torch._dynamo.reset() |
| 265 | + m_eager = M().eval() |
| 266 | + |
| 267 | + # program capture |
| 268 | + m = copy.deepcopy(m_eager) |
| 269 | + m = capture_pre_autograd_graph( |
| 270 | + m, |
| 271 | + example_inputs, |
| 272 | + constraints=[dynamic_dim(example_inputs[0], 0)] if export_with_dynamic_shape else [], |
| 273 | + ) |
| 274 | + |
| 275 | + m = prepare_pt2e(m, quantizer) |
| 276 | + # Calibrate |
| 277 | + m(*example_inputs) |
| 278 | + m = convert_pt2e(m, fold_quantize=True) |
| 279 | + |
| 280 | + pt2_quant_output = m(*example_inputs) |
| 281 | + node_occurrence = { |
| 282 | + ns.call_function(k): v for k, v in expected_node_occurrence.items() |
| 283 | + } |
| 284 | + if expected_node_list is None: |
| 285 | + expected_node_list = [] |
| 286 | + node_list = [ns.call_function(n) for n in expected_node_list] |
| 287 | + self.checkGraphModuleNodes( |
| 288 | + m, expected_node_occurrence=node_occurrence, expected_node_list=node_list |
| 289 | + ) |
| 290 | + |
| 291 | +if __name__ == "__main__": |
| 292 | + main() |
0 commit comments