Skip to content

Commit 1dc5fe0

Browse files
committed
Add gguf q4_k_s quantization
Summary: Didn't implement the algorithm to choose_qparams from gguf, since it's complicated, e.g. https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L744 and https://github.com/ggml-org/llama.cpp/blob/f423981ac806bf031d83784bcb47d2721bc70f97/ggml/src/ggml-quants.c#L827C14-L827C28 but implemented a simple choose_qparams that can fit the gguf format: Q4_K: w = q * block_scale(6-bit) + block_min(6-bit) Test Plan: python test/prototype/test_gguf_quant.py Reviewers: Subscribers: Tasks: Tags:
1 parent f38c272 commit 1dc5fe0

File tree

4 files changed

+487
-0
lines changed

4 files changed

+487
-0
lines changed

test/prototype/test_gguf_quant.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import unittest
2+
3+
import torch
4+
5+
from torchao.prototype.quantization.gguf import (
6+
GGUFQuantizedTensor,
7+
GGUFWeightOnlyConfig,
8+
choose_qparams_gguf,
9+
)
10+
from torchao.quantization import quantize_
11+
from torchao.quantization.utils import compute_error
12+
13+
14+
class TestGGUFQuantization(unittest.TestCase):
15+
def setUp(self):
16+
torch.manual_seed(123)
17+
self.input = torch.randn(2, 256, dtype=torch.float32)
18+
self.n_super_blocks = 8
19+
self.block_size = (1, 32)
20+
self.dtype = torch.uint4
21+
22+
def test_choose_qparams_gguf(self):
23+
(
24+
super_block_scale_scale,
25+
super_block_min_scale,
26+
quantized_block_scale,
27+
quantized_block_min,
28+
) = choose_qparams_gguf(self.input, self.block_size, self.dtype)
29+
30+
assert super_block_scale_scale.shape, (2, 8)
31+
assert super_block_min_scale.shape, (2, 8)
32+
assert quantized_block_scale.shape, (2, 32)
33+
34+
def test_gguf_quantized_tensor_from_float(self):
35+
gqt = GGUFQuantizedTensor.from_float(
36+
self.input,
37+
self.n_super_blocks,
38+
self.dtype,
39+
)
40+
41+
dequant = gqt.dequantize()
42+
43+
sqnr = compute_error(dequant, self.input)
44+
self.assertGreater(sqnr, 30)
45+
46+
def test_quantize_api(self):
47+
m = torch.nn.Sequential(torch.nn.Linear(256, 64))
48+
quantize_(m, GGUFWeightOnlyConfig())
49+
assert type(m[0].weight) == GGUFQuantizedTensor
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .gguf_quantized_tensor import (
2+
GGUFQuantizedTensor,
3+
GGUFWeightOnlyConfig,
4+
choose_qparams_gguf,
5+
)
6+
7+
__all__ = [
8+
"GGUFQuantizedTensor",
9+
"choose_qparams_gguf",
10+
"GGUFWeightOnlyConfig",
11+
]
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Define a Tensor subclass to wrap around ggml q4_0 tensor layout.
8+
# The layout is the following:
9+
# ┌─────────────────────┬───────────────────────────┐
10+
# │ │ │
11+
# │ │ │
12+
# │ 2 bytes (1xfp16) │ 16 bytes (32xint4) │
13+
# │ group-wise scale │ group-wise weights │
14+
# │ │ │
15+
# │ │ │
16+
# └─────────────────────┴───────────────────────────┘
17+
#
18+
# Notice that the 16 bytes (32 int4) are interleved:
19+
# [0th value, 16th value, 1st value, 17th value, ..., 15th, 31st]
20+
#
21+
# This layout is handled internally in the tensor subclass.
22+
from dataclasses import dataclass
23+
from typing import Optional
24+
25+
import torch
26+
27+
from torchao.core.config import AOBaseConfig
28+
from torchao.quantization.quant_primitives import (
29+
choose_qparams_gguf,
30+
dequantize_gguf,
31+
quantize_gguf,
32+
)
33+
from torchao.quantization.transform_module import register_quantize_module_handler
34+
from torchao.utils import TorchAOBaseTensor
35+
36+
_QK_K = 256
37+
38+
__all__ = [
39+
"GGUFQuantizedTensor",
40+
"choose_qparams_gguf",
41+
"quantize_gguf",
42+
"dequantize_gguf",
43+
"GGUFWeightOnlyConfig",
44+
]
45+
46+
47+
class GGUFQuantizedTensor(TorchAOBaseTensor):
48+
"""
49+
A Tensor subclass that when applied to a weight used in a linear op/module,
50+
changes that linear op to a weight-only int4 quantized linear op with groupwise
51+
affine quantization on the weight.
52+
"""
53+
54+
@staticmethod
55+
def __new__(
56+
cls,
57+
n_super_blocks,
58+
super_block_scale_scale,
59+
super_block_min_scale,
60+
quantized_block_scale,
61+
quantized_block_min,
62+
int_data,
63+
shape,
64+
**kwargs,
65+
):
66+
kwargs["device"] = kwargs.get("device", super_block_scale_scale.device)
67+
kwargs["dtype"] = kwargs.get("dtype", super_block_scale_scale.dtype)
68+
kwargs["requires_grad"] = False
69+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
70+
71+
def __init__(
72+
self,
73+
n_super_blocks,
74+
super_block_scale_scale,
75+
super_block_min_scale,
76+
quantized_block_scale,
77+
quantized_block_min,
78+
int_data,
79+
shape,
80+
**kwargs,
81+
):
82+
self.n_super_blocks = n_super_blocks
83+
self.super_block_scale_scale = super_block_scale_scale
84+
self.super_block_min_scale = super_block_min_scale
85+
self.quantized_block_scale = quantized_block_scale
86+
self.quantized_block_min = quantized_block_min
87+
self.int_data = int_data
88+
89+
def _apply_fn_to_data(self, fn):
90+
return self.__class__(
91+
self.n_super_blocks,
92+
fn(self.super_block_scale_scale),
93+
fn(self.super_block_min_sclae),
94+
fn(self.quantized_block_scale),
95+
fn(self.quantized_block_min),
96+
fn(self.int_data),
97+
self.shape,
98+
dtype=self.dtype,
99+
)
100+
101+
def __tensor_flatten__(self):
102+
return [
103+
"super_block_scale_scale",
104+
"super_block_min_scale",
105+
"quantized_block_scale",
106+
"quantized_block_min",
107+
"int_data",
108+
], (
109+
self.n_super_blocks,
110+
self.dtype,
111+
self.shape,
112+
)
113+
114+
@classmethod
115+
def __tensor_unflatten__(
116+
cls, tensor_data_dict, attributes, outer_size=None, outer_stride=None
117+
):
118+
(
119+
super_block_scale_scale,
120+
super_block_min_scale,
121+
quantized_block_scale,
122+
quantized_block_min,
123+
int_data,
124+
) = (
125+
tensor_data_dict["super_block_scale_scale"],
126+
tensor_data_dict["super_block_min_scale"],
127+
tensor_data_dict["quantized_block_scale"],
128+
tensor_data_dict["quantized_block_min"],
129+
tensor_data_dict["int_data"],
130+
)
131+
n_super_blocks, dtype, shape = attributes
132+
return cls(
133+
n_super_blocks,
134+
super_block_scale_scale,
135+
super_block_min_scale,
136+
quantized_block_scale,
137+
quantized_block_min,
138+
int_data,
139+
shape if outer_size is None else outer_size,
140+
dtype=dtype,
141+
)
142+
143+
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
144+
block_size = tuple(
145+
[1] * (self.int_data.ndim - 1) + [_QK_K // self.n_super_blocks]
146+
)
147+
return dequantize_gguf(
148+
self.int_data,
149+
block_size,
150+
self.dtype,
151+
self.super_block_scale_scale,
152+
self.super_block_min_scale,
153+
self.quantized_block_scale,
154+
self.quantized_block_min,
155+
)
156+
157+
def detach(self):
158+
"""
159+
Returns a new `CodebookQuantizedTensor`.
160+
"""
161+
return self.__class__(
162+
self.n_super_blocks,
163+
self.super_block_scale_scale.detach(),
164+
self.super_block_min_scale.detach(),
165+
self.quantized_block_scale.detach(),
166+
self.quantized_block_min.detach(),
167+
self.int_data.detach(),
168+
self.shape,
169+
dtype=self.dtype,
170+
)
171+
172+
def requires_grad_(self, requires_grad=False):
173+
"""
174+
Modifies the tensor's `requires_grad` status in-place.
175+
"""
176+
assert not requires_grad, "Only requires_grad == False is supported"
177+
return self
178+
179+
@classmethod
180+
def from_float(cls, input_float, n_super_blocks, target_dtype):
181+
"""
182+
Method used to convert a linear weight tensor to an instance of the
183+
GGMLInt4LinearWeight subclass.
184+
185+
Example usage::
186+
187+
model.lin_mod.weight = (
188+
GGMLInt4LinearWeight.from_float(model.lin_mod.weight)
189+
)
190+
"""
191+
assert (
192+
target_dtype == torch.uint4
193+
), "only uint4 quantization is supported right now"
194+
block_size = (1, _QK_K // n_super_blocks)
195+
(
196+
super_block_scale_scale,
197+
super_block_min_scale,
198+
quantized_block_scale,
199+
quantized_block_min,
200+
) = choose_qparams_gguf(input_float, block_size, target_dtype)
201+
202+
int_data = quantize_gguf(
203+
input_float,
204+
block_size,
205+
target_dtype,
206+
super_block_scale_scale,
207+
super_block_min_scale,
208+
quantized_block_scale,
209+
quantized_block_min,
210+
)
211+
return cls(
212+
n_super_blocks,
213+
super_block_scale_scale,
214+
super_block_min_scale,
215+
quantized_block_scale,
216+
quantized_block_min,
217+
int_data,
218+
input_float.shape,
219+
dtype=torch.float16,
220+
)
221+
222+
223+
@dataclass
224+
class GGUFWeightOnlyConfig(AOBaseConfig):
225+
dtype: torch.dtype = torch.uint4
226+
n_super_blocks: int = 8
227+
228+
229+
@register_quantize_module_handler(GGUFWeightOnlyConfig)
230+
def _gguf_weight_only_transform(
231+
module: torch.nn.Module,
232+
config: GGUFWeightOnlyConfig,
233+
):
234+
"""
235+
Applies gguf weight-only quantization to linear layers.
236+
237+
Args:
238+
dtype: torch.uint1 to torch.uint8, torch.int32 supported.
239+
n_super_blocks: the number of super blocks in a 256 element block for gguf, e.g. when it is 8
240+
it means we have blocks of 32 and 8 blocks in a superblock of 256 elements.
241+
Returns:
242+
Callable for quantization transformation.
243+
"""
244+
weight = module.weight
245+
if (weight.ndim != 2) or (weight.shape[-1] % 256 != 0):
246+
return module
247+
248+
quantized_weight = GGUFQuantizedTensor.from_float(
249+
weight, n_super_blocks=config.n_super_blocks, target_dtype=config.dtype
250+
)
251+
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
252+
return module

0 commit comments

Comments
 (0)