Skip to content

Commit e974a68

Browse files
committed
Implement Q8_0 quantization fully in PyTorch.
This is equivalent to gguf.quantize_q8_0 but doesn't round-trip to Numpy.
1 parent 4eaa704 commit e974a68

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

convert_grok.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,21 @@ def get_weights(fn):
123123
assert len(arrays) in (1, 2)
124124

125125

126+
def torch_roundf(t: torch.Tensor) -> torch.Tensor:
127+
"""Round halfway cases away from zero like roundf(3). Cf. gguf/quants.py."""
128+
a = abs(t)
129+
floored = torch.floor(a)
130+
b = floored + torch.floor(2 * (a - floored))
131+
return torch.sign(t) * b
132+
133+
126134
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
127135
# equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero)
128136
assert tensor.shape[1] % QK8_0 == 0
129137
tensor = tensor.reshape(-1, QK8_0)
130138
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
131-
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
139+
iscale = torch.where(scale != 0.0, 1.0 / scale, 0.0)
140+
tensor = torch_roundf(tensor * iscale).clamp(min=-128, max=127).char()
132141
# add scale into each block
133142
tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
134143
return tensor
@@ -175,9 +184,7 @@ def maybe_quantize_tensor(tensor, ggml_type):
175184
elif ggml_type == gguf.GGMLQuantizationType.F16:
176185
return tensor.half()
177186
elif ggml_type == gguf.GGMLQuantizationType.Q8_0:
178-
if tensor.device.type == "meta":
179-
return quantize_q8_0(tensor) # Cannot convert into numpy array.
180-
return torch.from_numpy(gguf.quantize_q8_0(tensor.numpy()))
187+
return quantize_q8_0(tensor)
181188
elif ggml_type == gguf.GGMLQuantizationType.Q4_0:
182189
return quantize_q4_0(tensor)
183190
elif ggml_type == gguf.GGMLQuantizationType.Q4_1:

0 commit comments

Comments
 (0)