From e8b00edab2ab170fc15ec2de828d36267cb3ff8b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 26 Apr 2025 11:23:08 +0200 Subject: [PATCH 1/6] convert : improve model arch handling --- convert_hf_to_gguf.py | 61 ++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index cf35fb86ecfec..ec298cc72d68d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -66,8 +66,6 @@ class ModelBase: part_names: list[str] is_safetensors: bool hparams: dict[str, Any] - block_count: int - tensor_map: gguf.TensorNameMap tensor_names: set[str] | None gguf_writer: gguf.GGUFWriter model_name: str | None @@ -78,6 +76,10 @@ class ModelBase: # subclasses should define this! model_arch: gguf.MODEL_ARCH + # subclasses should initialize this! + block_count: int + tensor_map: gguf.TensorNameMap + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, @@ -113,8 +115,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: if not self.is_safetensors: self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) - self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) self.tensor_names = None self.metadata_override = metadata_override self.model_name = model_name @@ -418,14 +418,7 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str] @staticmethod def load_hparams(dir_model: Path): with open(dir_model / "config.json", "r", encoding="utf-8") as f: - hparams = json.load(f) - architectures = hparams.get("architectures") - if "text_config" in hparams: - hparams = {**hparams, **hparams["text_config"]} - if architectures is not None: - # preserve "architectures" from root level config - hparams["architectures"] = architectures - return hparams + return json.load(f) @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -454,6 +447,16 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type class TextModel(ModelBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if "text_config" in self.hparams: + # move the text_config to the root level + self.hparams = {**self.hparams, **self.hparams["text_config"]} + + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + @classmethod def __init_subclass__(cls): # can't use an abstract property, because overriding it without type errors @@ -1078,8 +1081,12 @@ def __init__(self, *args, **kwargs): raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION") # small hack to correct the number of layers - self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128) - self.n_embd_text = self.find_hparam(["hidden_size", "n_embd"]) + self.block_count = 512 # vision models are small, this "ought to be enough for anybody" + self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count) + + # get n_embd of the text model + text_config = {**self.hparams, **self.hparams["text_config"]} + self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) assert self.n_embd_text > 0, "n_embd not found in hparams" if "vision_config" not in self.hparams: @@ -1726,8 +1733,7 @@ def prepare_tensors(self): "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", - "Idefics3ForConditionalGeneration", - "SmolVLMForConditionalGeneration", + "VLlama3ForCausalLM", "LlavaForConditionalGeneration") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -1735,11 +1741,12 @@ class LlamaModel(TextModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + arch = get_model_architecture(self.dir_model, ModelType.TEXT, self.hparams) # fix for SmolVLM2, missing `num_attention_heads` in config.json - if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration": + if arch == "VLlama3ForCausalLM": self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) # fix for Pixtral, missing `num_attention_heads` in config.json - if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \ + if arch == "LlavaForConditionalGeneration" \ and self.hparams.get("model_type") == "mistral": self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) @@ -5805,6 +5812,19 @@ def split_str_to_n_bytes(split_str: str) -> int: return n +def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any = None) -> str: + hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams + text_config = hparams.get("text_config", {}) + vision_config = hparams.get("vision_config", {}) + arch = hparams["architectures"][0] + # if "architectures" is found in the sub-config, use that instead + if model_type == ModelType.TEXT and text_config.get("architectures") is not None: + arch = text_config["architectures"][0] + elif model_type == ModelType.VISION and vision_config.get("architectures") is not None: + arch = vision_config["architectures"][0] + return arch + + def main() -> None: args = parse_args() @@ -5857,16 +5877,15 @@ def main() -> None: logger.info(f"Loading model: {dir_model.name}") - hparams = ModelBase.load_hparams(dir_model) - if args.mmproj: if "mmproj" not in fname_out.name: fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-") with torch.inference_mode(): output_type = ftype_map[args.outtype] - model_architecture = hparams["architectures"][0] model_type = ModelType.VISION if args.mmproj else ModelType.TEXT + model_architecture = get_model_architecture(dir_model, model_type) + logger.info(f"Model architecture: {model_architecture}") try: model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type) except NotImplementedError: From 4840d2fe916f7220bd2f825fa7c988258696dba7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 26 Apr 2025 16:28:09 +0200 Subject: [PATCH 2/6] use AutoConfig --- convert_hf_to_gguf.py | 40 ++++++++++++++-------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ec298cc72d68d..ccf65a889917f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -16,6 +16,7 @@ from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast from itertools import chain +from transformers import AutoConfig import math import numpy as np @@ -417,8 +418,13 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str] @staticmethod def load_hparams(dir_model: Path): - with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) + try: + return AutoConfig.from_pretrained(dir_model, trust_remote_code=True).to_dict() + except Exception as e: + logger.warning(f"Failed to load model config from {dir_model}: {e}") + logger.warning("Trying to load config.json instead") + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + return json.load(f) @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -1080,10 +1086,6 @@ def __init__(self, *args, **kwargs): if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION: raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION") - # small hack to correct the number of layers - self.block_count = 512 # vision models are small, this "ought to be enough for anybody" - self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count) - # get n_embd of the text model text_config = {**self.hparams, **self.hparams["text_config"]} self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) @@ -1095,6 +1097,9 @@ def __init__(self, *args, **kwargs): self.global_config = self.hparams self.hparams = self.hparams["vision_config"] + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) + self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count) + # load preprocessor config with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: self.preprocessor_config = json.load(f) @@ -1739,17 +1744,6 @@ class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA undo_permute = True - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - arch = get_model_architecture(self.dir_model, ModelType.TEXT, self.hparams) - # fix for SmolVLM2, missing `num_attention_heads` in config.json - if arch == "VLlama3ForCausalLM": - self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) - # fix for Pixtral, missing `num_attention_heads` in config.json - if arch == "LlavaForConditionalGeneration" \ - and self.hparams.get("model_type") == "mistral": - self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) - def set_vocab(self): try: self._set_vocab_sentencepiece() @@ -1912,11 +1906,7 @@ class LlavaVisionModel(VisionModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.hparams["model_type"] == "pixtral": - # fix missing config.json values - self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16) - self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24) - self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096) - self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024) + # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) self.img_break_tok_id = 12 # see tokenizer_config.json else: @@ -1927,7 +1917,6 @@ def set_gguf_parameters(self): hparams = self.hparams if hparams["model_type"] == "pixtral": self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL) - # default values below are taken from HF tranformers code self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) self.gguf_writer.add_vision_use_silu(True) @@ -1958,13 +1947,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class SmolVLMModel(VisionModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # fix for SmolVLM2, missing some keys in config.json - # default values are taken from transformers code if self.hparams["model_type"] == "smolvlm_vision": + # fix for SmolVLM2, missing some keys in config.json + # default values are taken from transformers code self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152) self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16) self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072) - self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12) def set_gguf_parameters(self): super().set_gguf_parameters() From d11dccb6ac31133a385fbe758be608b4f097e728 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 26 Apr 2025 21:09:47 +0200 Subject: [PATCH 3/6] rm trust_remote_code --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ccf65a889917f..1451dcbedcd6a 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -419,7 +419,7 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str] @staticmethod def load_hparams(dir_model: Path): try: - return AutoConfig.from_pretrained(dir_model, trust_remote_code=True).to_dict() + return AutoConfig.from_pretrained(dir_model).to_dict() except Exception as e: logger.warning(f"Failed to load model config from {dir_model}: {e}") logger.warning("Trying to load config.json instead") From e5c5fd764b07cf5c27b27d326ff1d5a9b25cde11 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 30 Apr 2025 11:50:05 +0200 Subject: [PATCH 4/6] Update convert_hf_to_gguf.py --- convert_hf_to_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3267e45e0dfff..fc5282d392b63 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -460,7 +460,7 @@ def __init__(self, *args, **kwargs): # move the text_config to the root level self.hparams = {**self.hparams, **self.hparams["text_config"]} - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @classmethod From 1a0485d52e84c1cc7f73fad140599307e2f1b286 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 30 Apr 2025 11:53:37 +0200 Subject: [PATCH 5/6] fix self.block_count for vision --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index fc5282d392b63..514cea246f4f5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -460,7 +460,7 @@ def __init__(self, *args, **kwargs): # move the text_config to the root level self.hparams = {**self.hparams, **self.hparams["text_config"]} - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]) + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @classmethod @@ -1097,7 +1097,7 @@ def __init__(self, *args, **kwargs): self.global_config = self.hparams self.hparams = self.hparams["vision_config"] - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"]) self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, self.block_count) # load preprocessor config @@ -1117,7 +1117,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"])) self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"])) self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"])) - self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"])) + self.gguf_writer.add_vision_block_count(self.block_count) self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"])) # preprocessor config From a21e755324f7627e0c1b79b71b3331642f798ad6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 30 Apr 2025 11:58:02 +0200 Subject: [PATCH 6/6] fix NomicBertModel --- convert_hf_to_gguf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 514cea246f4f5..249f72b1504bc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3507,6 +3507,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("NomicBertModel") class NomicBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.BERT + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any): hparams = kwargs.pop("hparams", None) if hparams is None: