diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 93dc754a571c..dbcd864516ec 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -72,6 +72,10 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): # hack: ggufs have a different name than transformers if model_type == "cohere": model_type = "command-r" + if model_type == "gemma3_text": + # Gemma3 models use "gemma3_text" in HuggingFace but + # "gemma3" in GGUF architecture naming + model_type = "gemma3" if model_type in ("deepseek_v3", "deepseek_v2"): model_type = "deepseek2" # GGUF layer map assumes that we will have a merged expert weights diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 9fa8e1c78b12..7e6fc401757a 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -372,6 +372,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( @@ -442,6 +443,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: + # Revert +1 during llama.cpp conversion + # see: https://github.com/ggml-org/llama.cpp/blob/be7c3034108473beda214fd1d7c98fd6a7a3bdf5/convert_hf_to_gguf.py#L3397-L3400 + if ( + self.quant_config + and self.quant_config.get_name() == "gguf" + and name.endswith("norm.weight") + ): + loaded_weight -= 1 + if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name) ):