Skip to content

Commit 5795b94

Browse files
authored
convert-hf : match model part name prefix and suffix (#7687)
In #7075, to fix the conversion of (some) models using model-00001-of-00001.safetensors instead of model.safetensors for a single model part we simply used the same logic as the part count to get the part names. But this doesn't always work correctly, like when unusual additional model files like consolidated.safetensors in https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3 are present. This commit matching both the prefix and the suffix of the model part names should fix this problem without breaking any previously-supported upstream models. But according to report by @teleprint-me there is still some persistent problem, but shall do in the meantime.
1 parent ed9f252 commit 5795b94

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

convert-hf-to-gguf.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
7373
self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
7474
self.use_temp_file = use_temp_file
7575
self.lazy = not eager
76-
self.part_names = Model.get_model_part_names(self.dir_model, ".safetensors")
76+
self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
7777
self.is_safetensors = len(self.part_names) > 0
7878
if not self.is_safetensors:
79-
self.part_names = Model.get_model_part_names(self.dir_model, ".bin")
79+
self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
8080
self.hparams = Model.load_hparams(self.dir_model)
8181
self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer"])
8282
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
@@ -335,10 +335,10 @@ def write_vocab(self):
335335
self.gguf_writer.close()
336336

337337
@staticmethod
338-
def get_model_part_names(dir_model: Path, suffix: str) -> list[str]:
338+
def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
339339
part_names: list[str] = []
340340
for filename in os.listdir(dir_model):
341-
if filename.endswith(suffix):
341+
if filename.startswith(prefix) and filename.endswith(suffix):
342342
part_names.append(filename)
343343

344344
part_names.sort()

0 commit comments

Comments
 (0)