Skip to content

Commit 8132fec

Browse files
committed
fix
1 parent aa18642 commit 8132fec

File tree

3 files changed

+32
-23
lines changed

3 files changed

+32
-23
lines changed

test/prototype/test_gguf_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TestGGUFQuantization(unittest.TestCase):
1515
def setUp(self):
1616
torch.manual_seed(123)
1717
self.input = torch.randn(2, 256, dtype=torch.float32)
18-
self.n_super_blocks = 8
18+
self.n_blocks_per_superblock = 8
1919
self.block_size = (1, 32)
2020
self.dtype = torch.uint4
2121

@@ -34,7 +34,7 @@ def test_choose_qparams_gguf(self):
3434
def test_gguf_quantized_tensor_from_float(self):
3535
gqt = GGUFQuantizedTensor.from_float(
3636
self.input,
37-
self.n_super_blocks,
37+
self.n_blocks_per_superblock,
3838
self.dtype,
3939
)
4040

torchao/prototype/quantization/gguf/gguf_quantized_tensor.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class GGUFQuantizedTensor(TorchAOBaseTensor):
3939
@staticmethod
4040
def __new__(
4141
cls,
42-
n_super_blocks,
42+
n_blocks_per_superblock,
4343
super_block_scale_scale,
4444
super_block_min_scale,
4545
quantized_block_scale,
@@ -55,7 +55,7 @@ def __new__(
5555

5656
def __init__(
5757
self,
58-
n_super_blocks,
58+
n_blocks_per_superblock,
5959
super_block_scale_scale,
6060
super_block_min_scale,
6161
quantized_block_scale,
@@ -64,7 +64,7 @@ def __init__(
6464
shape,
6565
**kwargs,
6666
):
67-
self.n_super_blocks = n_super_blocks
67+
self.n_blocks_per_superblock = n_blocks_per_superblock
6868
self.super_block_scale_scale = super_block_scale_scale
6969
self.super_block_min_scale = super_block_min_scale
7070
self.quantized_block_scale = quantized_block_scale
@@ -73,7 +73,7 @@ def __init__(
7373

7474
def _apply_fn_to_data(self, fn):
7575
return self.__class__(
76-
self.n_super_blocks,
76+
self.n_blocks_per_superblock,
7777
fn(self.super_block_scale_scale),
7878
fn(self.super_block_min_sclae),
7979
fn(self.quantized_block_scale),
@@ -91,7 +91,7 @@ def __tensor_flatten__(self):
9191
"quantized_block_min",
9292
"int_data",
9393
], (
94-
self.n_super_blocks,
94+
self.n_blocks_per_superblock,
9595
self.dtype,
9696
self.shape,
9797
)
@@ -113,9 +113,9 @@ def __tensor_unflatten__(
113113
tensor_data_dict["quantized_block_min"],
114114
tensor_data_dict["int_data"],
115115
)
116-
n_super_blocks, dtype, shape = attributes
116+
n_blocks_per_superblock, dtype, shape = attributes
117117
return cls(
118-
n_super_blocks,
118+
n_blocks_per_superblock,
119119
super_block_scale_scale,
120120
super_block_min_scale,
121121
quantized_block_scale,
@@ -127,7 +127,7 @@ def __tensor_unflatten__(
127127

128128
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
129129
block_size = tuple(
130-
[1] * (self.int_data.ndim - 1) + [_QK_K // self.n_super_blocks]
130+
[1] * (self.int_data.ndim - 1) + [_QK_K // self.n_blocks_per_superblock]
131131
)
132132
return dequantize_gguf(
133133
self.int_data,
@@ -144,7 +144,7 @@ def detach(self):
144144
Returns a new `CodebookQuantizedTensor`.
145145
"""
146146
return self.__class__(
147-
self.n_super_blocks,
147+
self.n_blocks_per_superblock,
148148
self.super_block_scale_scale.detach(),
149149
self.super_block_min_scale.detach(),
150150
self.quantized_block_scale.detach(),
@@ -162,7 +162,7 @@ def requires_grad_(self, requires_grad=False):
162162
return self
163163

164164
@classmethod
165-
def from_float(cls, input_float, n_super_blocks, target_dtype):
165+
def from_float(cls, input_float, n_blocks_per_superblock, target_dtype):
166166
"""
167167
Method used to convert a linear weight tensor to an instance of the
168168
GGMLInt4LinearWeight subclass.
@@ -176,7 +176,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype):
176176
assert (
177177
target_dtype == torch.uint4
178178
), "only uint4 quantization is supported right now"
179-
block_size = (1, _QK_K // n_super_blocks)
179+
block_size = (1, _QK_K // n_blocks_per_superblock)
180180
(
181181
super_block_scale_scale,
182182
super_block_min_scale,
@@ -194,7 +194,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype):
194194
quantized_block_min,
195195
)
196196
return cls(
197-
n_super_blocks,
197+
n_blocks_per_superblock,
198198
super_block_scale_scale,
199199
super_block_min_scale,
200200
quantized_block_scale,
@@ -208,7 +208,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype):
208208
@dataclass
209209
class GGUFWeightOnlyConfig(AOBaseConfig):
210210
dtype: torch.dtype = torch.uint4
211-
n_super_blocks: int = 8
211+
n_blocks_per_superblock: int = 8
212212

213213

214214
@register_quantize_module_handler(GGUFWeightOnlyConfig)
@@ -221,7 +221,7 @@ def _gguf_weight_only_transform(
221221
222222
Args:
223223
dtype: torch.uint1 to torch.uint8, torch.int32 supported.
224-
n_super_blocks: the number of super blocks in a 256 element block for gguf, e.g. when it is 8
224+
n_blocks_per_superblock: the number of super blocks in a 256 element block for gguf, e.g. when it is 8
225225
it means we have blocks of 32 and 8 blocks in a superblock of 256 elements.
226226
Returns:
227227
Callable for quantization transformation.
@@ -231,7 +231,9 @@ def _gguf_weight_only_transform(
231231
return module
232232

233233
quantized_weight = GGUFQuantizedTensor.from_float(
234-
weight, n_super_blocks=config.n_super_blocks, target_dtype=config.dtype
234+
weight,
235+
n_blocks_per_superblock=config.n_blocks_per_superblock,
236+
target_dtype=config.dtype,
235237
)
236238
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
237239
return module

torchao/quantization/quant_primitives.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,19 +1098,26 @@ def choose_qparams_gguf(
10981098
)
10991099

11001100
# 2. get super_block_scale_scale and super_block_min_scale
1101-
quant_max = 2**6 - 1
1102-
quant_min = 0
1103-
super_block_scale_scale = block_scale_absmax / float(quant_max - quant_min)
1104-
super_block_min_scale = block_min_absmax / float(quant_max - quant_min)
1101+
# TODO: make this configurable
1102+
# we also quantize the quantization parameters (scale and min) for each block to 6 bit
1103+
# for Q4_K
1104+
qparam_quant_max = 2**6 - 1
1105+
qparam_quant_min = 0
1106+
super_block_scale_scale = block_scale_absmax / float(
1107+
qparam_quant_max - qparam_quant_min
1108+
)
1109+
super_block_min_scale = block_min_absmax / float(
1110+
qparam_quant_max - qparam_quant_min
1111+
)
11051112
super_block_scale_scale_view = super_block_scale_scale.view(shape_after_reduction)
11061113
super_block_min_scale_view = super_block_min_scale.view(shape_after_reduction)
11071114

11081115
# 3. quantize block scale and min are stored in 6 bits using super_block_scale_scale and super_block_min_scale
11091116
quantized_block_scale = torch.clamp(
1110-
block_scale / super_block_scale_scale_view, quant_min, quant_max
1117+
block_scale / super_block_scale_scale_view, qparam_quant_min, qparam_quant_max
11111118
)
11121119
quantized_block_min = torch.clamp(
1113-
block_min / super_block_min_scale_view, quant_min, quant_max
1120+
block_min / super_block_min_scale_view, qparam_quant_min, qparam_quant_max
11141121
)
11151122
return (
11161123
super_block_scale_scale,

0 commit comments

Comments
 (0)