@@ -123,12 +123,21 @@ def get_weights(fn):
123
123
assert len (arrays ) in (1 , 2 )
124
124
125
125
126
+ def torch_roundf (t : torch .Tensor ) -> torch .Tensor :
127
+ """Round halfway cases away from zero like roundf(3). Cf. gguf/quants.py."""
128
+ a = abs (t )
129
+ floored = torch .floor (a )
130
+ b = floored + torch .floor (2 * (a - floored ))
131
+ return torch .sign (t ) * b
132
+
133
+
126
134
def quantize_q8_0 (tensor : torch .Tensor ) -> torch .CharTensor :
127
135
# equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero)
128
136
assert tensor .shape [1 ] % QK8_0 == 0
129
137
tensor = tensor .reshape (- 1 , QK8_0 )
130
138
scale = tensor .abs ().max (dim = - 1 , keepdim = True ).values / ((1 << 7 ) - 1 )
131
- tensor = (tensor / scale ).round ().clamp (min = - 128 , max = 127 ).char ()
139
+ iscale = torch .where (scale != 0.0 , 1.0 / scale , 0.0 )
140
+ tensor = torch_roundf (tensor * iscale ).clamp (min = - 128 , max = 127 ).char ()
132
141
# add scale into each block
133
142
tensor = torch .cat ((scale .half ().view (torch .int8 ), tensor ), dim = - 1 )
134
143
return tensor
@@ -175,9 +184,7 @@ def maybe_quantize_tensor(tensor, ggml_type):
175
184
elif ggml_type == gguf .GGMLQuantizationType .F16 :
176
185
return tensor .half ()
177
186
elif ggml_type == gguf .GGMLQuantizationType .Q8_0 :
178
- if tensor .device .type == "meta" :
179
- return quantize_q8_0 (tensor ) # Cannot convert into numpy array.
180
- return torch .from_numpy (gguf .quantize_q8_0 (tensor .numpy ()))
187
+ return quantize_q8_0 (tensor )
181
188
elif ggml_type == gguf .GGMLQuantizationType .Q4_0 :
182
189
return quantize_q4_0 (tensor )
183
190
elif ggml_type == gguf .GGMLQuantizationType .Q4_1 :
0 commit comments