diff --git a/ao/quantization/__init__.py b/ao/quantization/__init__.py new file mode 100644 index 0000000000..a18cbaa43a --- /dev/null +++ b/ao/quantization/__init__.py @@ -0,0 +1,39 @@ +from smoothquant import * # noqa: F403 +from quant_api import * # noqa: F403 +from subclass import * # noqa: F403 +from quant_primitives import * # noqa: F403 +from utils import * # noqa: F403 +from weight_only import * # noqa: F403 + +__all__ = [ + "DynamicallyPerAxisQuantizedLinear", + "replace_with_custom_fn_if_matches_filter", + "apply_weight_only_int8_quant", + "apply_dynamic_quant", + "change_linear_weights_to_dqtensors", + "insert_subclass", + "safe_int_mm", + "dynamically_quantize_per_tensor", + "quantize_activation_per_token_absmax", + "dynamically_quantize_per_channel", + "dequantize_per_tensor", + "dequantize_per_channel", + "quant_int8_dynamic_linear", + "quant_int8_matmul", + "quant_int8_dynamic_per_token_linear", + "quant_int8_per_token_matmul", + "get_scale", + "SmoothFakeDynQuantMixin", + "SmoothFakeDynamicallyQuantizedLinear", + "swap_linear_with_smooth_fq_linear", + "smooth_fq_linear_to_inference", + "set_smooth_fq_attribute", + "DynamicallyQuantizedLinearWeight", + "log_with_rank", + "clear_logs", + "compute_error", + "forward_hook", + "apply_logging_hook", + "get_model_size_in_bytes", + "WeightOnlyInt8QuantLinear", +] diff --git a/ao/quantization/__pycache__/dynamic_quant.cpython-310.pyc b/ao/quantization/__pycache__/dynamic_quant.cpython-310.pyc new file mode 100644 index 0000000000..9618fb7de9 Binary files /dev/null and b/ao/quantization/__pycache__/dynamic_quant.cpython-310.pyc differ diff --git a/ao/quantization/__pycache__/quant_api.cpython-310.pyc b/ao/quantization/__pycache__/quant_api.cpython-310.pyc new file mode 100644 index 0000000000..2d13215a0b Binary files /dev/null and b/ao/quantization/__pycache__/quant_api.cpython-310.pyc differ diff --git a/ao/quantization/__pycache__/quant_primitives.cpython-310.pyc b/ao/quantization/__pycache__/quant_primitives.cpython-310.pyc new file mode 100644 index 0000000000..7f956bc39f Binary files /dev/null and b/ao/quantization/__pycache__/quant_primitives.cpython-310.pyc differ diff --git a/ao/quantization/__pycache__/smoothquant.cpython-310.pyc b/ao/quantization/__pycache__/smoothquant.cpython-310.pyc new file mode 100644 index 0000000000..0d6ac773c1 Binary files /dev/null and b/ao/quantization/__pycache__/smoothquant.cpython-310.pyc differ diff --git a/ao/quantization/__pycache__/subclass.cpython-310.pyc b/ao/quantization/__pycache__/subclass.cpython-310.pyc new file mode 100644 index 0000000000..ef08b75015 Binary files /dev/null and b/ao/quantization/__pycache__/subclass.cpython-310.pyc differ diff --git a/ao/quantization/__pycache__/utils.cpython-310.pyc b/ao/quantization/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000..5e3028dc9f Binary files /dev/null and b/ao/quantization/__pycache__/utils.cpython-310.pyc differ diff --git a/ao/quantization/__pycache__/weight_only.cpython-310.pyc b/ao/quantization/__pycache__/weight_only.cpython-310.pyc new file mode 100644 index 0000000000..db875a9c56 Binary files /dev/null and b/ao/quantization/__pycache__/weight_only.cpython-310.pyc differ diff --git a/ao/quantization/dynamic_quant.py b/ao/quantization/dynamic_quant.py new file mode 100644 index 0000000000..dd28a0afa9 --- /dev/null +++ b/ao/quantization/dynamic_quant.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn +from quant_primitives import ( + dynamically_quantize_per_channel, + quant_int8_dynamic_per_token_linear, +) + +__all__ = ["DynamicallyPerAxisQuantizedLinear"] + + +class DynamicallyPerAxisQuantizedLinear(torch.nn.Linear): + """ + This class is a replacement for `torch.nn.Linear`, implementing dynamic quantization on + the input across all axes except for the last axis. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + use_fused_int_mm=False, + ) -> None: + super().__init__(in_features, out_features, bias) + self.use_fused_int_mm = use_fused_int_mm + # note: enabling use_fused_int_mm = True has best perf when additionally setting + # torch._inductor.config.force_fuse_int_mm_with_mul = True + + def forward(self, X: torch.Tensor) -> torch.Tensor: + """ + Performs the forward pass of the quantized linear layer. + + This method applies dynamic quantization to the input tensor across all axes except + the last axis using the `quant_int8_dynamic_per_token_linear` function. + + Args: + X (torch.Tensor): The input tensor to the quantized linear layer. + + Returns: + torch.Tensor: The output tensor after the quantized matmul and rescale. + + """ + # The following line mimics the behavior of SmoothFakeDynamicallyQuantizedLinear + if not self.use_fused_int_mm: + X = X / self.fake_rescale + # somehow the inductor fusion that occurs for most transformer models + # when this module has an additional div op is faster than when it doesn't + # have it although the memory usage is slightly higher. fake_rescale is scalar 1 + # so it doesn't affect accuracy + Y = quant_int8_dynamic_per_token_linear( + X, self.W_int_repr_t, self.W_scales, self.bias, X.dtype + ) + return Y + + @classmethod + def from_float( + cls, mod: torch.nn.Linear, use_fused_int_mm=False + ) -> "DynamicallyPerAxisQuantizedLinear": + """ + Converts a `mod` of class `torch.nn.Linear` to the dynamically quantized version of it. + + Note: this class does not require calibration. + + Args: + mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. + + Returns: + DynamicallyPerAxisQuantizedLinear: The converted quantized linear module. + + """ + + # create the new module with a toy size to ensure initialization is fast + fake_in_features, fake_out_features = 8, 8 + new_mod = cls( + fake_in_features, + fake_out_features, + bias=mod.bias is not None, + use_fused_int_mm=use_fused_int_mm, + ) + new_mod.in_features = mod.in_features + new_mod.out_features = mod.out_features + W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel( + mod.weight, -128, 127, torch.int8 + ) + new_mod.register_buffer("W_int_repr_t", W_int_repr.contiguous().t()) + new_mod.W_scales = nn.Parameter(W_scales) + new_mod.bias = mod.bias + if not use_fused_int_mm: + new_mod.fake_rescale = torch.tensor( + [1.0], dtype=mod.weight.dtype, device=mod.weight.device + ) + del new_mod.weight + + device_to_use = next(mod.parameters()).device + new_mod.to(device_to_use) + return new_mod diff --git a/ao/quantization/quant_api.py b/ao/quantization/quant_api.py new file mode 100644 index 0000000000..20b09b5cb3 --- /dev/null +++ b/ao/quantization/quant_api.py @@ -0,0 +1,74 @@ +""" +Quantization API stuff which is not specific to SmoothQuant + +Note: this is throwaway code for fast results on Blueberry, this is not +intended to be the actual long term quantization API for server GPUs. +""" + +import torch +from dynamic_quant import ( + DynamicallyPerAxisQuantizedLinear, +) +from subclass import ( + DynamicallyQuantizedLinearWeight, +) +from weight_only import ( + WeightOnlyInt8QuantLinear, +) + +__all__ = [ + "replace_with_custom_fn_if_matches_filter", + "apply_weight_only_int8_quant", + "apply_dynamic_quant", + "change_linear_weights_to_dqtensors", +] + + +def replace_with_custom_fn_if_matches_filter( + model, replacement_fn, filter_fn, cur_fqn="" +) -> None: + """ + For each `child` in `model`, replaces it with `replacement_fn(child)` + if `filter_fn(child)` is `True` + """ + name_to_child = dict(model.named_children()) + for name, child in name_to_child.items(): + if cur_fqn == "": + new_fqn = name + else: + new_fqn = f"{cur_fqn}.{name}" + if filter_fn(child, new_fqn): + new_child = replacement_fn(child) + setattr(model, name, new_child) + else: + replace_with_custom_fn_if_matches_filter( + child, replacement_fn, filter_fn, new_fqn + ) + + +def apply_weight_only_int8_quant(model): + replace_with_custom_fn_if_matches_filter( + model, + WeightOnlyInt8QuantLinear.from_float, + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + + +def apply_dynamic_quant(model, use_fused_int_mm=0): + replace_with_custom_fn_if_matches_filter( + model, + lambda mod: DynamicallyPerAxisQuantizedLinear.from_float(mod, use_fused_int_mm), + lambda mod, fqn: isinstance(mod, torch.nn.Linear), + ) + + +def change_linear_weights_to_dqtensors(model): + def insert_subclass(lin): + lin.weight = torch.nn.Parameter( + DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False + ) + return lin + + replace_with_custom_fn_if_matches_filter( + model, insert_subclass, lambda mod, fqn: isinstance(mod, torch.nn.Linear) + ) diff --git a/ao/quantization/quant_primitives.py b/ao/quantization/quant_primitives.py new file mode 100644 index 0000000000..e8a53e3135 --- /dev/null +++ b/ao/quantization/quant_primitives.py @@ -0,0 +1,384 @@ +import torch +from torch._dynamo import is_compiling as dynamo_is_compiling +from torch._higher_order_ops.out_dtype import out_dtype + +__all__ = [ + "safe_int_mm", + "dynamically_quantize_per_tensor", + "quantize_activation_per_token_absmax", + "dynamically_quantize_per_channel", + "dequantize_per_tensor", + "dequantize_per_channel", + "quant_int8_dynamic_linear", + "quant_int8_matmul", + "quant_int8_dynamic_per_token_linear", + "quant_int8_per_token_matmul", +] + + +def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + r""" + This function wraps torch._int_mm and avoids several undesirable behaviors of the function for certain inputs while still + returning correct results and being torch.compiled in a performant way. + + Assumes both tensors have dimension of 2. + + Note: no error checking for torch.compiled path, if input.shape = [i, j] and j<=16 then the triton kernel + will error. + + Args: + input (Tensor, int8): the first tensor to be multiplied + mat2 (Tensor, int8): the second tensor to be multiplied + + Return: + out (Tensor, int32): the result of the matmul with device matching that of the inputs + """ + + # torch.compile path + if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + + # error checking for cublas path + assert ( + mat2.device == input.device + ), f"need both tensors to be on the same device but got {mat2.device} and {input.device}" + device_cpu = "cpu" in [mat2.device.type, input.device.type] + # with input.shape = [i,j] and mat2.shape = [j,k] + i_is_strictly_greater_than_16 = input.shape[0] > 16 + j_is_nonzero_multiple_of_8 = (input.shape[1] % 8 == 0) and (input.shape[1] > 0) + k_is_nonzero_multiple_of_8 = (mat2.shape[1] % 8 == 0) and (mat2.shape[1] > 0) + bad_dimensions_for_cublas = not ( + i_is_strictly_greater_than_16 + and j_is_nonzero_multiple_of_8 + and k_is_nonzero_multiple_of_8 + ) + + if device_cpu or bad_dimensions_for_cublas: + # fallback path + return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( + input.device.type + ) + + # cublas paths + if not mat2.is_contiguous(): # silently gives incorrect result without this + mat2 = mat2.contiguous() + if (not input.is_contiguous()) and ( + input.shape[0] % 8 != 0 + ): # gives cryptic error without this + input = ( + input.contiguous() + ) # (it seems the transpose makes cublas check the above j constraint on i) + return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) + + +# copy-pasta of https://www.internalfb.com/intern/anp/view/?id=3350736 +def dynamically_quantize_per_tensor( + x, + quant_min, + quant_max, + target_dtype, + qscheme=torch.per_tensor_affine, # for now, reuse existing qscheme enum +): + # assumes affine quantization + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + if qscheme == torch.per_tensor_affine: + # get min and max + # TODO(future): make torch.aminmax work on cpu-half + # min_val, max_val = torch.aminmax(x) + min_val = torch.min(x) + max_val = torch.max(x) + + # calculate scale and zero point based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + # TODO(future): make torch.clamp with scalar work on cpu-half + scale = torch.clamp(scale, min=eps).reshape(1) + zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + + # quantize based on qmin/qmax/scale/zp + # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 + quant = torch.clamp( + torch.round(x / scale) + zero_point, quant_min, quant_max + ).to(target_dtype) + + else: + assert qscheme == torch.per_tensor_symmetric, f"unsupported qscheme {qscheme}" + # assert quant_min == -1 * quant_max, "unsupported quant_min/quant_max" + amax = torch.max(torch.abs(x)) + scale = amax / (float(quant_max - quant_min) / 2) + scale = torch.clamp(scale, min=eps).reshape(1) + quant = torch.clamp(torch.round(x / scale), quant_min, quant_max).to( + target_dtype + ) + # do not create a tensor for zero_point as this is expensive + zero_point = None + + return quant, scale, zero_point + + +# taken from +# https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26 +# and slightly modified +def quantize_activation_per_token_absmax(t): + n_bits = 8 + # if the shape of t is [B, N, K], the shape of scales will be [B, N, 1] + + scales = t.abs().amax(dim=-1, keepdim=True) + if scales.dtype == torch.float16: + scales = ( + scales.float() + ) # want float scales to avoid overflows for fp16, (bf16 has wide enough range) + q_max = 2 ** (n_bits - 1) - 1 + scales = scales.clamp(min=1e-5).div(q_max) + # Note: the original smoothquant does not clamp to qmin/qmax here, + # but some of the tests with bfloat16 ended up with a flipped sign + # if we don't clamp. TODO(future) look into this further. + t = torch.round(t / scales).clamp(-127, 127).to(torch.int8) + return t, scales + + +def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): + # assumes symmetric quantization + # assumes axis == 0 + # assumes dense memory format + # TODO(future): relax ^ as needed + + # default setup for affine quantization of activations + eps = torch.finfo(torch.float32).eps + + # get min and max + min_val, max_val = torch.aminmax(x, dim=1) + + # calculate scale and zero point based on min and max + # reference: https://fburl.com/code/srbiybme + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + device = min_val_neg.device + + # reference: https://fburl.com/code/4wll53rk + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + # ensure scale is the same dtype as the original tensor + scale = torch.clamp(scale, min=eps).to(x.dtype) + zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + + # quantize based on qmin/qmax/scale/zp + # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 + x_div = x.transpose(0, 1) / scale + x_round = torch.round(x_div) + x_zp = x_round + zero_point + x_zp = x_zp.transpose(0, 1) + quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + + return quant, scale, zero_point + + +# reference: https://fburl.com/code/vfsygwd0 +def dequantize_per_tensor(int_repr, scale, zero_point, out_dtype=torch.float32): + y = int_repr.to(out_dtype) + if zero_point is not None: + y -= zero_point + return y * scale + + +# reference: https://fburl.com/code/org0fmi3 +def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float32): + # assumes axis is 0 + y = int_repr.transpose(0, 1) + y = y.to(out_dtype) + y = y - zero_points + y = y * scales + y = y.transpose(0, 1) + return y + + +def quant_int8_dynamic_linear( + x, + x_quant_min, + x_quant_max, + x_q_dtype, + w_vals_int8_t, + w_scales, + w_vals_int8_t_sums_int64, + bias, + out_dtype=torch.float32, +): + # like F.linear, but with int8 dynamic quantization of activation, + # and a quantized weight + x_vals_int8, x_scale, x_zp = dynamically_quantize_per_tensor( + x, x_quant_min, x_quant_max, x_q_dtype + ) + # w_vals_int8_t_sums_int64 = w_vals_int8_t.sum(dim=0) + mm_out = quant_int8_matmul( + x_vals_int8, + x_scale, + x_zp, + w_vals_int8_t, + w_vals_int8_t_sums_int64, + w_scales, + out_dtype, + ) + if bias is not None: + mm_out += bias + return mm_out + + +def quant_int8_matmul( + x_vals_int8, + x_scale, + x_zp, + w_vals_int8_t, + w_vals_int8_t_sums_int64, + w_scales, + out_dtype=torch.float32, +): + # Quantized matmul of int8 operands that accumulates to int32 and returns + # out_dtype. For now, this is written for approximate numerical + # correctness, and things like aligning accumulation behaviors and + # performance optimizations are left for a future PR. + # Assumes that weight quantization is symmetric, i.e. w_zp is 0. + # Assumes that weight quantization is per-channel. + + # see + # https://github.com/google/gemmlowp/blob/master/doc/quantization.md + # for an overview of quantized matmul compute + + # in scalar form, assuming out_dtype is fp32 and zw == 0: + # + # Y_i_j_fp32 = sx * sw (dot(X_i, W_j) - zx * sum(W_j)) + # + + assert x_vals_int8.dtype in ( + torch.uint8, + torch.int8, + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8_t.dtype == torch.int8 + ), f"w dtype {w_vals_int8_t.dtype} not yet supported" + assert w_scales.dtype == out_dtype, f"{w_scales.dtype} does not match {out_dtype}" + + # + # 1. do the matrix form of dot(X_i, W_j) + # + + # TODO(before land): add test case for input with bsz + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + y_dot_int32 = safe_int_mm(tmp, w_vals_int8_t) + y_dot_int32 = y_dot_int32.reshape(*x_vals_int8.shape[:-1], -1) + + # TODO(future): consider using integer arithmetic throughout, although + # TBD if that is actually faster on GPUs + # need to use 32 bits here to prevent overflow for large shapes, + # 16 bits is not enough + y_dot_float32 = y_dot_int32.to(torch.float32) + + # + # 2. connect it all together + # + + # mm_unscaled has to stay in float32 for the next two lines to prevent overflow + mm_unscaled_float32 = y_dot_float32 - (x_zp * w_vals_int8_t_sums_int64) + y = x_scale * w_scales * mm_unscaled_float32 + # can downcast only at the very end + y = y.to(out_dtype) + return y + + +def quant_int8_dynamic_per_token_linear( + x, + w_vals_int8_t, + w_scales, + bias, + out_dtype=torch.float32, + use_fused_int_mm=0, +): + # like F.linear, but with int8 dynamic quantization of activation, + # and a quantized weight + x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) + mm_out = quant_int8_per_token_matmul( + x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype, use_fused_int_mm + ) + if bias is not None: + mm_out += bias + return mm_out + + +def quant_int8_per_token_matmul( + x_vals_int8, + x_scales, + w_vals_int8_t, + w_scales, + output_dtype=torch.float32, + use_fused_int_mm=0, +): + # Quantized matmul of int8 operands that accumulates to int32 and returns + # output_dtype. For now, this is written for approximate numerical + # Assumes that activation and weight quantization are symmetric, + # i.e. act_zp and w_zp is 0. + # Assumes that weight quantization is per-channel. + + # see + # https://github.com/google/gemmlowp/blob/master/doc/quantization.md + # for an overview of quantized matmul compute + + # in scalar form, assuming output_dtype is fp32 and zw == 0: + # + # Y_i_j_fp32 = sx * sw dot(X_i, W_j) + # + + assert ( + x_vals_int8.dtype == torch.int8 + ), f"x dtype {x_vals_int8.dtype} not yet supported" + assert ( + w_vals_int8_t.dtype == torch.int8 + ), f"w dtype {w_vals_int8_t.dtype} not yet supported" + assert ( + w_scales.dtype == output_dtype + ), f"{w_scales.dtype} does not match {output_dtype}" + + # + # 1. do the matrix form of dot(X_i, W_j) + # + + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # these branches use external triton fused_int_mm kernel's which fuse either 1 or 2 mul operations + if use_fused_int_mm == 2: + y = torch.ops.custom_int_mm.int_mm_dequant( + tmp, w_vals_int8_t, x_scales.view(-1, 1), w_scales, output_dtype + ).reshape(*x_vals_int8.shape[:-1], -1) + return y + elif use_fused_int_mm == 1: + y = torch.ops.custom_int_mm.int_mm_one_mul( + tmp, w_vals_int8_t, x_scales.view(-1, 1), output_dtype + ).reshape(*x_vals_int8.shape[:-1], -1) + y = y * w_scales + return y.to(output_dtype) + y_dot_int32 = safe_int_mm(tmp, w_vals_int8_t) + + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + assert x_scales.dtype in [ + torch.float, + torch.bfloat16, + ], f"x_scales needs to be a torch.float32 or torch.bfloat16 but got {x_scales.dtype}" + y = (y_dot_int32 * x_scales.view(-1, 1) * w_scales).reshape( + *x_vals_int8.shape[:-1], -1 + ) + + # can downcast only at the very end + y = y.to(output_dtype) + return y diff --git a/ao/quantization/smoothquant.py b/ao/quantization/smoothquant.py new file mode 100644 index 0000000000..80ce7893fd --- /dev/null +++ b/ao/quantization/smoothquant.py @@ -0,0 +1,237 @@ +""" +Testing out accuracy-only implementation of SmoothQuant +(https://arxiv.org/pdf/2211.10438.pdf) +Note: this is an application of input-weight equalization, with the addition that the +multiplication by scale is fused into the preceding layer, specifically for relevant +parts of transformer blocks. +""" + +import torch +import torch.nn.functional as F +import quant_api + +from quant_primitives import ( + dynamically_quantize_per_channel, + quant_int8_dynamic_per_token_linear, +) + +__all__ = [ + "get_scale", + "SmoothFakeDynQuantMixin", + "SmoothFakeDynamicallyQuantizedLinear", + "swap_linear_with_smooth_fq_linear", + "smooth_fq_linear_to_inference", + "set_smooth_fq_attribute", +] + +def get_scale(X_absmax, W_absmax, alpha=0.5): + """ + Calculate the scale based on abs(max(X)), abs(max(W)) and alpha + If X is of dimension `b*n*k` and W is dimension `k*m`, the returned + scale is of dimension `k`. + Note: X_absmax is calculated outside of this function because we + need to keep a running version of it during calibration. W_absmax + is calculated outside of this function for consistency with X_absmax. + """ + X_pow = torch.pow(X_absmax, alpha) + W_pow = torch.pow(W_absmax, 1.0 - alpha) + div = X_pow / W_pow + return div.reshape(-1) + + +class SmoothFakeDynQuantMixin(torch.nn.Module): + def init_smoothquant_variables(self, alpha): + self.calibrating = True + self.x_running_abs_max = None + self.register_buffer("smooth_scale", None) + self.alpha = alpha + # debug only + self.debug_skip_scaling = False + # self.debug_skip_scaling = True + + # Currently torch._int_mm cuBLAS underlying kernel does not work with + # non-contiguous weight. However, torch.compil'ing through + # torch._int_mm leads to triton code which is ~2x faster if the weight + # is transposed. So, for now we have a debug flag to toggle whether + # we store the quantized weight transposed, so that we can get correct + # numerics both in eager mode and after torch.compile. + # The default is True for cuBLAS / eager mode, set to False for + # torch.compile. + # self.store_w_int_repr_t = True + self.store_w_int_repr_t = False + + def update_x_running_abs_max(self, X): + # update the running max of incoming activations + all_dims_except_last = tuple(range(len(X.shape) - 1)) + cur_abs_max = torch.amax(torch.abs(X), dim=all_dims_except_last) + if self.x_running_abs_max is None: + self.x_running_abs_max = cur_abs_max + else: + self.x_running_abs_max = torch.max(cur_abs_max, self.x_running_abs_max) + + def get_scaled_quantized_w(self): + # inference + assert ( + self.smooth_scale is not None + ), "self.smooth_scale is None, did you turn on inference?" + W = self.weight + + # scale weight + # in the future, this can be done ahead of time instead of + # during inference + if not self.debug_skip_scaling: + # TODO(future): do below in `to_inference` instead of here + W = torch.matmul( + torch.diag(self.smooth_scale), W.transpose(0, 1) + ).transpose(0, 1) + + # fake quantize input and weight, and then do matmul in fp32/fp16 + # in the future, this should be replaced with quantized kernels which + # work on NVIDIA GPUs (such as protoquant's implementation) + W_dq_dtype = W.dtype + W_int_repr, W_scales, W_zps = dynamically_quantize_per_channel( + W, -128, 127, torch.int8 + ) + W_int_repr = W_int_repr.contiguous() + return W_int_repr, W_scales, W_zps + + def to_inference(self): + raise NotImplementedError() + + def fold_weight(self): + # note: _W_zps are zeroes and they are ignored + # TODO(future PR): set up serialization for this + W_int_repr, self.W_scales, _W_zps = self.get_scaled_quantized_w() + # need to store transposed weights to make eager mode matmul + # op work in cuBlas, or non-transposed to make it fast in torch.compile + if self.store_w_int_repr_t: + self.register_buffer("W_int_repr", W_int_repr.transpose(0, 1).contiguous()) + else: + self.register_buffer("W_int_repr", W_int_repr.contiguous()) + del self.weight + + def set_debug_x_absmax(self): + """ + Sets `self.x_running_abs_max` to a value which will lead to smooth scale + of all ones if `alpha=0.5`, to enable performance benchmarking without + calibration. + """ + raise NotImplementedError() + + +class SmoothFakeDynamicallyQuantizedLinear(SmoothFakeDynQuantMixin, torch.nn.Linear): + """ + This is a replacement for `torch.nn.Linear` which implements fake quantization + based on Smoothquant scaling. + """ + + def __init__(self, *args, **kwargs): + alpha = kwargs.pop("alpha") + super().__init__(*args, **kwargs) + self.init_smoothquant_variables(alpha) + + def forward(self, X): + if self.calibrating: + self.update_x_running_abs_max(X) + Y = F.linear(X, self.weight, self.bias) + else: + if not self.debug_skip_scaling: + # TODO(future): fuse this into previous layer (LayerNorm, + # RMSNorm, etc) where appropriate + X = X / self.smooth_scale + W_int_repr_t = ( + self.W_int_repr if self.store_w_int_repr_t else self.W_int_repr.t() + ) + Y = quant_int8_dynamic_per_token_linear( + X, W_int_repr_t, self.W_scales, self.bias, X.dtype + ) + return Y + + @classmethod + def from_float(cls, mod, alpha=0.5): + """ + Converts a `mod` of class `torch.nn.Linear` to the smooth fake quantized + version of it. Note: requires calibration. + """ + # create the new module with a toy size to ensure initialization is fast + fake_in_features, fake_out_features = 8, 8 + new_mod = cls( + fake_in_features, fake_out_features, bias=mod.bias is not None, alpha=alpha + ) + new_mod.in_features = mod.in_features + new_mod.out_features = mod.out_features + new_mod.weight = mod.weight + new_mod.bias = mod.bias + # TODO: test when creation is on cuda + device_to_use = next(mod.parameters()).device + new_mod.to(device_to_use) + return new_mod + + def to_inference(self): + """ + Calculates the smoothquant scale based on calibration + in preparation for inference + """ + assert self.x_running_abs_max is not None, "no calibration data found" + self.calibrating = False + self.smooth_scale = get_scale( + self.x_running_abs_max, + torch.max(torch.abs(self.weight.transpose(0, 1)), dim=1).values, + alpha=self.alpha, + ) + self.fold_weight() + + def set_debug_x_absmax(self): + w_absmax = torch.max(torch.abs(self.weight.transpose(0, 1)), dim=1).values + self.x_running_abs_max = w_absmax + +# +# utils to use the smooth linear on real models +# + +source_cls_to_target_cls = { + torch.nn.Linear: SmoothFakeDynamicallyQuantizedLinear, +} + + +def swap_linear_with_smooth_fq_linear( + model, skip_fqn_list=None, cur_fqn="", alpha=0.5 +) -> None: + name_to_child = dict(model.named_children()) + for name, child in name_to_child.items(): + if cur_fqn == "": + new_fqn = name + else: + new_fqn = f"{cur_fqn}.{name}" + if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and isinstance( + child, tuple(source_cls_to_target_cls.keys()) + ): + target_cls = source_cls_to_target_cls[type(child)] + new_child = target_cls.from_float(child, alpha=alpha) + setattr(model, name, new_child) + else: + swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha) + + +# code moved, avoid breaking callsites +# TODO clean this up +replace_with_custom_fn_if_matches_filter = ( + quant_api.replace_with_custom_fn_if_matches_filter +) + + +def smooth_fq_linear_to_inference(model, debug_skip_calibration=False) -> None: + for _, mod in model.named_modules(): + if isinstance(mod, tuple(source_cls_to_target_cls.values())): + if debug_skip_calibration: + mod.set_debug_x_absmax() + mod.to_inference() + + +# useful for quickly toggling smoothquant debug settings on all smoothquant +# modules in a model +def set_smooth_fq_attribute(model, attribute_name, new_attribute_val): + for _, mod in model.named_modules(): + if isinstance(mod, tuple(source_cls_to_target_cls.values())): + if hasattr(mod, attribute_name): + setattr(mod, attribute_name, new_attribute_val) diff --git a/ao/quantization/subclass.py b/ao/quantization/subclass.py new file mode 100644 index 0000000000..48959df922 --- /dev/null +++ b/ao/quantization/subclass.py @@ -0,0 +1,117 @@ +import torch +from quant_primitives import ( + dequantize_per_channel, + dynamically_quantize_per_channel, + quant_int8_dynamic_per_token_linear, +) +from torch.utils._python_dispatch import return_and_correct_aliasing + +__all__ = ["DynamicallyQuantizedLinearWeight"] + + +class DynamicallyQuantizedLinearWeight(torch.Tensor): + @staticmethod + def __new__(cls, input_data, q_scales, transposed=False, **kwargs): + # input data is assumed to be input so that q_axis is the 1th axis + # also assumes input is non contiguous + kwargs["device"] = input_data.device + kwargs["dtype"] = kwargs.get("dtype", torch.int8) + if input_data is not None: + kwargs["dtype"] = input_data.dtype + size = input_data.shape[::-1] if transposed else input_data.shape + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else input_data.layout + ) + return torch.Tensor._make_wrapper_subclass(cls, size, **kwargs) # type: ignore[attr-defined] + + def __init__(self, input_data, q_scales, transposed=False): + self.transposed = transposed + self.int_data = input_data + self.q_scales = q_scales + + def __repr__(self): + return f"DynamicallyQuantizedLinearWeight(shape={self.shape}, data={self.dequantize()})" + + def dequantize(self, dtype=None): + out = dequantize_per_channel( + self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype + ) + return out if self.transposed else out.t() # already transposedd for dequantize + + def int_repr(self): + return self.int_data.t() if self.transposed else self.int_data + + def _detach(self): + return DynamicallyQuantizedLinearWeight( + self.int_data, self.q_scales, transposed=self.transposed + ) + + def _transposed(self): + return DynamicallyQuantizedLinearWeight( + self.int_data, self.q_scales, transposed=(not self.transposed) + ) + + def __tensor_flatten__(self): + return ["int_data", "q_scales"], self.transposed + + @staticmethod + def __tensor_unflatten__(tensor_data, transposed): + int_data, q_scales = tensor_data["int_data"], tensor_data["q_scales"] + return DynamicallyQuantizedLinearWeight( + int_data, q_scales, transposed=transposed + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # two scenarios where we currently fall back to vanilla mm: + # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation + # for consistency and to allow people to test + # 2 - we need to define what happens when we're given non-floats - quantizing long to int8 is probs craxy + if ( + func in [torch.ops.aten.mm.default, torch.ops.aten.addmm.default] + and args[0].is_floating_point() + and args[0].is_cuda + ): + if func == torch.ops.aten.addmm.default: + assert ( + args[1].shape[-1] == args[2].shape[0] + ), f"need mat1 shape: {args[1].shape} final dim to match mat2 shape: {args[2].shape} first dim " + mat1, mat2, scales, bias = ( + args[1], + args[2].int_data, + args[2].q_scales, + args[0], + ) + else: + assert ( + args[0].shape[-1] == args[1].shape[0] + ), f"need mat1 shape: {args[0].shape} final dim to match mat2 shape: {args[1].shape} first dim " + mat1, mat2, scales, bias = ( + args[0], + args[1].int_data, + args[1].q_scales, + None, + ) + return quant_int8_dynamic_per_token_linear( + mat1, mat2, scales, bias, mat1.dtype + ) + + if func is torch.ops.aten.detach.default: + return return_and_correct_aliasing(func, args, kwargs, args[0]._detach()) + + if func is torch.ops.aten.t.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._transposed() + ) + breakpoint() + return NotImplemented + + @classmethod + def from_float(cls, input_float, qmin=-128, qmax=127, dtype=torch.int8): + w_int_repr, w_scales, _ = dynamically_quantize_per_channel( + input_float, qmin, qmax, dtype + ) + # always store with quantized axis in dim=1 for fast matmul + return cls(w_int_repr.contiguous().t(), w_scales, transposed=True) diff --git a/ao/quantization/test.py b/ao/quantization/test.py new file mode 100644 index 0000000000..53fdffda3f --- /dev/null +++ b/ao/quantization/test.py @@ -0,0 +1,1024 @@ +# mypy: ignore-errors +import copy +import unittest + +import torch +import torch.nn as nn +from torch._inductor.utils import run_and_get_code + +from torch.ao.quantization import MinMaxObserver, QConfigMapping + +from dynamic_quant import ( + DynamicallyPerAxisQuantizedLinear, +) +from quant_api import ( + apply_dynamic_quant, + apply_weight_only_int8_quant, + change_linear_weights_to_dqtensors, +) +from quant_primitives import ( + dequantize_per_channel, + dequantize_per_tensor, + dynamically_quantize_per_channel, + dynamically_quantize_per_tensor, + quant_int8_dynamic_linear, + quant_int8_dynamic_per_token_linear, + quantize_activation_per_token_absmax, + safe_int_mm, +) + +from smoothquant import ( + get_scale, + replace_with_custom_fn_if_matches_filter, + smooth_fq_linear_to_inference, + SmoothFakeDynamicallyQuantizedLinear, + swap_linear_with_smooth_fq_linear, +) +from subclass import ( + DynamicallyQuantizedLinearWeight, +) +from utils import ( + apply_logging_hook, + compute_error, + compute_error as SQNR, + fqn_to_op_to_shape_to_count, + LoggingTensorMode, +) +from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx + +torch.manual_seed(0) + + +class SmoothquantUnitTest(unittest.TestCase): + # first, let's reproduce the graphic from the paper, Figure 4, to ensure + # we are calculating the scales correctly + def test_figure_4(self): + X = torch.FloatTensor([1, -16, 2, 6, -2, 8, -1, -9]).reshape(1, 2, 4) + W = torch.FloatTensor([2, 1, -2, 1, -1, -1, 2, -1, -2, -1, -1, 1]).reshape(4, 3) + X_mul_W = torch.matmul(X, W) + + smoothquant_scale = get_scale( + torch.amax(torch.abs(X), dim=(0, 1)), + torch.amax(torch.abs(W), dim=1), + alpha=0.5, + ) + + # reproduce scaled calculation + X_scaled = X / smoothquant_scale.reshape(1, 1, -1) + W_scaled = torch.matmul(torch.diag(smoothquant_scale), W) + X_scaled_mul_scaled_W = torch.matmul(X_scaled, W_scaled) + assert torch.allclose(X_mul_W, X_scaled_mul_scaled_W), "not close!" + assert X_mul_W.shape == X_scaled_mul_scaled_W.shape + + # next, run the above test on a sample of representative inputs + def test_tensors(self): + x_shape = (1, 5, 7) + w_shape = (7, 9) + for i in range(3): + X = torch.randn(x_shape) * 10 + W = torch.randn(w_shape) + s = get_scale( + torch.amax(torch.abs(X), dim=(0, 1)), + torch.amax(torch.abs(W), dim=1), + alpha=0.5, + ) + + Y = torch.matmul(X, W) + Y_ref = torch.matmul( + X / s.reshape(1, 1, -1), + torch.matmul(torch.diag(s), W), + ) + assert torch.allclose(Y, Y_ref, atol=1e-3, rtol=1e-3), "not close!" + + def _test_smooth_linear_impl(self, x_shape, lin_shape, device): + # so we can use the full range + torch.backends.quantized.engine = "qnnpack" + + x = torch.randn(*x_shape, device=device) * 9 + 10 + + lin_fp32 = nn.Linear(*lin_shape, device=device) # misc: ignore + lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( + copy.deepcopy(lin_fp32), alpha=0.25 + ) + lin_smooth_skip_scaling = SmoothFakeDynamicallyQuantizedLinear.from_float( + copy.deepcopy(lin_fp32), alpha=0.25 + ) + + lin_fp32_copy = copy.deepcopy(lin_fp32) # assignment: ignore + lin_fp32_copy.qconfig = torch.ao.quantization.QConfig( # assignment: ignore + activation=None, + weight=torch.ao.quantization.default_per_channel_weight_observer, + ) + lin_dynamic_q = torch.ao.nn.quantized.dynamic.Linear.from_float( + lin_fp32_copy.cpu() + ) + + y_ref = lin_fp32(x) + + # calibrate the smoothquant versions + y_smooth_nocalib = lin_smooth(x) + _ = lin_smooth_skip_scaling(x) + lin_smooth.to_inference() + lin_smooth_skip_scaling.debug_skip_scaling = True + lin_smooth_skip_scaling.to_inference() + + # verify that with scaling turned off, numerics match quantized version + y_smooth_fq_only = lin_smooth_skip_scaling(x) + y_smooth_fq = lin_smooth(x) + y_dynamic_q = lin_dynamic_q(x.cpu()).to(device) + + # print('y_ref', y_ref) + # print('y_smooth_nocalib', y_smooth_nocalib) + # print('y_smooth_fq', y_smooth_fq) + # print('y_smooth_fq_only', y_smooth_fq_only) + # print('y_dynamic_q', y_dynamic_q) + + sqnr_smooth_fq = compute_error(y_ref, y_smooth_fq) + sqnr_dynamic_q = compute_error(y_ref, y_dynamic_q) + sqnr_fq = compute_error(y_smooth_fq_only, y_dynamic_q) + # print('sqnr_smooth', sqnr_smooth_fq, 'sqnr_dynamic', sqnr_dynamic_q, 'sqnr_fq', sqnr_fq) + + assert torch.allclose( + y_ref, y_smooth_nocalib + ), "y_ref not close to y_smooth_nocalib" + # after https://github.com/pytorch-labs/ao_benchmarks/pull/32, + # numerics do not match exactly between production c++ code + # and this Python code + # assert torch.allclose( + # y_smooth_fq_only, y_dynamic_q, + # atol=torch.max(y_smooth_fq_only).item()*0.01, + # rtol=0.00001), \ + # 'y_smooth_fq_only not close to y_dynamic_q' + + self.assertTrue(sqnr_smooth_fq.item() >= 40.0) + self.assertTrue(sqnr_dynamic_q.item() >= 40.0) + self.assertTrue(sqnr_fq.item() >= 40.0) + + def test_smooth_linear_cpu(self): + self._test_smooth_linear_impl((1, 5, 3), (3, 4), "cpu") + + def test_smooth_linear_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + self._test_smooth_linear_impl((1, 32, 32), (32, 16), "cuda") + + def test_smooth_linear_edge_cases(self): + # so we can use the full range + torch.backends.quantized.engine = "qnnpack" + lin_fp32 = nn.Linear(3, 4) + lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( + lin_fp32, alpha=0.25 + ) + + # test different ranks + x0 = torch.randn(4, 5, 3) + x1 = torch.randn(1, 8, 5, 3) + x2 = torch.randn(2, 3, 7, 5, 3) + + # calibrate + _ = lin_smooth(x0) + _ = lin_smooth(x1) + _ = lin_smooth(x2) + + # inference + lin_smooth.to_inference() + _ = lin_smooth(x0) + _ = lin_smooth(x1) + _ = lin_smooth(x2) + + def test_swap(self): + m = nn.Sequential( + nn.Sequential(nn.Linear(4, 4), nn.ReLU(), nn.Linear(4, 4)), + nn.Linear(4, 4), + ) + m_copy = copy.deepcopy(m) + swap_linear_with_smooth_fq_linear(m_copy, skip_fqn_list=["0.2"]) + + # verify all linears are swapped + assert isinstance(m_copy[0][0], SmoothFakeDynamicallyQuantizedLinear) + assert isinstance(m_copy[0][1], nn.ReLU) + # this one was skipped + assert isinstance(m_copy[0][2], nn.Linear) + assert isinstance(m_copy[1], SmoothFakeDynamicallyQuantizedLinear) + + # verify results do not change without smoothing + x = torch.randn(4, 4) + y_ref = m(x) + y = m_copy(x) + assert torch.allclose(y_ref, y) + + def test_weight_t_and_non_t_numerics_match(self): + # verify that numerics match whether weight is stored + # in transposed format (for cuBLAS) vs non-transposed format + # (for torch.compile) + if not torch.cuda.is_available(): + print("no cuda, skip") + return + dtype = torch.half + device = "cuda" + lin_ref = nn.Linear(32, 16, dtype=dtype, device=device) + lin_eager_t = copy.deepcopy(lin_ref) + lin_opt_t = copy.deepcopy(lin_eager_t) + lin_opt = copy.deepcopy(lin_eager_t) + lin_eager_t = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_eager_t) + lin_opt_t = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_opt_t) + lin_opt = SmoothFakeDynamicallyQuantizedLinear.from_float(lin_opt) + lin_opt.store_w_int_repr_t = False + + x = torch.randn(32, 32, dtype=dtype, device=device) + + y_calib_eager_t = lin_eager_t(x) + y_calib_opt_t = lin_opt_t(x) + y_calib_opt = lin_opt(x) + torch.testing.assert_close(y_calib_eager_t, y_calib_opt_t) + torch.testing.assert_close(y_calib_eager_t, y_calib_opt) + + lin_eager_t.to_inference() + lin_opt_t.to_inference() + lin_opt.to_inference() + + torch.testing.assert_close(lin_eager_t.W_int_repr, lin_opt_t.W_int_repr) + torch.testing.assert_close(lin_eager_t.W_int_repr, lin_opt.W_int_repr) + + lin_opt_t = torch.compile(lin_opt_t, mode="max-autotune") + lin_opt = torch.compile(lin_opt, mode="max-autotune") + + y_ref = lin_ref(x) + y_eager = lin_eager_t(x) + y_opt_t = lin_opt_t(x) + y_opt = lin_opt(x) + + if not torch.any(torch.isinf(y_ref)) and torch.any(torch.isinf(y_eager)): + # eager mode torch._int_mm is sometimes buggy, when this happens + # we can't really compare the compiled version against it properly + print("eager mode torch._int_mm known bad, test is inconclusive") + return + + sqnr_ref_eager = compute_error(y_ref, y_eager) + sqnr_eager_opt_t = compute_error(y_eager, y_opt_t) + sqnr_eager_opt = compute_error(y_eager, y_opt) + # since torch.compile for a torch.half model can + # change numerics significantly, we can only test for a high SQNR here + # and not for closeness + self.assertTrue(sqnr_eager_opt_t >= 45.0) + self.assertTrue(sqnr_eager_opt >= 45.0) + # y_opt_t and y_opt should be equivalent + torch.testing.assert_close(y_opt_t, y_opt) + + def test_selective_torch_compile(self): + m = nn.Sequential( + nn.Linear(4, 4), + nn.Sequential( + nn.Linear(4, 4), + nn.Linear(4, 4), + ), + nn.Linear(4, 4), + ) + x = torch.randn(4, 4) + y_ref = m(x) + + replace_with_custom_fn_if_matches_filter( + m, + lambda mod: torch.compile(mod), + lambda mod, fqn: isinstance(mod, nn.Linear) and fqn != "1.0", + ) + + self.assertTrue(isinstance(m[0], torch._dynamo.eval_frame.OptimizedModule)) + self.assertTrue(isinstance(m[1][0], nn.Linear)) + self.assertTrue(isinstance(m[1][1], torch._dynamo.eval_frame.OptimizedModule)) + self.assertTrue(isinstance(m[2], torch._dynamo.eval_frame.OptimizedModule)) + + y = m(x) + torch.testing.assert_close(y, y_ref) + + def test_debug_x_absmax(self): + m = nn.Sequential(nn.Linear(3, 4)) + x0 = torch.randn(4, 5, 3) + y0 = m(x0) + swap_linear_with_smooth_fq_linear(m) + # no calibration, straight to inference, should not crash + smooth_fq_linear_to_inference(m, debug_skip_calibration=True) + y1 = m(x0) + + +class PythonQuantPrimitivesUnitTest(unittest.TestCase): + def _test_dynamic_quant_per_tensor_numerics_impl( + self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device, qscheme + ): + x = torch.randn(256, dtype=float_dtype, device=device) + y_vals, y_scale, y_zero_point = dynamically_quantize_per_tensor( + x, qmin, qmax, int_dtype, qscheme + ) + + # reference + # quantize_per_tensor_dynamic doesn't work for half, so we cast there and back + x_for_ref = x.half().float() if float_dtype == torch.float16 else x + + # quantize_per_tensor_dynamic doesn't support qscheme, so we just do dynamic + # quant manually with observers + static quant + obs = MinMaxObserver( + dtype=qint_dtype, qscheme=qscheme, quant_min=qmin, quant_max=qmax + ).to(device) + obs(x_for_ref) + ref_scale, ref_zero_point = obs.calculate_qparams() + y_ref = torch.quantize_per_tensor( + x_for_ref, ref_scale, ref_zero_point, qint_dtype + ) + + # y_ref = torch.quantize_per_tensor_dynamic(x_for_ref, qint_dtype, False) + # print(y_ref) + if float_dtype == torch.float: + assert torch.equal(y_vals, y_ref.int_repr()) + else: + # numerics are not exactly aligned yet, off-by-one probably due + # to rounding + assert torch.max(torch.abs(y_vals - y_ref.int_repr())).item() <= 1 + torch.testing.assert_close( + y_scale, torch.tensor([y_ref.q_scale()], device=device, dtype=float_dtype) + ) + if y_zero_point is not None: + assert torch.equal( + y_zero_point, torch.tensor([y_ref.q_zero_point()], device=device) + ) + else: + self.assertTrue(y_ref.q_zero_point() == 0) + + # dequantize and check again + x_dq = dequantize_per_tensor(y_vals, y_scale, y_zero_point, float_dtype) + y_ref_dq = y_ref.dequantize().to(float_dtype) + if float_dtype == torch.float: + torch.testing.assert_close(x_dq, y_ref_dq) + else: + sqnr = compute_error(x_dq, y_ref_dq) + self.assertTrue(sqnr.item() > 45.0) + + def test_dynamic_quant_per_tensor_numerics_cpu(self): + # verifies that dynamic quant per tensor in plain pytorch matches + # numerics of production AO code + # TODO(future): test this on cpu-half, need to first make + # torch.aminmax support half on cpu + test_cases = ( + ( + 0, + 255, + torch.uint8, + torch.quint8, + torch.float32, + "cpu", + torch.per_tensor_affine, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cpu", + torch.per_tensor_affine, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cpu", + torch.per_tensor_symmetric, + ), + ( + -127, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cpu", + torch.per_tensor_symmetric, + ), + ) + for row in test_cases: + self._test_dynamic_quant_per_tensor_numerics_impl(*row) + + def test_dynamic_quant_per_tensor_numerics_cuda(self): + # verifies that dynamic quant per tensor in plain pytorch matches + # numerics of production AO code + if not torch.cuda.is_available(): + print("no cuda, skip") + return + test_cases = ( + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + torch.per_tensor_affine, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + torch.per_tensor_affine, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + torch.per_tensor_symmetric, + ), + ( + -128, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + torch.per_tensor_symmetric, + ), + ( + -127, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + torch.per_tensor_symmetric, + ), + ( + -127, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + torch.per_tensor_symmetric, + ), + ) + for row in test_cases: + self._test_dynamic_quant_per_tensor_numerics_impl(*row) + + def _test_dynamic_quant_per_channel_numerics_impl( + self, qmin, qmax, int_dtype, qint_dtype, float_dtype, device + ): + # verifies that dynamic quant per channel in plain pytorch matches + # numerics of production AO code + # TODO(future): test this on cpu-half, need to first make + # torch.aminmax support half on cpu + + x = torch.randn(16, 32, device=device, dtype=float_dtype) + y_vals, y_scale, y_zero_point = dynamically_quantize_per_channel( + x, qmin, qmax, int_dtype + ) + + min_val, max_val = torch.aminmax(x, dim=1) + + # reference + weight_obs = torch.ao.quantization.MovingAveragePerChannelMinMaxObserver( + dtype=qint_dtype, + quant_min=qmin, + quant_max=qmax, + qscheme=torch.per_channel_symmetric, + averaging_constant=1.0, # make it ignore previous iterations + ) + weight_obs(x) + y_ref_scale, y_ref_zp = weight_obs.calculate_qparams() + y_ref_scale = y_ref_scale.to(device) + y_ref_zp = y_ref_zp.to(device) + # quantize_per_channel doesn't work for half, so we cast there and back + x_for_ref = x.half().float() if float_dtype == torch.float16 else x + y_ref = torch.quantize_per_channel( + x_for_ref, y_ref_scale, y_ref_zp, 0, qint_dtype + ) + + torch.testing.assert_close( + y_scale, y_ref.q_per_channel_scales().to(float_dtype) + ) + assert torch.equal(y_zero_point, y_ref.q_per_channel_zero_points()) + # this test case has one element where the rounding is off by one + # from Python-only code vs the c++ code, it's easy to repro with + # various shapes. + # Discussion here is relevant: https://github.com/pytorch/pytorch/issues/16498 + # TODO(future): figure out what to do about this + # assert torch.equal(int_vals, q_reference.int_repr()) + assert torch.max(torch.abs(y_vals - y_ref.int_repr())) <= 1 + + # dequantize + x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point) + x_ref_dq = y_ref.dequantize() + # off-by-one for scale is okay + torch.testing.assert_close( + x_dq, x_ref_dq, atol=torch.max(y_scale).item() * 1.01, rtol=0.0001 + ) + + def test_dynamic_quant_per_channel_numerics_cpu(self): + test_cases = ((-128, 127, torch.int8, torch.qint8, torch.float32, "cpu"),) + for row in test_cases: + self._test_dynamic_quant_per_channel_numerics_impl(*row) + + def test_dynamic_quant_per_channel_numerics_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + test_cases = ( + (-128, 127, torch.int8, torch.qint8, torch.float32, "cuda"), + (-128, 127, torch.int8, torch.qint8, torch.float16, "cuda"), + ) + for row in test_cases: + self._test_dynamic_quant_per_channel_numerics_impl(*row) + + def _test_quantize_per_token_impl(self, device, dtype): + x = torch.randn(3, 3, 3, device=device, dtype=dtype) + xq, scales = quantize_activation_per_token_absmax(x) + x_dq = dequantize_per_tensor(xq, scales, None).to(x.dtype) + sqnr = compute_error(x, x_dq) + self.assertTrue(sqnr >= 45.0) + + def test_quantize_per_token_cpu(self): + for dtype in (torch.float32, torch.float16, torch.bfloat16): + self._test_quantize_per_token_impl("cpu", dtype) + + def test_quantize_per_token_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + for dtype in (torch.float32, torch.float16, torch.bfloat16): + self._test_quantize_per_token_impl("cuda", dtype) + + def _test_per_token_linear_impl(self, device, dtype): + x = torch.randn(2, 16, 8, device=device, dtype=dtype) + w = torch.randn(16, 8, device=device, dtype=dtype) + wq, w_scales, _w_zp = dynamically_quantize_per_channel(w, -127, 127, torch.int8) + # Note: need to make the weight contiguous because we are + # testing in eager mode and cuBlas will not give correct results + # for a transposed weight + y = quant_int8_dynamic_per_token_linear( + x, wq.t().contiguous(), w_scales, None, dtype + ) + y_ref = torch.matmul(x, w.t()) + sqnr = compute_error(y_ref, y) + self.assertTrue(sqnr >= 42.0) + + def test_per_token_linear_cpu(self): + for dtype in (torch.float32,): + self._test_per_token_linear_impl("cpu", dtype) + + def test_per_token_linear_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + for dtype in (torch.float32, torch.float16, torch.bfloat16): + self._test_per_token_linear_impl("cuda", dtype) + + def test__int_mm(self): + # TODO(future): figure out what here needs to move to PT core, + # if it's not already tested there + if not torch.cuda.is_available(): + print("no cuda, skip") + return + + m, k, n = 32, 32, 16 + x = torch.randint(-128, 127, (m, k), dtype=torch.int8, device="cuda") + w = torch.randint(-128, 127, (k, n), dtype=torch.int8, device="cuda") + + y_ref = torch.matmul(x.float(), w.float()).to(torch.int32) + y_raw = safe_int_mm(x, w) + + wrap_in_mm_opt = torch.compile(safe_int_mm, mode="max-autotune") + # note: triton chokes on the line below on k == 8 and n == 8 with + # https://www.internalfb.com/phabricator/paste/view/P683467944 + # TODO(future): file an issue + y_opt = wrap_in_mm_opt(x, w) + + torch.testing.assert_close(y_ref, y_raw, atol=0, rtol=0) + torch.testing.assert_close(y_ref, y_opt, atol=0, rtol=0) + + def test__int_mm_eager_and_torch_compile_numerics(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + + def __int_mm_ref(x, w): + x = x.cpu().to(torch.int32) + w = w.cpu().to(torch.int32) + y = torch.matmul(x, w) + return y.cuda() + + shapes = ( + # minimal test shape + ((1, 32, 32), (32, 16)), + # paste of real linear shapes from LLaMa 1.5b + ((17, 1, 1536), (1536, 1536)), + ((17, 8, 4096), (4096, 1536)), + ((17, 1, 1536), (1536, 4096)), + ((17, 8, 1536), (1536, 1536)), + ((17, 1, 4096), (4096, 1536)), + ((17, 8, 1536), (1536, 4096)), + ) + + for x_shape, w_shape in shapes: + + def wrap_torch_int_mm(x, w): + b, n, k = x.shape + k, m = w.shape + x = x.reshape(b * n, k) + res = safe_int_mm(x, w) + res = res.reshape(b, n, m) + return res + + wrap_torch_int_mm_opt = torch.compile( + wrap_torch_int_mm, mode="max-autotune" + ) + + x = torch.randint(-128, 127, x_shape, dtype=torch.int8, device="cuda") + w = torch.randint(-128, 127, w_shape, dtype=torch.int8, device="cuda") + + z_ref = __int_mm_ref(x, w) + z_eager = wrap_torch_int_mm(x, w) + z_torch_compile = wrap_torch_int_mm_opt(x, w) + # print(z_ref) + # print(z_eager) + # print(z_torch_compile) + + torch.testing.assert_close(z_ref, z_eager, atol=0, rtol=0) + torch.testing.assert_close(z_ref, z_torch_compile, atol=0, rtol=0) + + def _test_qlinear_per_channel_numerics( + self, x_shape, lin_shape, qmin, qmax, int_dtype, qint_dtype, float_dtype, device + ): + qconfig = torch.ao.quantization.per_channel_dynamic_qconfig + + x = torch.randn(*x_shape, device=device, dtype=float_dtype) + + # TODO: test bias true and false + # Note: reference path only works on float because lack of aten quant primitives + # support of half, so we cast back and forth to emulate + lin_ref = ( + nn.Sequential(nn.Linear(*lin_shape)) + .eval() + .to(float_dtype) + .float() + .to(device) + ) + y_ref = lin_ref(x.float()) + weight = lin_ref[0].weight + bias = lin_ref[0].bias + + qconfig_mapping = QConfigMapping().set_global(qconfig) + lin_ref_p = prepare_fx(lin_ref, qconfig_mapping, (torch.randn(1, 1),)) + lin_ref_q = convert_to_reference_fx(lin_ref_p) + y_q_ref = lin_ref_q(x.float()) + + # scale, zp of weight (get from reference model) + w_obs = qconfig.weight() + w_obs(weight) + lin_ref_w_scale, lin_ref_w_zp = w_obs.calculate_qparams() + lin_ref_w_scale = lin_ref_w_scale.to(device).to(float_dtype) + # print('lin_ref_w', 'scale', lin_ref_w_scale, 'zp', lin_ref_w_zp) + + w_vals, _s, _z = dynamically_quantize_per_channel( + getattr(lin_ref_q, "0").weight.to(float_dtype), -128, 127, torch.int8 + ) + w_vals = w_vals.t().contiguous() + w_vals_sums = w_vals.sum(dim=0) + + # do our version of the quantized linear operator + y = quant_int8_dynamic_linear( + x, + qmin, + qmax, + int_dtype, + w_vals, + lin_ref_w_scale, + w_vals_sums, + bias, + float_dtype, + ) + + # print('y', y) + # print('y_q_ref', y_q_ref) + # print('y_ref', y_ref) + + sqnr_ref = compute_error(y_ref, y_q_ref) + sqnr_our = compute_error(y_ref, y) + # print('sqnr_ref', sqnr_ref, 'sqnr_our', sqnr_our) + # for large shapes, sqnr can be in the high 30s for float32 and float16 + self.assertTrue(sqnr_our.item() >= 37.5) + + def test_qlinear_per_channel_numerics_cpu(self): + # Note: the AO codebase doesn't easily support qint8 activations, + # so the test cases below are for the quant primitives defined in + # this file only. The AO reference is using quint8 here. + test_cases = ( + ((2, 3), (3, 4), 0, 255, torch.uint8, torch.quint8, torch.float32, "cpu"), + ((2, 3), (3, 4), -128, 127, torch.int8, torch.qint8, torch.float32, "cpu"), + ) + for test_case in test_cases: + self._test_qlinear_per_channel_numerics(*test_case) + + def test_qlinear_per_channel_numerics_cuda(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + test_cases = ( + # Note: torch._int_mm needs int8 activations, so we don't test uint8 + # activations on CUDA at all + ( + (32, 32), + (32, 16), + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + ), + ( + (32, 32), + (32, 16), + -128, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + ), + # a large shape from LLaMa 1.5B - currently fails for float16 + ( + (17, 4096), + (4096, 1536), + -128, + 127, + torch.int8, + torch.qint8, + torch.float32, + "cuda", + ), + ( + (17, 4096), + (4096, 1536), + -128, + 127, + torch.int8, + torch.qint8, + torch.float16, + "cuda", + ), + ) + for test_case in test_cases: + self._test_qlinear_per_channel_numerics(*test_case) + + +class TestSubclass(unittest.TestCase): + def test_dq_lin_weight_subclass_aot(self): + m, k, n = 32, 64, 32 + x = torch.randn(m, k, device="cuda", dtype=torch.float32) + lin = torch.nn.Linear(k, n, device="cuda") + + import copy + + linq = DynamicallyPerAxisQuantizedLinear.from_float(copy.deepcopy(lin)) + + ref_f = lin(x) + ref_q = linq(x) + + print(SQNR(ref_f, ref_q), "float to dq") + + lin.weight = torch.nn.Parameter( + DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False + ) + test = lin(x) + print(SQNR(ref_f, test), "float to dq class") + print(SQNR(ref_q, test), "dq to dq class") + assert SQNR(ref_f, test) > 35 + assert SQNR(ref_q, test) > 35 + + lin_comp = torch.compile(lin, backend="aot_eager") + linq_comp = torch.compile(linq, backend="aot_eager") + test_comp = lin_comp(x) + ref_q_comp = linq_comp(x) + print(SQNR(ref_f, test_comp), "float to dq class compiled") + print(SQNR(ref_q_comp, test_comp), "dq compiled to dq class compiled") + assert SQNR(ref_f, test_comp) > 35 + assert SQNR(ref_q_comp, test_comp) > 35 + + def test_dq_lin_weight_subclass_max_autotune(self): + m, k, n = 32, 64, 32 + x = torch.randn(m, k, device="cuda", dtype=torch.float32) + lin = torch.nn.Linear(k, n, device="cuda") + + import copy + + linq = DynamicallyPerAxisQuantizedLinear.from_float(copy.deepcopy(lin)) + + ref_f = lin(x) + ref_q = linq(x) + + print(SQNR(ref_f, ref_q), "float to dq") + + lin.weight = torch.nn.Parameter( + DynamicallyQuantizedLinearWeight.from_float(lin.weight), requires_grad=False + ) + test = lin(x) + print(SQNR(ref_f, test), "float to dq class") + print(SQNR(ref_q, test), "dq to dq class") + assert SQNR(ref_f, test) > 35 + assert SQNR(ref_q, test) > 35 + + lin_comp = torch.compile(lin, mode="max-autotune") + linq_comp = torch.compile(linq, mode="max-autotune") + + test_comp = lin_comp(x) + ref_q_comp = linq_comp(x) + print(SQNR(ref_f, test_comp), "float to dq class compiled") + print(SQNR(ref_q_comp, test_comp), "dq compiled to dq class compiled") + assert SQNR(ref_f, test_comp) > 35 + assert SQNR(ref_q_comp, test_comp) > 35 + + @torch.no_grad() + def test_dq_lin_weight_subclass_max_autotune_api(self): + m, k, n = 32, 64, 32 + x = torch.randn(m, k, device="cuda", dtype=torch.float32) + + mod = nn.Sequential( + nn.Linear(k, n, device="cuda"), nn.ReLU(), nn.Linear(n, n, device="cuda") + ) + change_linear_weights_to_dqtensors(mod) + mod_qc = torch.compile(mod, mode="max-autotune") + mod_qc(x) + mod_qc(x) + + +class TestDynamicQuant(unittest.TestCase): + def test_dynamic_quant(self): + M, K, N = 8, 16, 8 + x = torch.randn(M, K) + m = nn.Sequential(nn.Linear(K, N)) + + y_ref = m(x) + apply_dynamic_quant(m) + y_test = m(x) + + sqnr = compute_error(y_ref, y_test) + self.assertGreater(sqnr, 40.0) + self.assertTrue(isinstance(m[0], DynamicallyPerAxisQuantizedLinear)) + + +class TestWeightOnlyInt8Quant(unittest.TestCase): + def test_weight_only_quant(self): + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + x = torch.randn(*x_shape) + m = nn.Sequential(nn.Linear(4, 5)) + y_ref = m(x) + apply_weight_only_int8_quant(m) + y_wo = m(x) + sqnr = compute_error(y_ref, y_wo) + self.assertGreater(sqnr, 44.0) + + @torch.no_grad() + def test_weight_only_quant_force_mixed_mm(self): + torch._inductor.config.epilogue_fusion = True + torch._inductor.config.force_mixed_mm = True + for x_dtype in [torch.float16, torch.bfloat16, torch.float32]: + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + torch._dynamo.reset() + x = torch.randn(*x_shape).to("cuda").to(x_dtype) + m = nn.Sequential(nn.Linear(4, 5)).to("cuda").to(x_dtype) + y_ref = m(x) + apply_weight_only_int8_quant(m) + m(x) + m_c = torch.compile(m, mode="max-autotune") + y_wo, (code,) = run_and_get_code(m_c, x) + sqnr = compute_error(y_ref, y_wo) + self.assertGreater(sqnr, 43.0) + self.assertTrue("mixed_mm" in code) + + def test_weight_only_quant_use_mixed_mm(self): + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.use_mixed_mm = True + for x_dtype in [torch.float32, torch.float16, torch.bfloat16]: + for x_shape in [[2, 4], [5, 5, 5, 4], [1, 4, 4]]: + torch._dynamo.reset() + x = torch.randn(*x_shape).to("cuda").to(x_dtype) + m = nn.Sequential(nn.Linear(4, 5)).to("cuda").to(x_dtype) + y_ref = m(x) + apply_weight_only_int8_quant(m) + m_c = torch.compile(m, mode="max-autotune") + y_wo, (code,) = run_and_get_code(m_c, x) + sqnr = compute_error(y_ref, y_wo) + self.assertGreater(sqnr, 43.0) + + +class TorchCompileUnitTest(unittest.TestCase): + def test_fullgraph(self): + if not torch.cuda.is_available(): + print("no cuda, skip") + return + lin_fp16 = nn.Linear(32, 16, device="cuda", dtype=torch.float16) + lin_smooth = SmoothFakeDynamicallyQuantizedLinear.from_float( + lin_fp16, alpha=0.25 + ) + + x0 = torch.randn(17, 1, 32, device="cuda", dtype=torch.float16) + + # calibrate + _ = lin_smooth(x0) + + # inference + lin_smooth.to_inference() + + # torch.compile + lin_smooth_opt = torch.compile(lin_smooth, fullgraph=True) + # print(lin_smooth_opt) + + y = lin_smooth_opt(x0) + # print(y) + + +class UtilsUnitTest(unittest.TestCase): + def test_shape_logger(self): + x = torch.randn(4, 4) + + m = nn.Sequential( + nn.Linear(4, 4), + nn.Sequential( + nn.Linear(4, 4), + ), + ) + + apply_logging_hook(m) + with LoggingTensorMode(): + m(x) + m(x) + + for fqn, d1 in fqn_to_op_to_shape_to_count.items(): # noqa: PERF102 + for op, d2 in d1.items(): # noqa: PERF102 + for shape, count in d2.items(): # noqa: PERF102 + # print(fqn, op, shape, count) + pass + + +class SmoothquantIntegrationTest(unittest.TestCase): + @torch.inference_mode() + def test_on_dummy_distilbert(self): + # https://huggingface.co/distilbert-base-uncased#how-to-use + from transformers import ( # type: ignore[import-untyped] + DistilBertModel, + DistilBertTokenizer, + ) + + tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") + model = DistilBertModel.from_pretrained("distilbert-base-uncased") + # print(model) + text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output_ref = model(**encoded_input) + # print(output_ref) + + # + # smooth_quant + # + model_copy = copy.deepcopy(model) + swap_linear_with_smooth_fq_linear(model_copy, alpha=0.75) + # calibrate + output_1_1 = model_copy(**encoded_input) + # inference + smooth_fq_linear_to_inference(model_copy) + output_1_2 = model_copy(**encoded_input) + # print(output_1_1) + # print(output_1_2) + sqnr_sq = compute_error( + output_ref.last_hidden_state, output_1_2.last_hidden_state + ) + print("sqnr_sq", sqnr_sq) + self.assertTrue(sqnr_sq >= 20.0) + + # + # reference - dynamic linear quant + # + model_copy2 = copy.deepcopy(model) + qconfig = torch.ao.quantization.QConfig( + activation=None, + weight=torch.ao.quantization.default_per_channel_weight_observer, + ) + model_copy2 = torch.ao.quantization.quantize_dynamic( + model_copy2, + {torch.nn.Linear: qconfig}, + ) + output_2_2 = model_copy2(**encoded_input) + # print(output_2_2) + sqnr_pt_quant = compute_error( + output_ref.last_hidden_state, output_2_2.last_hidden_state + ) + print("sqnr_pt_quant", sqnr_pt_quant) + self.assertTrue(sqnr_sq >= 8.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/ao/quantization/utils.py b/ao/quantization/utils.py new file mode 100644 index 0000000000..c53395579a --- /dev/null +++ b/ao/quantization/utils.py @@ -0,0 +1,98 @@ +import os +from typing import Dict, Optional + +import torch +from torch.utils._python_dispatch import TorchDispatchMode + +__all__ = [ + "log_with_rank", + "clear_logs", + "compute_error", + "apply_logging_hook", + "get_model_size_in_bytes", +] + + +def log_with_rank(*args): + # append + # + # {thing_to_log} + # + # to {file}_{rank}.txt, for printing stuff from multiple GPUs + if not os.path.exists(log_dir): + os.makedirs(log_dir) + with open(log_fname, "a") as f: + f.write(" ".join([str(s) for s in args]) + "\n") + if local_rank == 0: + print(*args) + + +def clear_logs(): + if os.path.isfile(log_fname): + os.remove(log_fname) + + +# basic SQNR +def compute_error(x, y): + Ps = torch.norm(x) + Pn = torch.norm(x - y) + return 20 * torch.log10(Ps / Pn) + + +# logger for fqn + op + shape +# note: not safe for any kind of multithreading +_cur_fqn: Optional[str] = None + + +def _get_logging_hook(fqn): + def forward_hook(module, input): + global _cur_fqn + _cur_fqn = fqn + + return forward_hook + + +def apply_logging_hook(model): + for name, mod in model.named_modules(): + mod.register_forward_pre_hook(_get_logging_hook(name)) + + +# collections.defaultdict printing is weird with lambdas, so hand writing for now +fqn_to_op_to_shape_to_count: Dict[ + Optional[str], Dict[Optional[str], Dict[Optional[str], int]] +] = {} + + +class LoggingTensorMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + rs = func(*args, **kwargs) + global _cur_fqn + op_name: str = f"{func.__module__}.{func.__name__}" + shape_str = "" + for arg in args: + if isinstance(arg, torch.Tensor): + shape_str += str(list(arg.shape)) + ", " + if shape_str != "": + shape_str = shape_str[:-2] + + if _cur_fqn not in fqn_to_op_to_shape_to_count: + fqn_to_op_to_shape_to_count[_cur_fqn] = {} + if op_name not in fqn_to_op_to_shape_to_count[_cur_fqn]: + fqn_to_op_to_shape_to_count[_cur_fqn][op_name] = {} + if shape_str not in fqn_to_op_to_shape_to_count[_cur_fqn][op_name]: + fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] = 0 + fqn_to_op_to_shape_to_count[_cur_fqn][op_name][shape_str] += 1 + + return rs + + +# https://discuss.pytorch.org/t/finding-model-size/130275 +def get_model_size_in_bytes(model): + s = 0 + for p in model.parameters(): + s += p.nelement() * p.element_size() + for b in model.buffers(): + s += b.nelement() * b.element_size() + return s diff --git a/ao/quantization/weight_only.py b/ao/quantization/weight_only.py new file mode 100644 index 0000000000..0be9c8867b --- /dev/null +++ b/ao/quantization/weight_only.py @@ -0,0 +1,49 @@ +import torch +from quant_primitives import ( + dynamically_quantize_per_channel, +) + +__all__ = ["WeightOnlyInt8QuantLinear"] + + +class WeightOnlyInt8QuantLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + w_int8 = kwargs.pop("w_int8") + scales = kwargs.pop("scales") + super().__init__(*args, **kwargs) + self.w_int8 = w_int8 + self.scales = scales + + def forward(self, x): + # if len(x.shape)<=2: + # y = torch.mm(x, self.w_int8.to(x.dtype)) * self.scales + # else: # turn x into 2d tensor, then undo it for y + x_view = x.view(-1, x.shape[-1]) + y = torch.mm(x_view, self.w_int8.to(x.dtype)) * self.scales + y = y.reshape(*x.shape[:-1], -1) + if self.bias is not None: + y += self.bias + return y + + @classmethod + def from_float(cls, mod): + w_fp32 = mod.weight + w_int8, scales, _zp = dynamically_quantize_per_channel( + w_fp32, -128, 127, torch.int8 + ) + # create the new module with a toy size to ensure initialization is fast + fake_in_features, fake_out_features = 8, 8 + new_mod = cls( + fake_in_features, + fake_out_features, + bias=mod.bias is not None, + w_int8=w_int8.t().contiguous(), + scales=scales, + ) + new_mod.in_features = mod.in_features + new_mod.out_features = mod.out_features + del new_mod.weight + new_mod.bias = mod.bias + device_to_use = next(mod.parameters()).device + new_mod.to(device_to_use) + return new_mod diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..102c90b1da --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +from setuptools import setup, find_packages + +setup( + name='ao', + version='0.1', + packages=find_packages(), + install_requires=[ + 'torch', + ], + description='Package for applying ao techniques to GPU models', + long_description=open('README.md').read(), + long_description_content_type='text/markdown', + url='https://github.com/pytorch-labs/ao', +)