Skip to content

Commit 60fa5a9

Browse files
committed
Fix layer order.
1 parent d57efae commit 60fa5a9

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

convert_grok.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def get_weights(fn):
126126
def quantize_q8_0(tensor: torch.Tensor) -> torch.CharTensor:
127127
# equivalent to ggml_quantize_q8_0 in ggml.c
128128
assert tensor.shape[1] % GGML_QK8_0 == 0
129-
tensor = tensor.view(-1, GGML_QK8_0)
129+
tensor = tensor.reshape(-1, GGML_QK8_0)
130130
scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
131131
tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
132132
# add scale into each block
@@ -152,7 +152,7 @@ def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
152152
def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
153153
# equivalent to ggml_quantize_q4_1 in ggml.c
154154
assert tensor.shape[1] % GGML_QK4_1 == 0
155-
tensor = tensor.view(-1, GGML_QK4_1)
155+
tensor = tensor.reshape(-1, GGML_QK4_1)
156156
abs_max_indices = tensor.max(dim=-1, keepdim=True).indices
157157
max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
158158
abs_min_indices = tensor.min(dim=-1, keepdim=True).indices
@@ -185,15 +185,13 @@ def maybe_quantize_tensor(tensor, ggml_type):
185185
raise NotImplementedError(f"Cannot quantize tensor of dtype {tensor.dtype} ({ggml_type})")
186186

187187

188-
def get_dtype_and_ggml_type(tensor, ggml_type):
189-
if tensor.ndim in (2, 3):
188+
def get_dtype_and_ggml_type(name, tensor, ggml_type):
189+
if tensor.ndim in (2, 3) and "ffn_gate_inp" not in name:
190190
if tensor.shape[1] % GGML_QK8_0 == 0:
191191
return np.int8, ggml_type
192192
else:
193193
return np.float16, gguf.GGMLQuantizationType.F16
194194
else:
195-
# 1d weight: convert it to float32
196-
assert tensor.ndim == 1, tensor
197195
return np.float32, gguf.GGMLQuantizationType.F32
198196

199197

@@ -205,7 +203,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
205203
for idx, name in enumerate(weight_names):
206204
weight, scales = get_weights(f"{input_dir}/tensor{idx:05}_000")
207205
meta_tensor = convert_weight(name, weight, scales, config, device="meta")
208-
dtype, tensor_ggml_type = get_dtype_and_ggml_type(meta_tensor, ggml_type)
206+
dtype, tensor_ggml_type = get_dtype_and_ggml_type(name, meta_tensor, ggml_type)
209207
quantized_meta_tensor = maybe_quantize_tensor(meta_tensor, tensor_ggml_type)
210208
f.add_tensor_info(
211209
f"{name}.weight",
@@ -227,7 +225,7 @@ def dump_state_dict(f, ggml_type, input_dir, config):
227225
for name in weight_names:
228226
weight, scales = weights.pop(name)
229227
tensor = convert_weight(name, weight, scales, config)
230-
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
228+
_, tensor_ggml_type = get_dtype_and_ggml_type(name, tensor, ggml_type)
231229
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
232230

233231
logging.info(
@@ -317,7 +315,10 @@ def get_weight_names(num_hidden_layers=64):
317315
gguf.MODEL_TENSOR.FFN_GATE_INP,
318316
)
319317

320-
for bid in range(num_hidden_layers):
318+
layers = [str(bid) for bid in range(64)]
319+
layers.sort() # Lexicographic sort: 0 < 1 < 10 < 11 ... < 2 < 20 < ...
320+
321+
for bid in layers[:num_hidden_layers]:
321322
for key in layer:
322323
weight_names.append(gguf.TENSOR_NAMES[key].format(bid=bid))
323324

@@ -333,7 +334,6 @@ def ffn_size(emb_size, widening_factor):
333334
return _ffn_size
334335

335336
config = {
336-
"vocab_size": 128 * 1024,
337337
"hidden_act": "gelu",
338338
"pad_token_id": 0,
339339
"eos_token_id": 2,
@@ -366,8 +366,7 @@ def ffn_size(emb_size, widening_factor):
366366

367367
f = gguf.GGUFWriter(args.save_path, "grok", endianess=gguf.GGUFEndian.LITTLE)
368368

369-
f.add_name("grok")
370-
f.add_vocab_size(config.vocab_size)
369+
f.add_name("grok-1")
371370
f.add_context_length(config.max_position_embeddings)
372371
f.add_embedding_length(config.hidden_size)
373372
f.add_block_count(config.num_hidden_layers)
@@ -389,6 +388,8 @@ def ffn_size(emb_size, widening_factor):
389388
f.add_token_scores(scores)
390389
f.add_token_types(toktypes)
391390

391+
f.add_quantization_version(ggml_type)
392+
392393
dump_state_dict(f, ggml_type, args.input_dir, config)
393394
f.close()
394395

0 commit comments

Comments
 (0)