@@ -39,7 +39,7 @@ class GGUFQuantizedTensor(TorchAOBaseTensor):
39
39
@staticmethod
40
40
def __new__ (
41
41
cls ,
42
- n_super_blocks ,
42
+ n_blocks_per_superblock ,
43
43
super_block_scale_scale ,
44
44
super_block_min_scale ,
45
45
quantized_block_scale ,
@@ -55,7 +55,7 @@ def __new__(
55
55
56
56
def __init__ (
57
57
self ,
58
- n_super_blocks ,
58
+ n_blocks_per_superblock ,
59
59
super_block_scale_scale ,
60
60
super_block_min_scale ,
61
61
quantized_block_scale ,
@@ -64,7 +64,7 @@ def __init__(
64
64
shape ,
65
65
** kwargs ,
66
66
):
67
- self .n_super_blocks = n_super_blocks
67
+ self .n_blocks_per_superblock = n_blocks_per_superblock
68
68
self .super_block_scale_scale = super_block_scale_scale
69
69
self .super_block_min_scale = super_block_min_scale
70
70
self .quantized_block_scale = quantized_block_scale
@@ -73,7 +73,7 @@ def __init__(
73
73
74
74
def _apply_fn_to_data (self , fn ):
75
75
return self .__class__ (
76
- self .n_super_blocks ,
76
+ self .n_blocks_per_superblock ,
77
77
fn (self .super_block_scale_scale ),
78
78
fn (self .super_block_min_sclae ),
79
79
fn (self .quantized_block_scale ),
@@ -91,7 +91,7 @@ def __tensor_flatten__(self):
91
91
"quantized_block_min" ,
92
92
"int_data" ,
93
93
], (
94
- self .n_super_blocks ,
94
+ self .n_blocks_per_superblock ,
95
95
self .dtype ,
96
96
self .shape ,
97
97
)
@@ -113,9 +113,9 @@ def __tensor_unflatten__(
113
113
tensor_data_dict ["quantized_block_min" ],
114
114
tensor_data_dict ["int_data" ],
115
115
)
116
- n_super_blocks , dtype , shape = attributes
116
+ n_blocks_per_superblock , dtype , shape = attributes
117
117
return cls (
118
- n_super_blocks ,
118
+ n_blocks_per_superblock ,
119
119
super_block_scale_scale ,
120
120
super_block_min_scale ,
121
121
quantized_block_scale ,
@@ -127,7 +127,7 @@ def __tensor_unflatten__(
127
127
128
128
def dequantize (self , output_dtype : Optional [torch .dtype ] = None ) -> torch .Tensor :
129
129
block_size = tuple (
130
- [1 ] * (self .int_data .ndim - 1 ) + [_QK_K // self .n_super_blocks ]
130
+ [1 ] * (self .int_data .ndim - 1 ) + [_QK_K // self .n_blocks_per_superblock ]
131
131
)
132
132
return dequantize_gguf (
133
133
self .int_data ,
@@ -144,7 +144,7 @@ def detach(self):
144
144
Returns a new `CodebookQuantizedTensor`.
145
145
"""
146
146
return self .__class__ (
147
- self .n_super_blocks ,
147
+ self .n_blocks_per_superblock ,
148
148
self .super_block_scale_scale .detach (),
149
149
self .super_block_min_scale .detach (),
150
150
self .quantized_block_scale .detach (),
@@ -162,7 +162,7 @@ def requires_grad_(self, requires_grad=False):
162
162
return self
163
163
164
164
@classmethod
165
- def from_float (cls , input_float , n_super_blocks , target_dtype ):
165
+ def from_float (cls , input_float , n_blocks_per_superblock , target_dtype ):
166
166
"""
167
167
Method used to convert a linear weight tensor to an instance of the
168
168
GGMLInt4LinearWeight subclass.
@@ -176,7 +176,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype):
176
176
assert (
177
177
target_dtype == torch .uint4
178
178
), "only uint4 quantization is supported right now"
179
- block_size = (1 , _QK_K // n_super_blocks )
179
+ block_size = (1 , _QK_K // n_blocks_per_superblock )
180
180
(
181
181
super_block_scale_scale ,
182
182
super_block_min_scale ,
@@ -194,7 +194,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype):
194
194
quantized_block_min ,
195
195
)
196
196
return cls (
197
- n_super_blocks ,
197
+ n_blocks_per_superblock ,
198
198
super_block_scale_scale ,
199
199
super_block_min_scale ,
200
200
quantized_block_scale ,
@@ -208,7 +208,7 @@ def from_float(cls, input_float, n_super_blocks, target_dtype):
208
208
@dataclass
209
209
class GGUFWeightOnlyConfig (AOBaseConfig ):
210
210
dtype : torch .dtype = torch .uint4
211
- n_super_blocks : int = 8
211
+ n_blocks_per_superblock : int = 8
212
212
213
213
214
214
@register_quantize_module_handler (GGUFWeightOnlyConfig )
@@ -221,7 +221,7 @@ def _gguf_weight_only_transform(
221
221
222
222
Args:
223
223
dtype: torch.uint1 to torch.uint8, torch.int32 supported.
224
- n_super_blocks : the number of super blocks in a 256 element block for gguf, e.g. when it is 8
224
+ n_blocks_per_superblock : the number of super blocks in a 256 element block for gguf, e.g. when it is 8
225
225
it means we have blocks of 32 and 8 blocks in a superblock of 256 elements.
226
226
Returns:
227
227
Callable for quantization transformation.
@@ -231,7 +231,9 @@ def _gguf_weight_only_transform(
231
231
return module
232
232
233
233
quantized_weight = GGUFQuantizedTensor .from_float (
234
- weight , n_super_blocks = config .n_super_blocks , target_dtype = config .dtype
234
+ weight ,
235
+ n_blocks_per_superblock = config .n_blocks_per_superblock ,
236
+ target_dtype = config .dtype ,
235
237
)
236
238
module .weight = torch .nn .Parameter (quantized_weight , requires_grad = False )
237
239
return module
0 commit comments