|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from typing import Tuple, Optional |
| 4 | + |
| 5 | +from functools import partial |
| 6 | + |
| 7 | +from torchao.quantization.quant_primitives import ( |
| 8 | + dynamically_quantize_per_channel, |
| 9 | + quant_int8_dynamic_per_token_linear, |
| 10 | + quantize_activation_per_token_absmax |
| 11 | +) |
| 12 | +from torchao.quantization import quant_api |
| 13 | +from torchao.sparsity import apply_fake_sparsity |
| 14 | + |
| 15 | +# Quant + Sparse helper functinos |
| 16 | +def sparse_quant_int8_dynamic_per_token_linear( |
| 17 | + x, |
| 18 | + w_vals_int8, |
| 19 | + w_scales, |
| 20 | + bias, |
| 21 | + out_dtype=torch.float32, |
| 22 | + fuse_dequant=True, |
| 23 | +): |
| 24 | + # like F.linear, but with int8 dynamic quantization of activation, |
| 25 | + # and a quantized weight |
| 26 | + x_vals_int8, x_scales = quantize_activation_per_token_absmax(x) |
| 27 | + mm_out = sparse_quant_int8_per_token_matmul( |
| 28 | + x_vals_int8, x_scales, w_vals_int8, w_scales, out_dtype, fuse_dequant=fuse_dequant) |
| 29 | + if bias is not None: |
| 30 | + mm_out += bias |
| 31 | + return mm_out |
| 32 | + |
| 33 | +def sparse_quant_int8_per_token_matmul( |
| 34 | + x_vals_int8, |
| 35 | + x_scales, |
| 36 | + w_vals_int8, |
| 37 | + w_scales, |
| 38 | + out_dtype=torch.float32, |
| 39 | + fuse_dequant=True, |
| 40 | +): |
| 41 | + # Quantized sparse matmul of int8 operands that accumulates to fp16 and returns |
| 42 | + # out_dtype. This matmul uses cuSPARSELt as a backend. |
| 43 | + |
| 44 | + # Assumes that activation and weight quantization are symmetric, |
| 45 | + # i.e. act_zp and w_zp is 0. |
| 46 | + # Assumes that weight quantization is per-channel. |
| 47 | + # NOTE: sparsity is only compatible with symmetric (zero-preserving) quantization techniques. |
| 48 | + |
| 49 | + # see |
| 50 | + # https://github.com/google/gemmlowp/blob/master/doc/quantization.md |
| 51 | + # for an overview of quantized matmul compute |
| 52 | + |
| 53 | + # in scalar form, assuming out_dtype is fp32 and zw == 0: |
| 54 | + # |
| 55 | + # Y_i_j_fp32 = sx * sw dot(X_i, W_j) |
| 56 | + # |
| 57 | + |
| 58 | + assert x_vals_int8.dtype == torch.int8, \ |
| 59 | + f'x dtype {x_vals_int8.dtype} not yet supported' |
| 60 | + assert w_vals_int8.dtype == torch.int8, \ |
| 61 | + f'w dtype {w_vals_int8.dtype} not yet supported' |
| 62 | + assert w_scales.dtype == out_dtype, \ |
| 63 | + f'{w_scales.dtype} does not match {out_dtype}' |
| 64 | + |
| 65 | + # |
| 66 | + # 1. do the matrix form of dot(X_i, W_j) |
| 67 | + # |
| 68 | + |
| 69 | + # For sparse matmul, we need one of the input operands to be transposed. |
| 70 | + # This is because cuSPARSELt only supports int8 matmul for specific formats: |
| 71 | + # https://docs.nvidia.com/cuda/cusparselt/functions.html#matmul-descriptor-functions |
| 72 | + # Because we currently only support the first input to the operand being sparse, |
| 73 | + # we cannot transpose w_vals_int8, so instead we transpose x_vals_int8. |
| 74 | + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]).contiguous() |
| 75 | + # Since cuSPARSELt does not have support for int32 output, we instead use the fp16 kernel |
| 76 | + # instead, by setting out_dtype. |
| 77 | + # y_dot_fp16 = torch._sparse_semi_structured_linear(tmp, w_vals_int8, out_dtype=torch.float16) |
| 78 | + y_dot_fp16 = torch._cslt_sparse_mm(w_vals_int8, tmp.t(), out_dtype=torch.float16).t() |
| 79 | + y_dot_fp32 = y_dot_fp16.reshape(*x_vals_int8.shape[:-1], -1).to(out_dtype) |
| 80 | + |
| 81 | + # |
| 82 | + # 2. rescale the output |
| 83 | + # |
| 84 | + # in cases with large matrices, y_dot_int32 can grow sufficiently |
| 85 | + # large that y_dot_int32 * a float16 scale is greater than the maximum |
| 86 | + # value of a float 16, (which results in a value of inf even if multiplying |
| 87 | + # by the other scale would bring it within the expected range) |
| 88 | + |
| 89 | + # assert x_scales.dtype == torch.float, f"x_scales needs to be a torch.float32 but got {x_scales.dtype}" |
| 90 | + |
| 91 | + y = y_dot_fp32 * x_scales * w_scales |
| 92 | + |
| 93 | + # can downcast only at the very end |
| 94 | + y = y.to(out_dtype) |
| 95 | + return y |
| 96 | + |
| 97 | +class SparseDynamicallyPerAxisQuantizedLinear(torch.nn.Linear): |
| 98 | + """ |
| 99 | + This class is a replacement for `torch.nn.Linear`, implementing sparse dynamic quantization on |
| 100 | + the input across all axes except for the last axis. |
| 101 | + """ |
| 102 | + |
| 103 | + def __init__( |
| 104 | + self, |
| 105 | + in_features: int, |
| 106 | + out_features: int, |
| 107 | + bias: bool = True |
| 108 | + ): |
| 109 | + super().__init__(in_features, out_features, bias) |
| 110 | + |
| 111 | + def forward(self, X: torch.Tensor) -> torch.Tensor: |
| 112 | + """ |
| 113 | + Performs the forward pass of the sparse quantized linear layer. |
| 114 | +
|
| 115 | + This method applies dynamic quantization to the input tensor across all axes except |
| 116 | + the last axis using the `quant_int8_dynamic_per_token_linear` function. |
| 117 | +
|
| 118 | + We artifically limit the quantization value to int4 range to ensure we stay within the range of fp16. |
| 119 | + This method will use cuSPASRELt to perform sparse matmul. |
| 120 | +
|
| 121 | + Args: |
| 122 | + X (torch.Tensor): The input tensor to the sparse quantized linear layer. |
| 123 | + Returns: |
| 124 | + torch.Tensor: The output tensor after the sparse quantized matmul and rescale. |
| 125 | + """ |
| 126 | + Y = sparse_quant_int8_dynamic_per_token_linear( |
| 127 | + X, self.W_int_repr, self.W_scales, self.bias, X.dtype, fuse_dequant=self.fuse_dequant) |
| 128 | + return Y |
| 129 | + |
| 130 | + @classmethod |
| 131 | + def from_float(cls, mod: torch.nn.Linear, fuse_dequant=True) -> 'SparseDynamicallyPerAxisQuantizedLinear': |
| 132 | + """ |
| 133 | + Converts a `mod` of class `torch.nn.Linear` to the sparse dynamically quantized version of it. |
| 134 | + Note: this class does not require calibration. |
| 135 | + Args: |
| 136 | + mod (torch.nn.Linear): The original `torch.nn.Linear` module to convert. |
| 137 | + Returns: |
| 138 | + SparseDynamicallyPerAxisQuantizedLinear: The converted sparse quantized linear module. |
| 139 | + """ |
| 140 | + |
| 141 | + # create the new module with a toy size to ensure initialization is fast |
| 142 | + fake_in_features, fake_out_features = 8, 8 |
| 143 | + new_mod = cls( |
| 144 | + fake_in_features, fake_out_features, bias=mod.bias is not None) |
| 145 | + new_mod.in_features = mod.in_features |
| 146 | + new_mod.out_features = mod.out_features |
| 147 | + # NOTE: We artifically clamp the values to int4 quantization to ensure we stay within the |
| 148 | + # dynamic range of fp16 |
| 149 | + W_int_repr, W_scales, _W_zps = dynamically_quantize_per_channel( |
| 150 | + mod.weight, -8, 7, torch.int8) |
| 151 | + new_mod.register_buffer('W_int_repr', torch._cslt_compress(W_int_repr.contiguous())) |
| 152 | + new_mod.register_buffer('W_scales', W_scales) |
| 153 | + new_mod.bias = mod.bias |
| 154 | + new_mod.fuse_dequant = fuse_dequant |
| 155 | + del new_mod.weight |
| 156 | + |
| 157 | + device_to_use = next(mod.parameters()).device |
| 158 | + new_mod.to(device_to_use) |
| 159 | + return new_mod |
| 160 | + |
| 161 | +def apply_int4_dynamic_quant_sparse(model, fuse_dequant=False): |
| 162 | + apply_fake_sparsity(model) |
| 163 | + quant_api._replace_with_custom_fn_if_matches_filter( |
| 164 | + model, |
| 165 | + partial(SparseDynamicallyPerAxisQuantizedLinear.from_float, fuse_dequant=fuse_dequant), |
| 166 | + lambda mod, fqn: isinstance(mod, torch.nn.Linear)) |
0 commit comments