Skip to content

Commit c10ba15

Browse files
committed
Add gguf q4_k_s quantization
Summary: Didn't implement the algorithm to choose_qparams from gguf, since it's complicated, e.g. https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L744 and https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L827C14-L827C28 but implemented a simple choose_qparams that can fit the gguf format: Q4_K: w = q * block_scale(6-bit) + block_min(6-bit) Test Plan: python test/prototype/test_gguf_quant.py Reviewers: Subscribers: Tasks: Tags:
1 parent 3bbf42a commit c10ba15

File tree

4 files changed

+472
-0
lines changed

4 files changed

+472
-0
lines changed

test/prototype/test_gguf_quant.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import unittest
2+
3+
import torch
4+
5+
from torchao.prototype.quantization.gguf import (
6+
GGUFQuantizedTensor,
7+
GGUFWeightOnlyConfig,
8+
choose_qparams_gguf,
9+
)
10+
from torchao.quantization import quantize_
11+
from torchao.quantization.utils import compute_error
12+
13+
14+
class TestGGUFQuantization(unittest.TestCase):
15+
def setUp(self):
16+
torch.manual_seed(123)
17+
self.input = torch.randn(2, 256, dtype=torch.float32)
18+
self.n_super_blocks = 8
19+
self.block_size = (1, 32)
20+
self.dtype = torch.uint4
21+
22+
def test_choose_qparams_gguf(self):
23+
(
24+
super_block_scale_scale,
25+
super_block_min_scale,
26+
quantized_block_scale,
27+
quantized_block_min,
28+
) = choose_qparams_gguf(self.input, self.block_size, self.dtype)
29+
30+
assert super_block_scale_scale.shape, (2, 8)
31+
assert super_block_min_scale.shape, (2, 8)
32+
assert quantized_block_scale.shape, (2, 32)
33+
34+
def test_gguf_quantized_tensor_from_float(self):
35+
gqt = GGUFQuantizedTensor.from_float(
36+
self.input,
37+
self.n_super_blocks,
38+
self.dtype,
39+
)
40+
41+
dequant = gqt.dequantize()
42+
43+
sqnr = compute_error(dequant, self.input)
44+
self.assertGreater(sqnr, 30)
45+
46+
def test_quantize_api(self):
47+
m = torch.nn.Sequential(torch.nn.Linear(256, 64))
48+
quantize_(m, GGUFWeightOnlyConfig())
49+
assert type(m[0].weight) == GGUFQuantizedTensor
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .gguf_quantized_tensor import (
2+
GGUFQuantizedTensor,
3+
GGUFWeightOnlyConfig,
4+
choose_qparams_gguf,
5+
)
6+
7+
__all__ = [
8+
"GGUFQuantizedTensor",
9+
"choose_qparams_gguf",
10+
"GGUFWeightOnlyConfig",
11+
]
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
from typing import Optional
9+
10+
import torch
11+
12+
from torchao.core.config import AOBaseConfig
13+
from torchao.quantization.quant_primitives import (
14+
choose_qparams_gguf,
15+
dequantize_gguf,
16+
quantize_gguf,
17+
)
18+
from torchao.quantization.transform_module import register_quantize_module_handler
19+
from torchao.utils import TorchAOBaseTensor
20+
21+
_QK_K = 256
22+
23+
__all__ = [
24+
"GGUFQuantizedTensor",
25+
"choose_qparams_gguf",
26+
"quantize_gguf",
27+
"dequantize_gguf",
28+
"GGUFWeightOnlyConfig",
29+
]
30+
31+
32+
class GGUFQuantizedTensor(TorchAOBaseTensor):
33+
"""
34+
A Tensor subclass that when applied to a weight used in a linear op/module,
35+
changes that linear op to a weight-only int4 quantized linear op with groupwise
36+
affine quantization on the weight.
37+
"""
38+
39+
@staticmethod
40+
def __new__(
41+
cls,
42+
n_super_blocks,
43+
super_block_scale_scale,
44+
super_block_min_scale,
45+
quantized_block_scale,
46+
quantized_block_min,
47+
int_data,
48+
shape,
49+
**kwargs,
50+
):
51+
kwargs["device"] = kwargs.get("device", super_block_scale_scale.device)
52+
kwargs["dtype"] = kwargs.get("dtype", super_block_scale_scale.dtype)
53+
kwargs["requires_grad"] = False
54+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
55+
56+
def __init__(
57+
self,
58+
n_super_blocks,
59+
super_block_scale_scale,
60+
super_block_min_scale,
61+
quantized_block_scale,
62+
quantized_block_min,
63+
int_data,
64+
shape,
65+
**kwargs,
66+
):
67+
self.n_super_blocks = n_super_blocks
68+
self.super_block_scale_scale = super_block_scale_scale
69+
self.super_block_min_scale = super_block_min_scale
70+
self.quantized_block_scale = quantized_block_scale
71+
self.quantized_block_min = quantized_block_min
72+
self.int_data = int_data
73+
74+
def _apply_fn_to_data(self, fn):
75+
return self.__class__(
76+
self.n_super_blocks,
77+
fn(self.super_block_scale_scale),
78+
fn(self.super_block_min_sclae),
79+
fn(self.quantized_block_scale),
80+
fn(self.quantized_block_min),
81+
fn(self.int_data),
82+
self.shape,
83+
dtype=self.dtype,
84+
)
85+
86+
def __tensor_flatten__(self):
87+
return [
88+
"super_block_scale_scale",
89+
"super_block_min_scale",
90+
"quantized_block_scale",
91+
"quantized_block_min",
92+
"int_data",
93+
], (
94+
self.n_super_blocks,
95+
self.dtype,
96+
self.shape,
97+
)
98+
99+
@classmethod
100+
def __tensor_unflatten__(
101+
cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None
102+
):
103+
(
104+
super_block_scale_scale,
105+
super_block_min_scale,
106+
quantized_block_scale,
107+
quantized_block_min,
108+
int_data,
109+
) = (
110+
tensor_data_dict["super_block_scale_scale"],
111+
tensor_data_dict["super_block_min_scale"],
112+
tensor_data_dict["quantized_block_scale"],
113+
tensor_data_dict["quantized_block_min"],
114+
tensor_data_dict["int_data"],
115+
)
116+
n_super_blocks, dtype, shape = attributes
117+
return cls(
118+
n_super_blocks,
119+
super_block_scale_scale,
120+
super_block_min_scale,
121+
quantized_block_scale,
122+
quantized_block_min,
123+
int_data,
124+
shape if outer_size is None else outer_size,
125+
dtype=dtype,
126+
)
127+
128+
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
129+
block_size = tuple(
130+
[1] * (self.int_data.ndim - 1) + [_QK_K // self.n_super_blocks]
131+
)
132+
return dequantize_gguf(
133+
self.int_data,
134+
block_size,
135+
self.dtype,
136+
self.super_block_scale_scale,
137+
self.super_block_min_scale,
138+
self.quantized_block_scale,
139+
self.quantized_block_min,
140+
)
141+
142+
def detach(self):
143+
"""
144+
Returns a new `CodebookQuantizedTensor`.
145+
"""
146+
return self.__class__(
147+
self.n_super_blocks,
148+
self.super_block_scale_scale.detach(),
149+
self.super_block_min_scale.detach(),
150+
self.quantized_block_scale.detach(),
151+
self.quantized_block_min.detach(),
152+
self.int_data.detach(),
153+
self.shape,
154+
dtype=self.dtype,
155+
)
156+
157+
def requires_grad_(self, requires_grad=False):
158+
"""
159+
Modifies the tensor's `requires_grad` status in-place.
160+
"""
161+
assert not requires_grad, "Only requires_grad == False is supported"
162+
return self
163+
164+
@classmethod
165+
def from_float(cls, input_float, n_super_blocks, target_dtype):
166+
"""
167+
Method used to convert a linear weight tensor to an instance of the
168+
GGMLInt4LinearWeight subclass.
169+
170+
Example usage::
171+
172+
model.lin_mod.weight = (
173+
GGMLInt4LinearWeight.from_float(model.lin_mod.weight)
174+
)
175+
"""
176+
assert (
177+
target_dtype == torch.uint4
178+
), "only uint4 quantization is supported right now"
179+
block_size = (1, _QK_K // n_super_blocks)
180+
(
181+
super_block_scale_scale,
182+
super_block_min_scale,
183+
quantized_block_scale,
184+
quantized_block_min,
185+
) = choose_qparams_gguf(input_float, block_size, target_dtype)
186+
187+
int_data = quantize_gguf(
188+
input_float,
189+
block_size,
190+
target_dtype,
191+
super_block_scale_scale,
192+
super_block_min_scale,
193+
quantized_block_scale,
194+
quantized_block_min,
195+
)
196+
return cls(
197+
n_super_blocks,
198+
super_block_scale_scale,
199+
super_block_min_scale,
200+
quantized_block_scale,
201+
quantized_block_min,
202+
int_data,
203+
input_float.shape,
204+
dtype=torch.float16,
205+
)
206+
207+
208+
@dataclass
209+
class GGUFWeightOnlyConfig(AOBaseConfig):
210+
dtype: torch.dtype = torch.uint4
211+
n_super_blocks: int = 8
212+
213+
214+
@register_quantize_module_handler(GGUFWeightOnlyConfig)
215+
def _gguf_weight_only_transform(
216+
module: torch.nn.Module,
217+
config: GGUFWeightOnlyConfig,
218+
):
219+
"""
220+
Applies gguf weight-only quantization to linear layers.
221+
222+
Args:
223+
dtype: torch.uint1 to torch.uint8, torch.int32 supported.
224+
n_super_blocks: the number of super blocks in a 256 element block for gguf, e.g. when it is 8
225+
it means we have blocks of 32 and 8 blocks in a superblock of 256 elements.
226+
Returns:
227+
Callable for quantization transformation.
228+
"""
229+
weight = module.weight
230+
if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0):
231+
return module
232+
233+
quantized_weight = GGUFQuantizedTensor.from_float(
234+
weight, n_super_blocks=config.n_super_blocks, target_dtype=config.dtype
235+
)
236+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
237+
return module

0 commit comments

Comments
 (0)