Skip to content

Commit 074e42a

Browse files
authored
convert : converting mmproj for Qwen2/2.5VL from convert_hf_to_gguf (#13209)
* wip * qwen2.5vl ok * vision: fix models missing "text_config" * add test * fix test repo name * fix 32B model * Revert "fix 32B model" This reverts commit 651752f. * clarify about 32B * rm qwen surgery script * update llava/readme * move V_ENC_EMBD_PATCH handling to Qwen2VLVisionModel
1 parent c642bc0 commit 074e42a

File tree

7 files changed

+132
-233
lines changed

7 files changed

+132
-233
lines changed

convert_hf_to_gguf.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,8 @@ def __init__(self, *args, **kwargs):
10891089
raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
10901090

10911091
# get n_embd of the text model
1092+
if "text_config" not in self.hparams:
1093+
self.hparams["text_config"] = {}
10921094
text_config = {**self.hparams, **self.hparams["text_config"]}
10931095
self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0))
10941096
assert self.n_embd_text > 0, "n_embd not found in hparams"
@@ -2583,6 +2585,82 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
25832585
return [(self.map_tensor_name(name), data_torch)]
25842586

25852587

2588+
@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
2589+
class Qwen2VLVisionModel(VisionModel):
2590+
def __init__(self, *args, **kwargs):
2591+
super().__init__(*args, **kwargs)
2592+
self.hparams["image_size"] = self.hparams.get("image_size", 560)
2593+
# rename config.json values
2594+
self.hparams["num_attention_heads"] = self.hparams.get("num_heads")
2595+
self.hparams["num_hidden_layers"] = self.hparams.get("depth")
2596+
if "embed_dim" in self.hparams: # qwen2vl
2597+
self.hparams["intermediate_size"] = self.hparams.get("hidden_size")
2598+
self.hparams["hidden_size"] = self.hparams.get("embed_dim")
2599+
2600+
def set_gguf_parameters(self):
2601+
super().set_gguf_parameters()
2602+
hparams = self.hparams
2603+
if self.global_config['model_type'] == 'qwen2_vl':
2604+
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN2VL)
2605+
elif self.global_config['model_type'] == 'qwen2_5_vl':
2606+
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.QWEN25VL)
2607+
self.gguf_writer.add_vision_use_silu(True)
2608+
# find n_wa_pattern (window attention pattern)
2609+
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
2610+
assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl"
2611+
n_wa_pattern = fullatt_block_indexes[0] + 1
2612+
# validate n_wa_pattern
2613+
for i in range(1, len(fullatt_block_indexes)):
2614+
if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern:
2615+
raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}")
2616+
self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern)
2617+
else:
2618+
raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}")
2619+
# default values below are taken from HF tranformers code
2620+
self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6))
2621+
2622+
def tensor_force_quant(self, name, new_name, bid, n_dims):
2623+
del bid, name, n_dims # unused
2624+
if ".patch_embd." in new_name:
2625+
return gguf.GGMLQuantizationType.F16
2626+
if ".position_embd." in new_name:
2627+
return gguf.GGMLQuantizationType.F32
2628+
return False
2629+
2630+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2631+
del bid # unused
2632+
if name.startswith("visual."):
2633+
# process visual tensors
2634+
# split QKV tensors if needed
2635+
if ".qkv." in name:
2636+
if data_torch.ndim == 2: # weight
2637+
c3, _ = data_torch.shape
2638+
else: # bias
2639+
c3 = data_torch.shape[0]
2640+
assert c3 % 3 == 0
2641+
c = c3 // 3
2642+
wq = data_torch[:c]
2643+
wk = data_torch[c: c * 2]
2644+
wv = data_torch[c * 2:]
2645+
return [
2646+
(self.map_tensor_name(name.replace("qkv", "q")), wq),
2647+
(self.map_tensor_name(name.replace("qkv", "k")), wk),
2648+
(self.map_tensor_name(name.replace("qkv", "v")), wv),
2649+
]
2650+
elif 'patch_embed.proj.weight' in name:
2651+
# split Conv3D into Conv2Ds
2652+
c1, c2, kt, kh, kw = data_torch.shape
2653+
del c1, c2, kh, kw # unused
2654+
assert kt == 2, "Current implmentation only support temporal_patch_size of 2"
2655+
return [
2656+
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...]),
2657+
(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]),
2658+
]
2659+
else:
2660+
return [(self.map_tensor_name(name), data_torch)]
2661+
return [] # skip other tensors
2662+
2663+
25862664
@ModelBase.register("WavTokenizerDec")
25872665
class WavTokenizerDecModel(TextModel):
25882666
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC

examples/llava/README.md

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,16 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
3535
# Pixtral 12B
3636
llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
3737

38+
# Qwen 2 VL
39+
llama-mtmd-cli -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF
40+
llama-mtmd-cli -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF
41+
42+
# Qwen 2.5 VL
43+
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF
44+
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF
45+
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF
46+
llama-mtmd-cli -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF
47+
3848
# Mistral Small 3.1 24B (IQ2_M quantization)
3949
llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
4050
```
@@ -60,7 +70,17 @@ Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advanta
6070

6171
## How to obtain `mmproj`
6272

63-
Multimodal projector (`mmproj`) files are specific to each model architecture. Please refer to the relevant guide for instructions on how to obtain or create them:
73+
Multimodal projector (`mmproj`) files are specific to each model architecture.
74+
75+
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
76+
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
77+
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
78+
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
79+
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
80+
- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
81+
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
82+
83+
For older models, please refer to the relevant guide for instructions on how to obtain or create them:
6484

6585
- [LLaVA](../../docs/multimodal/llava.md)
6686
- [MobileVLM](../../docs/multimodal/MobileVLM.md)
@@ -70,10 +90,3 @@ Multimodal projector (`mmproj`) files are specific to each model architecture. P
7090
- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
7191
- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
7292
- [Google Gemma 3](../../docs/multimodal/gemma3.md)
73-
74-
For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
75-
- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
76-
- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
77-
- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
78-
- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
79-
- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)

examples/llava/qwen2_vl_surgery.py

Lines changed: 0 additions & 217 deletions
This file was deleted.

examples/llava/tests.sh

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,6 @@ add_test() {
3636
arr_tmpl+=("$tmpl")
3737
}
3838

39-
add_test_big() {
40-
if [ "$RUN_BIG_TESTS" = true ]; then
41-
add_test "$@"
42-
fi
43-
}
44-
4539
add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
4640
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
4741
add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
@@ -58,8 +52,16 @@ add_test "llama-mtmd-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
5852
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
5953

6054
# to test the big models, run: ./tests.sh big
61-
add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
62-
add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
55+
if [ "$RUN_BIG_TESTS" = true ]; then
56+
add_test "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
57+
add_test "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
58+
add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
59+
add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
60+
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
61+
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
62+
# add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
63+
# add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big
64+
fi
6365

6466
# these models always give the wrong answer, not sure why
6567
# add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"

0 commit comments

Comments
 (0)