|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from dataclasses import dataclass |
| 8 | +from typing import Optional |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +from torchao.core.config import AOBaseConfig |
| 13 | +from torchao.quantization.quant_primitives import ( |
| 14 | + choose_qparams_gguf, |
| 15 | + dequantize_gguf, |
| 16 | + quantize_gguf, |
| 17 | +) |
| 18 | +from torchao.quantization.transform_module import register_quantize_module_handler |
| 19 | +from torchao.utils import TorchAOBaseTensor |
| 20 | + |
| 21 | +_QK_K = 256 |
| 22 | + |
| 23 | +__all__ = [ |
| 24 | + "GGUFQuantizedTensor", |
| 25 | + "choose_qparams_gguf", |
| 26 | + "quantize_gguf", |
| 27 | + "dequantize_gguf", |
| 28 | + "GGUFWeightOnlyConfig", |
| 29 | +] |
| 30 | + |
| 31 | + |
| 32 | +class GGUFQuantizedTensor(TorchAOBaseTensor): |
| 33 | + """ |
| 34 | + A Tensor subclass that when applied to a weight used in a linear op/module, |
| 35 | + changes that linear op to a weight-only int4 quantized linear op with groupwise |
| 36 | + affine quantization on the weight. |
| 37 | + """ |
| 38 | + |
| 39 | + @staticmethod |
| 40 | + def __new__( |
| 41 | + cls, |
| 42 | + n_super_blocks, |
| 43 | + super_block_scale_scale, |
| 44 | + super_block_min_scale, |
| 45 | + quantized_block_scale, |
| 46 | + quantized_block_min, |
| 47 | + int_data, |
| 48 | + shape, |
| 49 | + **kwargs, |
| 50 | + ): |
| 51 | + kwargs["device"] = kwargs.get("device", super_block_scale_scale.device) |
| 52 | + kwargs["dtype"] = kwargs.get("dtype", super_block_scale_scale.dtype) |
| 53 | + kwargs["requires_grad"] = False |
| 54 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 55 | + |
| 56 | + def __init__( |
| 57 | + self, |
| 58 | + n_super_blocks, |
| 59 | + super_block_scale_scale, |
| 60 | + super_block_min_scale, |
| 61 | + quantized_block_scale, |
| 62 | + quantized_block_min, |
| 63 | + int_data, |
| 64 | + shape, |
| 65 | + **kwargs, |
| 66 | + ): |
| 67 | + self.n_super_blocks = n_super_blocks |
| 68 | + self.super_block_scale_scale = super_block_scale_scale |
| 69 | + self.super_block_min_scale = super_block_min_scale |
| 70 | + self.quantized_block_scale = quantized_block_scale |
| 71 | + self.quantized_block_min = quantized_block_min |
| 72 | + self.int_data = int_data |
| 73 | + |
| 74 | + def _apply_fn_to_data(self, fn): |
| 75 | + return self.__class__( |
| 76 | + self.n_super_blocks, |
| 77 | + fn(self.super_block_scale_scale), |
| 78 | + fn(self.super_block_min_sclae), |
| 79 | + fn(self.quantized_block_scale), |
| 80 | + fn(self.quantized_block_min), |
| 81 | + fn(self.int_data), |
| 82 | + self.shape, |
| 83 | + dtype=self.dtype, |
| 84 | + ) |
| 85 | + |
| 86 | + def __tensor_flatten__(self): |
| 87 | + return [ |
| 88 | + "super_block_scale_scale", |
| 89 | + "super_block_min_scale", |
| 90 | + "quantized_block_scale", |
| 91 | + "quantized_block_min", |
| 92 | + "int_data", |
| 93 | + ], ( |
| 94 | + self.n_super_blocks, |
| 95 | + self.dtype, |
| 96 | + self.shape, |
| 97 | + ) |
| 98 | + |
| 99 | + @classmethod |
| 100 | + def __tensor_unflatten__( |
| 101 | + cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None |
| 102 | + ): |
| 103 | + ( |
| 104 | + super_block_scale_scale, |
| 105 | + super_block_min_scale, |
| 106 | + quantized_block_scale, |
| 107 | + quantized_block_min, |
| 108 | + int_data, |
| 109 | + ) = ( |
| 110 | + tensor_data_dict["super_block_scale_scale"], |
| 111 | + tensor_data_dict["super_block_min_scale"], |
| 112 | + tensor_data_dict["quantized_block_scale"], |
| 113 | + tensor_data_dict["quantized_block_min"], |
| 114 | + tensor_data_dict["int_data"], |
| 115 | + ) |
| 116 | + n_super_blocks, dtype, shape = attributes |
| 117 | + return cls( |
| 118 | + n_super_blocks, |
| 119 | + super_block_scale_scale, |
| 120 | + super_block_min_scale, |
| 121 | + quantized_block_scale, |
| 122 | + quantized_block_min, |
| 123 | + int_data, |
| 124 | + shape if outer_size is None else outer_size, |
| 125 | + dtype=dtype, |
| 126 | + ) |
| 127 | + |
| 128 | + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: |
| 129 | + block_size = tuple( |
| 130 | + [1] * (self.int_data.ndim - 1) + [_QK_K // self.n_super_blocks] |
| 131 | + ) |
| 132 | + return dequantize_gguf( |
| 133 | + self.int_data, |
| 134 | + block_size, |
| 135 | + self.dtype, |
| 136 | + self.super_block_scale_scale, |
| 137 | + self.super_block_min_scale, |
| 138 | + self.quantized_block_scale, |
| 139 | + self.quantized_block_min, |
| 140 | + ) |
| 141 | + |
| 142 | + def detach(self): |
| 143 | + """ |
| 144 | + Returns a new `CodebookQuantizedTensor`. |
| 145 | + """ |
| 146 | + return self.__class__( |
| 147 | + self.n_super_blocks, |
| 148 | + self.super_block_scale_scale.detach(), |
| 149 | + self.super_block_min_scale.detach(), |
| 150 | + self.quantized_block_scale.detach(), |
| 151 | + self.quantized_block_min.detach(), |
| 152 | + self.int_data.detach(), |
| 153 | + self.shape, |
| 154 | + dtype=self.dtype, |
| 155 | + ) |
| 156 | + |
| 157 | + def requires_grad_(self, requires_grad=False): |
| 158 | + """ |
| 159 | + Modifies the tensor's `requires_grad` status in-place. |
| 160 | + """ |
| 161 | + assert not requires_grad, "Only requires_grad == False is supported" |
| 162 | + return self |
| 163 | + |
| 164 | + @classmethod |
| 165 | + def from_float(cls, input_float, n_super_blocks, target_dtype): |
| 166 | + """ |
| 167 | + Method used to convert a linear weight tensor to an instance of the |
| 168 | + GGMLInt4LinearWeight subclass. |
| 169 | +
|
| 170 | + Example usage:: |
| 171 | +
|
| 172 | + model.lin_mod.weight = ( |
| 173 | + GGMLInt4LinearWeight.from_float(model.lin_mod.weight) |
| 174 | + ) |
| 175 | + """ |
| 176 | + assert ( |
| 177 | + target_dtype == torch.uint4 |
| 178 | + ), "only uint4 quantization is supported right now" |
| 179 | + block_size = (1, _QK_K // n_super_blocks) |
| 180 | + ( |
| 181 | + super_block_scale_scale, |
| 182 | + super_block_min_scale, |
| 183 | + quantized_block_scale, |
| 184 | + quantized_block_min, |
| 185 | + ) = choose_qparams_gguf(input_float, block_size, target_dtype) |
| 186 | + |
| 187 | + int_data = quantize_gguf( |
| 188 | + input_float, |
| 189 | + block_size, |
| 190 | + target_dtype, |
| 191 | + super_block_scale_scale, |
| 192 | + super_block_min_scale, |
| 193 | + quantized_block_scale, |
| 194 | + quantized_block_min, |
| 195 | + ) |
| 196 | + return cls( |
| 197 | + n_super_blocks, |
| 198 | + super_block_scale_scale, |
| 199 | + super_block_min_scale, |
| 200 | + quantized_block_scale, |
| 201 | + quantized_block_min, |
| 202 | + int_data, |
| 203 | + input_float.shape, |
| 204 | + dtype=torch.float16, |
| 205 | + ) |
| 206 | + |
| 207 | + |
| 208 | +@dataclass |
| 209 | +class GGUFWeightOnlyConfig(AOBaseConfig): |
| 210 | + dtype: torch.dtype = torch.uint4 |
| 211 | + n_super_blocks: int = 8 |
| 212 | + |
| 213 | + |
| 214 | +@register_quantize_module_handler(GGUFWeightOnlyConfig) |
| 215 | +def _gguf_weight_only_transform( |
| 216 | + module: torch.nn.Module, |
| 217 | + config: GGUFWeightOnlyConfig, |
| 218 | +): |
| 219 | + """ |
| 220 | + Applies gguf weight-only quantization to linear layers. |
| 221 | +
|
| 222 | + Args: |
| 223 | + dtype: torch.uint1 to torch.uint8, torch.int32 supported. |
| 224 | + n_super_blocks: the number of super blocks in a 256 element block for gguf, e.g. when it is 8 |
| 225 | + it means we have blocks of 32 and 8 blocks in a superblock of 256 elements. |
| 226 | + Returns: |
| 227 | + Callable for quantization transformation. |
| 228 | + """ |
| 229 | + weight = module.weight |
| 230 | + if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0): |
| 231 | + return module |
| 232 | + |
| 233 | + quantized_weight = GGUFQuantizedTensor.from_float( |
| 234 | + weight, n_super_blocks=config.n_super_blocks, target_dtype=config.dtype |
| 235 | + ) |
| 236 | + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) |
| 237 | + return module |
0 commit comments