Skip to content

Commit 72920b1

Browse files
committed
Don't split MoE weights.
As per ggml-org#7058 (comment). This helps avoid a memcopy when running.
1 parent 7c54f47 commit 72920b1

File tree

1 file changed

+13
-22
lines changed

1 file changed

+13
-22
lines changed

convert_grok.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,14 @@ def maybe_quantize_tensor(tensor, ggml_type):
185185

186186

187187
def get_dtype_and_ggml_type(tensor, ggml_type):
188-
if tensor.ndim == 2:
188+
if tensor.ndim in (2, 3):
189189
if tensor.shape[1] % GGML_QK8_0 == 0:
190190
return np.int8, ggml_type
191191
else:
192192
return np.float16, gguf.GGMLQuantizationType.F16
193193
else:
194194
# 1d weight: convert it to float32
195-
assert tensor.ndim == 1
195+
assert tensor.ndim == 1, tensor
196196
return np.float32, gguf.GGMLQuantizationType.F32
197197

198198

@@ -236,15 +236,15 @@ def dump_state_dict(f, weight_names, model_files, ggml_type, config):
236236
cache.update(state_dict)
237237
tensor = cache.pop(key)
238238
_, tensor_ggml_type = get_dtype_and_ggml_type(tensor, ggml_type)
239-
tensor = maybe_quantize_tensor(tensor, tensor_ggml_type)
239+
array = maybe_quantize_tensor(tensor, tensor_ggml_type).numpy()
240240

241-
array = tensor.numpy()
242241
print(
243-
f"dumping {key}: {tensor_ggml_type.name}/{array.dtype}, {array.shape}, {array.nbytes} bytes"
242+
f"dumping {key}:",
243+
f"{tensor_ggml_type.name}/{array.dtype}, {list(tensor.shape)}, {array.nbytes} bytes",
244244
)
245245
f.write_tensor_data(array)
246246

247-
tensor_info.append((key, tensor.shape, tensor_ggml_type.name))
247+
tensor_info.append((key, list(tensor.shape), tensor_ggml_type.name))
248248

249249
try:
250250
print(tabulate(tensor_info, headers=["name", "shape", "dtype"], tablefmt="psql"))
@@ -282,15 +282,10 @@ def convert_weight(tensor_name, weight, scales, experts, dtype=torch.float32, de
282282
if len(weight.shape) >= 2 and "token_embd" not in tensor_name:
283283
weight = weight.transpose(-1, -2)
284284

285-
if tensor_name.endswith("ffn_gate_inp.weight"):
285+
if tensor_name.endswith("ffn_gate_inp.weight") or tensor_name.endswith("_exps.weight"):
286286
result[tensor_name] = weight[experts] # gather.
287287
elif "experts" not in tensor_name:
288288
result[tensor_name] = weight
289-
else:
290-
# split moe
291-
for i, expert in enumerate(experts):
292-
key = tensor_name.replace("experts", str(i))
293-
result[key] = weight[expert]
294289

295290
return result
296291

@@ -328,14 +323,10 @@ def extract_vocabulary_from_model(vocab):
328323
def get_weight_names(config):
329324
weight_names = ["token_embd.weight"]
330325
for i in range(config.num_hidden_layers):
331-
for j in range(config.num_experts):
332-
weight_names += [
333-
f"blk.{i}.ffn_gate.{j}.weight",
334-
f"blk.{i}.ffn_down.{j}.weight",
335-
f"blk.{i}.ffn_up.{j}.weight",
336-
]
337-
338326
weight_names += [
327+
f"blk.{i}.ffn_gate_exps.weight",
328+
f"blk.{i}.ffn_down_exps.weight",
329+
f"blk.{i}.ffn_up_exps.weight",
339330
f"blk.{i}.attn_k.weight",
340331
f"blk.{i}.attn_output.weight",
341332
f"blk.{i}.attn_q.weight",
@@ -399,9 +390,9 @@ def ffn_size(emb_size, widening_factor):
399390
]
400391
for i in range(config.num_hidden_layers):
401392
tensor_names += [
402-
f"blk.{i}.ffn_gate.experts.weight",
403-
f"blk.{i}.ffn_down.experts.weight",
404-
f"blk.{i}.ffn_up.experts.weight",
393+
f"blk.{i}.ffn_gate_exps.weight",
394+
f"blk.{i}.ffn_down_exps.weight",
395+
f"blk.{i}.ffn_up_exps.weight",
405396
f"blk.{i}.attn_k.weight",
406397
f"blk.{i}.attn_output.weight",
407398
f"blk.{i}.attn_q.weight",

0 commit comments

Comments
 (0)