Skip to content

Commit cd38c87

Browse files
committed
Use Q8_0 quantization from gguf module.
This makes tensors exactly as in https://huggingface.co/Arki05/Grok-1-GGUF/tree/main/Q8_0
1 parent 60fa5a9 commit cd38c87

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

convert_grok.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def get_weights(fn):
124124

125125

126126
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
127-
# equivalent to ggml_quantize_q8_0 in ggml.c
127+
# equivalent to ggml_quantize_q8_0 in ggml.c (modulo rounding away from zero)
128128
assert tensor.shape[1] % GGML_QK8_0 == 0
129129
tensor = tensor.reshape(-1, GGML_QK8_0)
130130
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
@@ -135,7 +135,7 @@ def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
135135

136136

137137
def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
138-
# equivalent to ggml_quantize_q4_0 in ggml.c
138+
# equivalent to ggml_quantize_q4_0 in ggml.c (modulo rounding away from zero)
139139
assert tensor.shape[1] % GGML_QK4_0 == 0
140140
tensor = tensor.reshape(-1, GGML_QK4_0)
141141
abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
@@ -150,7 +150,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
150150

151151

152152
def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
153-
# equivalent to ggml_quantize_q4_1 in ggml.c
153+
# equivalent to ggml_quantize_q4_1 in ggml.c (modulo rounding away from zero)
154154
assert tensor.shape[1] % GGML_QK4_1 == 0
155155
tensor = tensor.reshape(-1, GGML_QK4_1)
156156
abs_max_indices = tensor.max(dim=-1, keepdim=True).indices
@@ -170,13 +170,14 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
170170

171171
def maybe_quantize_tensor(tensor, ggml_type):
172172
assert tensor.dtype == torch.float32
173-
174173
if ggml_type == gguf.GGMLQuantizationType.F32:
175174
return tensor.float()
176175
elif ggml_type == gguf.GGMLQuantizationType.F16:
177176
return tensor.half()
178177
elif ggml_type == gguf.GGMLQuantizationType.Q8_0:
179-
return quantize_q8_0(tensor)
178+
if tensor.device.type == "meta":
179+
return quantize_q8_0(tensor) # Cannot convert into numpy array.
180+
return torch.from_numpy(gguf.quantize_q8_0(tensor.numpy()))
180181
elif ggml_type == gguf.GGMLQuantizationType.Q4_0:
181182
return quantize_q4_0(tensor)
182183
elif ggml_type == gguf.GGMLQuantizationType.Q4_1:

0 commit comments

Comments
 (0)