diff --git a/common/arg.cpp b/common/arg.cpp
index f5e9b294f3048..0b814e58c27a3 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -1406,14 +1406,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.mmproj = value;
}
- ).set_examples({LLAMA_EXAMPLE_LLAVA}));
+ ).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_COGAGENT}));
add_opt(common_arg(
{"--image"}, "FILE",
"path to an image file. use with multimodal models. Specify multiple times for batching",
[](common_params & params, const std::string & value) {
params.image.emplace_back(value);
}
- ).set_examples({LLAMA_EXAMPLE_LLAVA}));
+ ).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_VISION, LLAMA_EXAMPLE_COGAGENT}));
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
diff --git a/common/common.h b/common/common.h
index b208d0c7ece59..535c292f36c48 100644
--- a/common/common.h
+++ b/common/common.h
@@ -80,6 +80,8 @@ enum llama_example {
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
+ LLAMA_EXAMPLE_VISION,
+ LLAMA_EXAMPLE_COGAGENT,
LLAMA_EXAMPLE_COUNT,
};
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 018a2a588ae9d..08b00a42d473d 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -17,6 +17,7 @@
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from itertools import chain
+from transformers import AutoConfig
import math
import numpy as np
import torch
@@ -66,6 +67,13 @@ class Model:
metadata_override: Path | None
dir_model_card: Path
+ # for vision model
+ vision_arch: gguf.MODEL_ARCH | None = None
+ preprocessor_config: dict[str, Any] | None = None
+ vparams: dict[str, Any] | None = None
+ v_tensor_map: gguf.TensorNameMap | None = None
+ v_tensor_names: set[str] | None
+
# subclasses should define this!
model_arch: gguf.MODEL_ARCH
@@ -126,6 +134,16 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
return None
raise KeyError(f"could not find any of: {keys}")
+ def find_vparams(self, keys: Iterable[str], optional: bool = False) -> Any:
+ if self.vparams is None:
+ raise ValueError("vision model parameters not set")
+ key = next((k for k in keys if k in self.vparams), None)
+ if key is not None:
+ return self.vparams[key]
+ if optional:
+ return None
+ raise KeyError(f"(vision) could not find any of: {keys}")
+
def set_vocab(self):
self._set_vocab_gpt2()
@@ -186,9 +204,10 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
f"Missing tensors: {missing}\n"
f"Extra tensors: {extra}")
- def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
- if key not in gguf.MODEL_TENSORS[self.model_arch]:
- raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
+ def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight", is_vision = False) -> str:
+ arch = self.vision_arch if is_vision and self.vision_arch is not None else self.model_arch
+ if key not in gguf.MODEL_TENSORS[arch]:
+ raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {arch!r}")
name: str = gguf.TENSOR_NAMES[key]
if "{bid}" in name:
assert bid is not None
@@ -210,9 +229,13 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int |
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
- if new_name is None:
+ new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) if self.v_tensor_map is not None else None
+ if new_name is not None:
+ return new_name
+ elif new_name_vision is not None:
+ return new_name_vision
+ else:
raise ValueError(f"Can not map tensor {name!r}")
- return new_name
def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count)
@@ -257,6 +280,23 @@ def set_gguf_parameters(self):
self.gguf_writer.add_key_length(head_dim)
self.gguf_writer.add_value_length(head_dim)
+ # Vision model parameters
+ if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None:
+ self.gguf_writer.add_vision_type("vit")
+ self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
+ self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
+ self.gguf_writer.add_vision_vit_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch])
+ self.gguf_writer.add_vision_vit_block_count(self.vparams["num_hidden_layers"])
+ self.gguf_writer.add_vision_vit_embedding_length(self.vparams["hidden_size"])
+ self.gguf_writer.add_vision_vit_feed_forward_length(self.vparams["intermediate_size"])
+ self.gguf_writer.add_vision_vit_head_count(self.vparams["num_attention_heads"])
+ self.gguf_writer.add_vision_vit_image_mean(self.preprocessor_config["image_mean"])
+ self.gguf_writer.add_vision_vit_image_std(self.preprocessor_config["image_std"])
+ try:
+ self.gguf_writer.add_vision_vit_select_layer(self.find_hparam(["vision_feature_layer", "mm_vision_select_layer"]))
+ except KeyError:
+ self.gguf_writer.add_vision_vit_select_layer(0)
+
self.gguf_writer.add_file_type(self.ftype)
logger.info(f"gguf: file type = {self.ftype}")
@@ -466,7 +506,25 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
@staticmethod
def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
- return json.load(f)
+ hparams = json.load(f)
+ if "text_config" in hparams:
+ text_config = hparams["text_config"]
+ model_id = text_config.get("_name_or_path", None)
+ # for example, llava-1.5-7b-hf misses the language model config, need to retrieve it via model ID
+ if model_id is not None and model_id != "None" and model_id != "":
+ text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
+ hparams = {**text_config, **hparams}
+ return hparams
+
+ @staticmethod
+ def load_preprocessor_config(dir_model: Path):
+ # TODO: this varies vastly among models, need to handle more cases in the future
+ file_path = dir_model / "preprocessor_config.json"
+ if os.path.exists(file_path):
+ with open(file_path, "r", encoding="utf-8") as f:
+ return json.load(f)
+ else:
+ raise Exception(f"Preprocessor config not found at {file_path}")
@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -948,6 +1006,29 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab
self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
+# TODO: maybe merge this with Model in the future
+class VisionModelHelper:
+ model: Model
+ tok_embd_tensor: Tensor | None = None
+
+ def __init__(self, model: Model):
+ self.model = model
+ # TODO: how to do this without reading the whole safetensor file?
+ for tname, tensor in model.get_tensors():
+ if tname.endswith("embed_tokens.weight"):
+ self.tok_embd_tensor = tensor
+
+ def get_embd_for_tokens(self, map_token_to_tensor_name: Iterable[tuple[str, gguf.MODEL_TENSOR]], tensor_name_postfix = '.weight') -> Iterable[tuple[str, Tensor]]:
+ if self.tok_embd_tensor is None:
+ raise ValueError("Token embedding tensor not found")
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(self.model.dir_model, trust_remote_code=True)
+ for token, tensor_name in map_token_to_tensor_name:
+ tok_id = tokenizer.get_vocab()[token]
+ row = self.tok_embd_tensor[tok_id]
+ yield gguf.TENSOR_NAMES[tensor_name] + tensor_name_postfix, row
+
+
@Model.register("GPTNeoXForCausalLM")
class GPTNeoXModel(Model):
model_arch = gguf.MODEL_ARCH.GPTNEOX
@@ -1560,10 +1641,38 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed norms: {norms}")
-@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
+@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM", "Idefics3ForConditionalGeneration")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ model_type = self.hparams.get("model_type", None)
+ self.vision_arch = None
+
+ # only tested with https://huggingface.co/llava-hf/llava-1.5-7b-hf
+ if "vision_config" in self.hparams and model_type == "llava":
+ self.vparams = self.hparams["vision_config"]
+ self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
+ self.vision_arch = gguf.MODEL_ARCH.VISION_LLAVA
+
+ # only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B
+ if "mm_vision_tower" in self.hparams and model_type == "mobilevlm":
+ from transformers import AutoImageProcessor
+ vision_model_id = self.hparams["mm_vision_tower"]
+ self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"]
+ self.preprocessor_config = AutoImageProcessor.from_pretrained(vision_model_id).to_dict()
+ self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM
+
+ if "vision_config" in self.hparams and model_type == "idefics3":
+ self.vparams = self.hparams["vision_config"]
+ self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
+ self.vision_arch = gguf.MODEL_ARCH.VISION_IDEFICS3
+
+ if self.vparams is not None and self.vision_arch is not None:
+ self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
+
def set_vocab(self):
try:
self._set_vocab_sentencepiece()
@@ -1613,6 +1722,24 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
+ # For vision model
+ if self.vparams is not None:
+ max_pos_embd = -1
+ self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
+ # TODO: should not hardcode these, but they are currently missing from config.json
+ if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA:
+ self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP)
+ max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
+ if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM:
+ self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.LDPV2)
+ max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
+ if self.vision_arch == gguf.MODEL_ARCH.VISION_IDEFICS3:
+ self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP)
+ self.gguf_writer.add_vision_vit_scale_factor(self.hparams["scale_factor"])
+ max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
+ self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-05)
+ self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd)
+
@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
@@ -1626,11 +1753,24 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
+ is_vision_tensor = "vision_tower" in name or "vision_model" in name
- if name.endswith(("q_proj.weight", "q_proj.bias")):
- data_torch = LlamaModel.permute(data_torch, n_head, n_head)
- if name.endswith(("k_proj.weight", "k_proj.bias")):
- data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
+ if is_vision_tensor:
+ if name.startswith("model.text_model"):
+ name = name.replace("text_model.", "") # for SmolVLM
+ else:
+ name = name.replace("model.vision_tower.", "")
+ if "post_layernorm" in name and self.vision_arch != gguf.MODEL_ARCH.VISION_IDEFICS3:
+ return [] # skip post_layernorm
+
+ if not is_vision_tensor:
+ if name.startswith("language_model"):
+ # language model tensors, remove the prefix
+ name = name.replace("language_model.", "")
+ if name.endswith(("q_proj.weight", "q_proj.bias")):
+ data_torch = LlamaModel.permute(data_torch, n_head, n_head)
+ if name.endswith(("k_proj.weight", "k_proj.bias")):
+ data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
@@ -2234,6 +2374,173 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
yield name, data
+@Model.register("MiniCPMV")
+class MiniCPMVModel(Qwen2Model):
+ # MiniCPM-V 2.5 is Qwen2 and 2.6 is Qwen-2.5
+ model_arch = gguf.MODEL_ARCH.QWEN2
+ proj_type: gguf.constants.CLIPProjectorType | None
+ resampler_n_embd = 0
+ vhelper: VisionModelHelper | None
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ model_type = self.hparams.get("model_type", None)
+
+ # only tested with https://huggingface.co/openbmb/MiniCPM-V-2_6
+ if "vision_config" in self.hparams and model_type == "minicpmv":
+ self.vparams = self.hparams["vision_config"]
+ self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
+ self.vision_arch = gguf.MODEL_ARCH.VISION_MINICPMV
+ version = str(self.hparams.get("version", "unknown"))
+ if version == "2.5":
+ self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_5
+ elif version == "2.6":
+ self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6
+ else:
+ raise ValueError(f"Unsupported MiniCPM-V version: {version}")
+ self.vhelper = VisionModelHelper(self)
+ # TODO: how to do this without reading the whole safetensor file?
+ for tname, tensor in self.get_tensors():
+ if tname == "resampler.ln_post.bias":
+ self.resampler_n_embd = tensor.shape[0]
+ if self.resampler_n_embd < 2:
+ raise ValueError("Failed to detect resampler embedding size")
+ else:
+ raise ValueError("Expected vision_config, but not found")
+
+ assert self.vparams is not None
+ assert self.vision_arch is not None
+ assert self.preprocessor_config is not None
+ self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
+ self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5]
+ self.hparams["vision_feature_layer"] = 0
+ self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ assert self.vparams is not None and self.proj_type is not None
+ self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
+ self.gguf_writer.add_vision_vit_projector_type(self.proj_type)
+ self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06)
+ max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
+ self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd)
+
+
+ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+ # because the model operates excusively on 70x70 patches for now, we should precompute the positional embeddings to gain performance
+ # in the future, we can do it in cpp if we figure out how to do it efficiently
+ yield (
+ self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True),
+ torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70)))
+ )
+ assert self.vhelper is not None
+ added_tokens = [
+ ("", gguf.MODEL_TENSOR.V_TOK_EMBD_IMAGE),
+ ("", gguf.MODEL_TENSOR.V_TOK_EMBD_END_IMAGE),
+ ("", gguf.MODEL_TENSOR.V_TOK_EMBD_SLICE),
+ ("", gguf.MODEL_TENSOR.V_TOK_EMBD_END_SLICE),
+ ]
+ for tensor_name, tensor in self.vhelper.get_embd_for_tokens(added_tokens):
+ yield tensor_name, tensor
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+
+ # for language part
+ if name.startswith("llm."):
+ return [(self.map_tensor_name(name.replace("llm.", "")), data_torch)]
+
+ # split the resampler.attn.in_proj_(weight|bias) tensors into q, k, v
+ if name.endswith("in_proj_weight") or name.endswith("in_proj_bias"):
+ assert data_torch.shape[0] == 3 * self.resampler_n_embd
+ split_tensor = data_torch.chunk(3, dim=0)
+ name_q = name.replace("in_proj_", "in_proj_q.") # in_proj_q.(weight|bias)
+ name_k = name.replace("in_proj_", "in_proj_k.") # in_proj_k.(weight|bias)
+ name_v = name.replace("in_proj_", "in_proj_v.") # in_proj_v.(weight|bias)
+ return [
+ # TODO: permute these
+ (self.map_tensor_name(name_q), split_tensor[0]),
+ (self.map_tensor_name(name_k), split_tensor[1]),
+ (self.map_tensor_name(name_v), split_tensor[2]),
+ ]
+
+ # append .weight to these tensors
+ if name == "resampler.proj" or name == "resampler.query":
+ name += ".weight"
+
+ if name.startswith("resampler.proj"):
+ data_torch = data_torch.transpose(-1, -2).contiguous()
+
+ if "post_layernorm" in name:
+ return [] # skip post_layernorm
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
+ del name, bid # unused
+ if "v.resmpl.query" in new_name or "v.resmpl.pos_embd_k" in new_name:
+ return gguf.GGMLQuantizationType.F32
+ if "v.resmpl." in new_name:
+ return gguf.GGMLQuantizationType.F32 if n_dims == 1 else gguf.GGMLQuantizationType.F16
+ return False
+
+ # utils to work with MiniCPM-V resampler
+
+ # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
+ def _get_2d_sincos_pos_embed(self, embed_dim: int, grid_size: tuple[int, int] | int, cls_token=False) -> np.ndarray:
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, int):
+ grid_h_size, grid_w_size = grid_size, grid_size
+ else:
+ grid_h_size, grid_w_size = grid_size[0], grid_size[1]
+
+ grid_h = np.arange(grid_h_size, dtype=np.float32)
+ grid_w = np.arange(grid_w_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
+ pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+ def _get_2d_sincos_pos_embed_from_grid(self, embed_dim: int, grid: np.ndarray) -> np.ndarray:
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+ def _get_1d_sincos_pos_embed_from_grid(self, embed_dim: int, pos: np.ndarray) -> np.ndarray:
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000 ** omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
@Model.register("WavTokenizerDec")
class WavTokenizerDecModel(Model):
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
@@ -4877,6 +5184,58 @@ def _reverse_hf_permute(data_torch, n_heads, hidden_dim):
data_torch = data_torch.repeat_interleave(n_heads, 0)
return data_torch
+@Model.register("CogAgentForCausalLM")
+class CogVLMModel(Model):
+ model_arch = gguf.MODEL_ARCH.COGVLM
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.ftype = gguf.LlamaFileType.ALL_F32
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # Skip boi and eoi tensors for now
+ if name.endswith("boi"):
+ return []
+ if name.endswith("eoi"):
+ return []
+ if name.startswith("model.vision"):
+ return []
+ if name.startswith("model.cross_vision"):
+ return []
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ def set_vocab(self):
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
+ vocab_size = len(tokenizer.vocab.items())
+
+ reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
+ added_vocab = tokenizer.get_added_vocab()
+ tokens: list[str] = []
+ toktypes: list[int] = []
+
+ for i in range(vocab_size):
+ if i not in reverse_vocab:
+ tokens.append(f"[PAD{i}]")
+ toktypes.append(gguf.TokenType.UNUSED)
+ else:
+ token: str = reverse_vocab[i]
+ if token in added_vocab:
+ if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token):
+ toktypes.append(gguf.TokenType.CONTROL)
+ else:
+ token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
+ toktypes.append(gguf.TokenType.USER_DEFINED)
+ else:
+ toktypes.append(gguf.TokenType.NORMAL)
+ tokens.append(token)
+
+ self.gguf_writer.add_tokenizer_model("llama")
+ self.gguf_writer.add_tokenizer_pre("default")
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+
###### CONVERSION LOGIC ######
@@ -4949,7 +5308,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
- description="Convert a huggingface model to a GGML compatible file")
+ description="Convert a huggingface model to a GGML compatible file\n\nNote: When converting vision models, this script may use internet connection to download configuration files via Hugging Face.")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 66cfab2c3b796..59973bb3395d9 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -18,6 +18,7 @@ if (EMSCRIPTEN)
else()
add_subdirectory(batched-bench)
add_subdirectory(batched)
+ add_subdirectory(cogagent)
add_subdirectory(embedding)
add_subdirectory(eval-callback)
@@ -53,6 +54,7 @@ else()
add_subdirectory(tokenize)
add_subdirectory(tts)
add_subdirectory(gen-docs)
+ add_subdirectory(vision)
if (NOT GGML_BACKEND_DL)
# these examples use the backends directly and cannot be built with dynamic loading
add_subdirectory(convert-llama2c-to-ggml)
diff --git a/examples/cogagent/CMakeLists.txt b/examples/cogagent/CMakeLists.txt
new file mode 100644
index 0000000000000..e83b1c3febcfd
--- /dev/null
+++ b/examples/cogagent/CMakeLists.txt
@@ -0,0 +1,18 @@
+set(TARGET llama-cogagent-cli)
+add_executable(${TARGET} cogagent-cli.cpp)
+add_library(cogagent OBJECT
+ vision_encoder.cpp
+ vision_encoder.h
+ cross_vision.cpp
+ cross_vision.h
+ cogagent_util.cpp
+ cogagent_util.h
+ image_util.cpp
+ image_util.h)
+set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-cogagent-cli)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE common cogagent ggml ${CMAKE_THREAD_LIBS_INIT})
+target_include_directories(cogagent PUBLIC ../../ggml/include)
+target_include_directories(cogagent PUBLIC ../../include)
+target_include_directories(cogagent PUBLIC ../../common)
+target_compile_features(${TARGET} PRIVATE cxx_std_11)
\ No newline at end of file
diff --git a/examples/cogagent/cogagent-cli.cpp b/examples/cogagent/cogagent-cli.cpp
new file mode 100644
index 0000000000000..f195fd4d2b736
--- /dev/null
+++ b/examples/cogagent/cogagent-cli.cpp
@@ -0,0 +1,285 @@
+#include "arg.h"
+#include "base64.hpp"
+#include "log.h"
+#include "common.h"
+#include "sampling.h"
+#include "llama.h"
+
+#include
+#include
+#include
+#include
+
+#include "cogagent.h"
+
+cogagent_ctx cogagent_global;
+
+// This function is mostly copied from cogagent cli
+static bool eval_string_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) {
+ int N = (int) tokens.size();
+
+ //// Processing the input tokens in batches
+ for (int i = 0; i < N; i += n_batch) {
+ int n_eval = (int) tokens.size() - i;
+ if (n_eval > n_batch) {
+ n_eval = n_batch;
+ }
+
+ std::vector pos;
+ pos.resize(n_eval);
+ for (int i=0; i &img_data,
+ int n_batch, int * n_past) {
+ int n_embd = 4096;
+ int num_tokens = 258;
+ int positions[258];
+
+ positions[0] = *n_past;
+ for (int i=0; i n_batch) {
+ n_eval = n_batch;
+ }
+ llama_batch batch = {int32_t(n_eval), nullptr, data_ptr, positions, nullptr, nullptr, nullptr, nullptr, nullptr, };
+ batch.cross_embd = cogagent_global.cross_vision_image_tensor;
+ if (llama_decode(ctx_llama, batch)) {
+ LOG_ERR("%s : failed to eval\n", __func__);
+ return false;
+ }
+ data_ptr += i * n_embd;
+ }
+ *n_past += 3;
+ return true;
+}
+
+static void print_usage(int, char ** argv) {
+ LOG("\n example usage:\n");
+ LOG("\n %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
+ LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
+}
+
+static const char * sample(struct common_sampler * smpl,
+ struct llama_context * ctx_llama,
+ int * n_past) {
+ const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
+ common_sampler_accept(smpl, id, true);
+
+ const llama_model * model = llama_get_model(ctx_llama);
+ const llama_vocab * vocab = llama_model_get_vocab(model);
+
+ static std::string ret;
+ if (llama_vocab_is_eog(vocab, id)) {
+ ret = "";
+ } else {
+ ret = common_token_to_piece(ctx_llama, id);
+ }
+ // Give the new token to the model. I'm not sure how it is stored.
+ // Perhaps it is stored in the KV cache.
+ std::vector tokens;
+ tokens.push_back(id);
+ eval_string_tokens(ctx_llama, tokens, 1, n_past);
+
+ return ret.c_str();
+}
+
+static bool run_vision_encoders(const char* vision_encoder_path, const char* image_path) {
+ // Load image and resize for the encoders
+ std::vector small_image_data; // For vision encoder
+ std::vector large_image_data; // For cross vision encoder
+ if (!load_and_stretch_image(image_path, cogagent_global.vision_encoder_img_size,
+ small_image_data, cogagent_global.norm_mean, cogagent_global.norm_deviation)) {
+ printf("Failed to load the specified image file.\n");
+ return false;
+ }
+ if (!load_and_stretch_image(image_path, cogagent_global.cross_vision_img_size,
+ large_image_data, cogagent_global.norm_mean, cogagent_global.norm_deviation)) {
+ printf("Failed to load the specified image file.\n");
+ return false;
+ }
+
+ // For debugging purposes
+ const char * vision_encoder_resized_image = "cogagent_encoders/llama_vision_encoder_input.gguf";
+ int dims[3] = {cogagent_global.vision_encoder_img_size,
+ cogagent_global.vision_encoder_img_size, 3};
+ save_tensor_from_data(small_image_data, dims, vision_encoder_resized_image);
+ const char * cross_vision_resized_image = "cogagent_encoders/llama_cross_vision_input.gguf";
+ dims[0] = cogagent_global.cross_vision_img_size;
+ dims[1] = cogagent_global.cross_vision_img_size;
+ save_tensor_from_data(large_image_data, dims, cross_vision_resized_image);
+
+ // const char * reference_vision_encoder_input = "/home/tianyue/myworkspace"
+ // "/vlm_intermediate/vision_encoder_input.gguf";
+ // const char * reference_cross_vision_input = "/home/tianyue/myworkspace"
+ // "/vlm_intermediate/cross_vision_input.gguf";
+ // // Load the reference input
+ // if (get_input(small_image_data, reference_vision_encoder_input) < 0) {
+ // printf("Failed to load small image input\n");
+ // return false;
+ // }
+ // if (get_input(large_image_data, reference_cross_vision_input) < 0) {
+ // printf("Failed to load big image input\n");
+ // return false;
+ // }
+ printf("Loaded and resized the specified image.\n");
+
+ // Load the vision encoder weights
+ if (!vision_encoder_init_load(vision_encoder_path)) {
+ printf("Failed to load vision encoder model file.\n");
+ return false;
+ }
+ printf("Vision encoder weights loaded.\n");
+
+ // Run the vision encoder
+ run_vision_encoder(small_image_data);
+ printf("Completed vision encoder run on image file.\n");
+
+ free_vision_encoder_ctx();
+
+ // Load and run the cross vision encoder
+ if (!cross_vision_init_load(vision_encoder_path)) {
+ printf("Failed to load cross vision encoder model file.\n");
+ return false;
+ }
+ printf("Cross vision encoder weights loaded.\n");
+
+ run_cross_vision(large_image_data);
+ printf("Completed cross vision encoder run on image file.\n");
+
+ free_cross_vision_ctx();
+ return true;
+}
+
+int main(int argc, char ** argv) {
+ ggml_time_init();
+ common_params params;
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COGAGENT, print_usage)) {
+ return 1;
+ }
+ common_init();
+
+ llama_backend_init();
+ llama_numa_init(params.numa);
+
+ // Initialize a GGML context to store the encoded image tensors
+ struct ggml_init_params token_ctx_params = {
+ size_t(40000000),
+ NULL,
+ false,
+ };
+ cogagent_global.token_ctx = ggml_init(token_ctx_params);
+ if (!cogagent_global.token_ctx) {
+ printf("Failed to initialize token storage context.\n");
+ return 1;
+ }
+ // Allocate the tensor for cross vision encoded image
+ cogagent_global.cross_vision_image_tensor = ggml_new_tensor_2d(
+ cogagent_global.token_ctx, GGML_TYPE_F32, 1024, 6400
+ );
+
+ // Load the images and the encoder models
+ // Then run the encoder models
+ if (!run_vision_encoders(params.mmproj.c_str(), params.image[0].c_str())) {
+ return 1;
+ }
+
+ llama_model_params model_params = common_model_params_to_llama(params);
+ llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params);
+ if (model == nullptr) {
+ printf("Failed to load decoder model\n");
+ return 1;
+ }
+
+ llama_context_params ctx_params = common_context_params_to_llama(params);
+ printf("Context size is %d tokens\n", ctx_params.n_ctx);
+ llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
+
+ if (ctx_llama == nullptr) {
+ printf("Failed to create the llama context\n");
+ return 1;
+ }
+
+ cogagent_global.ctx_llama = ctx_llama;
+ cogagent_global.cogvlm_model = model;
+
+ // At the moment I can't figure out how the llama kv cache
+ // keeps its information across runs.
+ // It seems to me that the graph is allocated for each batch,
+ // which would invalidate any tensors stored in the kv cache.
+ // I don't spot logic for separately allocating the kv cache
+ // tensors to avoid this, so it doesn't make sense.
+ // Maybe the graph isn't actually allocated for each batch?
+ // Perhaps that is why a worst case graph is allocated.
+
+ // TODO: Check if system prompt is compatible
+ std::vector begin_token;
+ const llama_vocab * vocab = llama_model_get_vocab(cogagent_global.cogvlm_model);
+ begin_token.push_back(llama_vocab_bos(vocab));
+
+ int n_past = 0;
+ printf("Run model with bos token.\n");
+ eval_string_tokens(cogagent_global.ctx_llama,
+ begin_token, params.n_batch, &n_past);
+ printf("Run model with image tokens.\n");
+ eval_image_tokens(cogagent_global.ctx_llama, cogagent_global.vision_encoder_image,
+ params.n_batch, &n_past);
+ // Tokenize user prompt
+ // Third option set to false to that the tokenizer doesn't add
+ // beginning of sentence and end of sentence
+ std::vector user_prompt_tokens = common_tokenize(
+ cogagent_global.ctx_llama, params.prompt, false, true
+ );
+ printf("Run model with user entered text tokens.\n");
+ eval_string_tokens(cogagent_global.ctx_llama, user_prompt_tokens,
+ params.n_batch, &n_past);
+
+ printf("Parsed maximum sampling length %d.\n", params.n_predict);
+ int max_len = params.n_predict < 0 ? 256 : params.n_predict;
+
+ struct common_sampler * smpl = common_sampler_init(cogagent_global.cogvlm_model, params.sampling);
+ if (!smpl) {
+ printf("Failed to initialize sampler.\n");
+ return 1;
+ }
+ printf("\nReprinting entered prompt.\n %s \n", params.prompt.c_str());
+ printf("\n\n Beginning of response.\n");
+ std::string response = "";
+ for (int i=0; i") == 0) {
+ if (i < 10) {
+ continue;
+ }
+ break;
+ }
+ printf("%s", tmp);
+ fflush(stdout);
+ }
+ common_sampler_free(smpl);
+
+ llama_model_free(model);
+ ggml_free(cogagent_global.token_ctx);
+ return 0;
+}
\ No newline at end of file
diff --git a/examples/cogagent/cogagent.h b/examples/cogagent/cogagent.h
new file mode 100644
index 0000000000000..d51a87a539a35
--- /dev/null
+++ b/examples/cogagent/cogagent.h
@@ -0,0 +1,36 @@
+#ifndef COGAGENT_H
+#define COGAGENT_H
+
+#include "vision_encoder.h"
+#include "cross_vision.h"
+#include "cogagent_util.h"
+#include "image_util.h"
+#include "ggml.h"
+#include "gguf.h"
+
+struct cogagent_ctx {
+ // Vision encoder and cross vision encoder models
+ vision_encoder_ctx vision_encoder;
+ cross_vision_ctx cross_vision;
+
+ struct llama_context * ctx_llama;
+ struct llama_model * cogvlm_model;
+
+ // Context for storing vision tokens and cross vision
+ // embedded picture tensor
+ ggml_context * token_ctx;
+
+ std::string user_prompt;
+ std::vector vision_encoder_image; // Image encoded by the vision encoder
+ struct ggml_tensor * cross_vision_image_tensor; // Image encoded by the cross vision encoder
+
+ int vision_encoder_img_size = 224;
+ int cross_vision_img_size = 1120;
+
+ float norm_mean[3] = {0.48145466, 0.4578275, 0.40821073};
+ float norm_deviation[3] = {0.26862954, 0.26130258, 0.27577711};
+};
+
+extern struct cogagent_ctx cogagent_global;
+
+#endif
\ No newline at end of file
diff --git a/examples/cogagent/cogagent_util.cpp b/examples/cogagent/cogagent_util.cpp
new file mode 100644
index 0000000000000..fce62dda3ddab
--- /dev/null
+++ b/examples/cogagent/cogagent_util.cpp
@@ -0,0 +1,162 @@
+#include "cogagent_util.h"
+
+void print_dims(struct ggml_tensor * input_tensor, const char * name) {
+ printf("Tensor %s has shape %ld x %ld x %ld x %ld and data type %d\n", name,
+ input_tensor->ne[0], input_tensor->ne[1], input_tensor->ne[2],
+ input_tensor->ne[3], input_tensor->type);
+}
+
+struct ggml_tensor * get_tensor(struct ggml_context * dst_ctx, struct ggml_context * src_ctx, std::string tensor_name, int &count_failed) {
+ struct ggml_tensor * cur_tensor = ggml_get_tensor(src_ctx, tensor_name.c_str());
+ if (!cur_tensor) {
+ printf("Retrieval of tensor %s from model context failed\n", tensor_name.c_str());
+ count_failed++;
+ return nullptr;
+ }
+ struct ggml_tensor * new_tensor = ggml_dup_tensor(dst_ctx, cur_tensor);
+ ggml_set_name(new_tensor, cur_tensor->name);
+ return new_tensor;
+}
+
+void save_tensor_filename(struct ggml_tensor * input_tensor, std::string filename) {
+ std::string prefix = "/home/tianyue/myworkspace/";
+ filename = prefix + filename;
+ gguf_context * gguf_ctx = gguf_init_empty();
+ gguf_set_val_str(gguf_ctx, "model.architecture", "cogagent");
+ gguf_set_val_u32(gguf_ctx, "general.file_type", GGML_TYPE_F32);
+
+ struct ggml_init_params params = {
+ ggml_nbytes(input_tensor) + 1000000, // Memory to allocate
+ nullptr, // Buffer location
+ false, // no_alloc=false, so that tensor data is allocated
+ };
+ struct ggml_context * tensor_ctx = ggml_init(params);
+ struct ggml_tensor * tensor_with_data = ggml_dup(tensor_ctx, input_tensor);
+ ggml_backend_tensor_get(input_tensor, tensor_with_data->data,
+ 0, ggml_nbytes(input_tensor));
+
+ ggml_set_name(tensor_with_data, "output_tensor");
+ gguf_add_tensor(gguf_ctx, tensor_with_data);
+ gguf_write_to_file(gguf_ctx, filename.c_str(), false);
+ gguf_free(gguf_ctx);
+ ggml_free(tensor_ctx);
+}
+
+void save_tensor_from_data(std::vector tensor_data, int* dims, std::string filename) {
+ std::string prefix = "/home/tianyue/myworkspace/";
+ filename = prefix + filename;
+ gguf_context * gguf_ctx = gguf_init_empty();
+ gguf_set_val_str(gguf_ctx, "model.architecture", "cogagent");
+ gguf_set_val_u32(gguf_ctx, "general.file_type", GGML_TYPE_F32);
+
+ struct ggml_init_params params = {
+ tensor_data.size() * sizeof(float) + 1000000, // Memory to allocate
+ nullptr, // Buffer location
+ false, // Allocate tensor data
+ };
+ struct ggml_context * tensor_ctx = ggml_init(params);
+ struct ggml_tensor * tensor_with_data = ggml_new_tensor_3d(tensor_ctx,
+ GGML_TYPE_F32, dims[0], dims[1], dims[2]);
+ // copy the data
+ memcpy(tensor_with_data->data, tensor_data.data(), ggml_nbytes(tensor_with_data));
+
+ ggml_set_name(tensor_with_data, "output_tensor");
+ gguf_add_tensor(gguf_ctx, tensor_with_data);
+ gguf_write_to_file(gguf_ctx, filename.c_str(), false);
+ gguf_free(gguf_ctx);
+ ggml_free(tensor_ctx);
+}
+
+// Function that loads data from the GGUF file to a temporary buffer
+// and then from the temporary buffer to the GGML backend buffer
+// Copied from the GGML MNIST example
+bool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct gguf_context * ctx_gguf) {
+ FILE * f = ggml_fopen(fname, "rb");
+ if (!f) {
+ return false;
+ }
+
+ const size_t buf_size = 4*1024*1024;
+ void * buf = malloc(buf_size);
+
+ const int n_tensors = gguf_get_n_tensors(ctx_gguf);
+ for (int i = 0; i < n_tensors; i++) {
+ const char * name = gguf_get_tensor_name(ctx_gguf, i);
+
+ struct ggml_tensor * tensor = ggml_get_tensor(ctx_ggml, name);
+ if (!tensor) {
+ // We get here if there is a tensor in the file
+ // that is not being requested for the context
+ // that we are loading into
+ continue;
+ }
+
+ const size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i);
+
+ if (fseek(f, offs, SEEK_SET) != 0) {
+ fclose(f);
+ free(buf);
+ return false;
+ }
+
+ const size_t nbytes = ggml_nbytes(tensor);
+ for (size_t pos = 0; pos < nbytes; pos += buf_size) {
+ const size_t nbytes_cpy = buf_size < nbytes - pos ? buf_size : nbytes - pos;
+
+ if (fread(buf, 1, nbytes_cpy, f) != nbytes_cpy) {
+ fclose(f);
+ free(buf);
+ return false;
+ }
+
+ ggml_backend_tensor_set(tensor, buf, pos, nbytes_cpy);
+ }
+ }
+
+ fclose(f);
+ free(buf);
+ return true;
+}
+
+int get_input(
+ std::vector &input_data,
+ const char * filename
+) {
+ struct ggml_context * meta_info;
+
+ struct gguf_init_params gguf_params = {
+ true, &meta_info,
+ };
+ struct gguf_context * gguf_ctx = gguf_init_from_file(filename, gguf_params);
+ if (!gguf_ctx) {
+ printf("Failed to initialize GGUF context. Check filename.\n");
+ return false;
+ }
+
+ // I don't know how to set tensor name when writing a GGUF file
+ // in cpp, so the cross vision output tensor doesn't appear
+ // to have a name when reading the GGUF file
+ struct ggml_tensor * meta_tensor = ggml_get_first_tensor(meta_info);
+ if (meta_tensor->type != GGML_TYPE_F32) {
+ printf("Expected the input image datatype to be float 32.\n");
+ printf("Image loading failed because the datatype is actually %d\n",
+ meta_tensor->type);
+ return -1;
+ }
+
+ size_t tensor_size = ggml_nbytes(meta_tensor);
+ printf("Input tensor size is %ld bytes\n", tensor_size);
+
+ input_data.resize(meta_tensor->ne[0] * meta_tensor->ne[1] *
+ meta_tensor->ne[2]);
+ int num_tokens = meta_tensor->ne[1];
+
+ std::ifstream input_file = std::ifstream(filename, std::ios::binary);
+ const size_t offset = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, 0);
+ input_file.seekg(offset, std::ios::beg);
+ printf("Seeked to input tensor GGUF position %ld\n", offset);
+ input_file.read(reinterpret_cast(input_data.data()), tensor_size);
+ input_file.close();
+ ggml_free(meta_info);
+ return num_tokens;
+}
\ No newline at end of file
diff --git a/examples/cogagent/cogagent_util.h b/examples/cogagent/cogagent_util.h
new file mode 100644
index 0000000000000..d31dc25db288a
--- /dev/null
+++ b/examples/cogagent/cogagent_util.h
@@ -0,0 +1,39 @@
+#ifndef COGAGENT_UTIL_H
+#define COGAGENT_UTIL_H
+
+#include "llama.h"
+#include "ggml.h"
+#include "gguf.h"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include