diff --git a/common/arg.cpp b/common/arg.cpp index f5e9b294f3048..0b814e58c27a3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1406,14 +1406,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.mmproj = value; } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_COGAGENT})); add_opt(common_arg( {"--image"}, "FILE", "path to an image file. use with multimodal models. Specify multiple times for batching", [](common_params & params, const std::string & value) { params.image.emplace_back(value); } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_VISION, LLAMA_EXAMPLE_COGAGENT})); if (llama_supports_rpc()) { add_opt(common_arg( {"--rpc"}, "SERVERS", diff --git a/common/common.h b/common/common.h index b208d0c7ece59..535c292f36c48 100644 --- a/common/common.h +++ b/common/common.h @@ -80,6 +80,8 @@ enum llama_example { LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_PARALLEL, LLAMA_EXAMPLE_TTS, + LLAMA_EXAMPLE_VISION, + LLAMA_EXAMPLE_COGAGENT, LLAMA_EXAMPLE_COUNT, }; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 018a2a588ae9d..08b00a42d473d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast from itertools import chain +from transformers import AutoConfig import math import numpy as np import torch @@ -66,6 +67,13 @@ class Model: metadata_override: Path | None dir_model_card: Path + # for vision model + vision_arch: gguf.MODEL_ARCH | None = None + preprocessor_config: dict[str, Any] | None = None + vparams: dict[str, Any] | None = None + v_tensor_map: gguf.TensorNameMap | None = None + v_tensor_names: set[str] | None + # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -126,6 +134,16 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: return None raise KeyError(f"could not find any of: {keys}") + def find_vparams(self, keys: Iterable[str], optional: bool = False) -> Any: + if self.vparams is None: + raise ValueError("vision model parameters not set") + key = next((k for k in keys if k in self.vparams), None) + if key is not None: + return self.vparams[key] + if optional: + return None + raise KeyError(f"(vision) could not find any of: {keys}") + def set_vocab(self): self._set_vocab_gpt2() @@ -186,9 +204,10 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: f"Missing tensors: {missing}\n" f"Extra tensors: {extra}") - def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: - if key not in gguf.MODEL_TENSORS[self.model_arch]: - raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") + def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight", is_vision = False) -> str: + arch = self.vision_arch if is_vision and self.vision_arch is not None else self.model_arch + if key not in gguf.MODEL_TENSORS[arch]: + raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {arch!r}") name: str = gguf.TENSOR_NAMES[key] if "{bid}" in name: assert bid is not None @@ -210,9 +229,13 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) - if new_name is None: + new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) if self.v_tensor_map is not None else None + if new_name is not None: + return new_name + elif new_name_vision is not None: + return new_name_vision + else: raise ValueError(f"Can not map tensor {name!r}") - return new_name def set_gguf_parameters(self): self.gguf_writer.add_block_count(self.block_count) @@ -257,6 +280,23 @@ def set_gguf_parameters(self): self.gguf_writer.add_key_length(head_dim) self.gguf_writer.add_value_length(head_dim) + # Vision model parameters + if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None: + self.gguf_writer.add_vision_type("vit") + self.gguf_writer.add_vision_image_size(self.vparams["image_size"]) + self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"]) + self.gguf_writer.add_vision_vit_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch]) + self.gguf_writer.add_vision_vit_block_count(self.vparams["num_hidden_layers"]) + self.gguf_writer.add_vision_vit_embedding_length(self.vparams["hidden_size"]) + self.gguf_writer.add_vision_vit_feed_forward_length(self.vparams["intermediate_size"]) + self.gguf_writer.add_vision_vit_head_count(self.vparams["num_attention_heads"]) + self.gguf_writer.add_vision_vit_image_mean(self.preprocessor_config["image_mean"]) + self.gguf_writer.add_vision_vit_image_std(self.preprocessor_config["image_std"]) + try: + self.gguf_writer.add_vision_vit_select_layer(self.find_hparam(["vision_feature_layer", "mm_vision_select_layer"])) + except KeyError: + self.gguf_writer.add_vision_vit_select_layer(0) + self.gguf_writer.add_file_type(self.ftype) logger.info(f"gguf: file type = {self.ftype}") @@ -466,7 +506,25 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str] @staticmethod def load_hparams(dir_model: Path): with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) + hparams = json.load(f) + if "text_config" in hparams: + text_config = hparams["text_config"] + model_id = text_config.get("_name_or_path", None) + # for example, llava-1.5-7b-hf misses the language model config, need to retrieve it via model ID + if model_id is not None and model_id != "None" and model_id != "": + text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict() + hparams = {**text_config, **hparams} + return hparams + + @staticmethod + def load_preprocessor_config(dir_model: Path): + # TODO: this varies vastly among models, need to handle more cases in the future + file_path = dir_model / "preprocessor_config.json" + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + else: + raise Exception(f"Preprocessor config not found at {file_path}") @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -948,6 +1006,29 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0]) +# TODO: maybe merge this with Model in the future +class VisionModelHelper: + model: Model + tok_embd_tensor: Tensor | None = None + + def __init__(self, model: Model): + self.model = model + # TODO: how to do this without reading the whole safetensor file? + for tname, tensor in model.get_tensors(): + if tname.endswith("embed_tokens.weight"): + self.tok_embd_tensor = tensor + + def get_embd_for_tokens(self, map_token_to_tensor_name: Iterable[tuple[str, gguf.MODEL_TENSOR]], tensor_name_postfix = '.weight') -> Iterable[tuple[str, Tensor]]: + if self.tok_embd_tensor is None: + raise ValueError("Token embedding tensor not found") + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.model.dir_model, trust_remote_code=True) + for token, tensor_name in map_token_to_tensor_name: + tok_id = tokenizer.get_vocab()[token] + row = self.tok_embd_tensor[tok_id] + yield gguf.TENSOR_NAMES[tensor_name] + tensor_name_postfix, row + + @Model.register("GPTNeoXForCausalLM") class GPTNeoXModel(Model): model_arch = gguf.MODEL_ARCH.GPTNEOX @@ -1560,10 +1641,38 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed norms: {norms}") -@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") +@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM", "Idefics3ForConditionalGeneration") class LlamaModel(Model): model_arch = gguf.MODEL_ARCH.LLAMA + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + model_type = self.hparams.get("model_type", None) + self.vision_arch = None + + # only tested with https://huggingface.co/llava-hf/llava-1.5-7b-hf + if "vision_config" in self.hparams and model_type == "llava": + self.vparams = self.hparams["vision_config"] + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) + self.vision_arch = gguf.MODEL_ARCH.VISION_LLAVA + + # only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B + if "mm_vision_tower" in self.hparams and model_type == "mobilevlm": + from transformers import AutoImageProcessor + vision_model_id = self.hparams["mm_vision_tower"] + self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"] + self.preprocessor_config = AutoImageProcessor.from_pretrained(vision_model_id).to_dict() + self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM + + if "vision_config" in self.hparams and model_type == "idefics3": + self.vparams = self.hparams["vision_config"] + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) + self.vision_arch = gguf.MODEL_ARCH.VISION_IDEFICS3 + + if self.vparams is not None and self.vision_arch is not None: + self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"]) + def set_vocab(self): try: self._set_vocab_sentencepiece() @@ -1613,6 +1722,24 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + # For vision model + if self.vparams is not None: + max_pos_embd = -1 + self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) + # TODO: should not hardcode these, but they are currently missing from config.json + if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA: + self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 + if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM: + self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.LDPV2) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1 + if self.vision_arch == gguf.MODEL_ARCH.VISION_IDEFICS3: + self.gguf_writer.add_vision_vit_projector_type(gguf.constants.CLIPProjectorType.MLP) + self.gguf_writer.add_vision_vit_scale_factor(self.hparams["scale_factor"]) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-05) + self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd) + @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: @@ -1626,11 +1753,24 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + is_vision_tensor = "vision_tower" in name or "vision_model" in name - if name.endswith(("q_proj.weight", "q_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_head) - if name.endswith(("k_proj.weight", "k_proj.bias")): - data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + if is_vision_tensor: + if name.startswith("model.text_model"): + name = name.replace("text_model.", "") # for SmolVLM + else: + name = name.replace("model.vision_tower.", "") + if "post_layernorm" in name and self.vision_arch != gguf.MODEL_ARCH.VISION_IDEFICS3: + return [] # skip post_layernorm + + if not is_vision_tensor: + if name.startswith("language_model"): + # language model tensors, remove the prefix + name = name.replace("language_model.", "") + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) # process the experts separately if name.find("block_sparse_moe.experts") != -1: @@ -2234,6 +2374,173 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: yield name, data +@Model.register("MiniCPMV") +class MiniCPMVModel(Qwen2Model): + # MiniCPM-V 2.5 is Qwen2 and 2.6 is Qwen-2.5 + model_arch = gguf.MODEL_ARCH.QWEN2 + proj_type: gguf.constants.CLIPProjectorType | None + resampler_n_embd = 0 + vhelper: VisionModelHelper | None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + model_type = self.hparams.get("model_type", None) + + # only tested with https://huggingface.co/openbmb/MiniCPM-V-2_6 + if "vision_config" in self.hparams and model_type == "minicpmv": + self.vparams = self.hparams["vision_config"] + self.preprocessor_config = self.load_preprocessor_config(self.dir_model) + self.vision_arch = gguf.MODEL_ARCH.VISION_MINICPMV + version = str(self.hparams.get("version", "unknown")) + if version == "2.5": + self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_5 + elif version == "2.6": + self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6 + else: + raise ValueError(f"Unsupported MiniCPM-V version: {version}") + self.vhelper = VisionModelHelper(self) + # TODO: how to do this without reading the whole safetensor file? + for tname, tensor in self.get_tensors(): + if tname == "resampler.ln_post.bias": + self.resampler_n_embd = tensor.shape[0] + if self.resampler_n_embd < 2: + raise ValueError("Failed to detect resampler embedding size") + else: + raise ValueError("Expected vision_config, but not found") + + assert self.vparams is not None + assert self.vision_arch is not None + assert self.preprocessor_config is not None + self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5] + self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5] + self.hparams["vision_feature_layer"] = 0 + self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"]) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.vparams is not None and self.proj_type is not None + self.gguf_writer.add_vision_vit_patch_merge_type(gguf.CLIPPatchMergeType.FLAT) + self.gguf_writer.add_vision_vit_projector_type(self.proj_type) + self.gguf_writer.add_vision_vit_layer_norm_epsilon(1e-06) + max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + self.gguf_writer.add_vision_vit_max_position_embeddings(max_pos_embd) + + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # because the model operates excusively on 70x70 patches for now, we should precompute the positional embeddings to gain performance + # in the future, we can do it in cpp if we figure out how to do it efficiently + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True), + torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70))) + ) + assert self.vhelper is not None + added_tokens = [ + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_IMAGE), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_END_IMAGE), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_SLICE), + ("", gguf.MODEL_TENSOR.V_TOK_EMBD_END_SLICE), + ] + for tensor_name, tensor in self.vhelper.get_embd_for_tokens(added_tokens): + yield tensor_name, tensor + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # for language part + if name.startswith("llm."): + return [(self.map_tensor_name(name.replace("llm.", "")), data_torch)] + + # split the resampler.attn.in_proj_(weight|bias) tensors into q, k, v + if name.endswith("in_proj_weight") or name.endswith("in_proj_bias"): + assert data_torch.shape[0] == 3 * self.resampler_n_embd + split_tensor = data_torch.chunk(3, dim=0) + name_q = name.replace("in_proj_", "in_proj_q.") # in_proj_q.(weight|bias) + name_k = name.replace("in_proj_", "in_proj_k.") # in_proj_k.(weight|bias) + name_v = name.replace("in_proj_", "in_proj_v.") # in_proj_v.(weight|bias) + return [ + # TODO: permute these + (self.map_tensor_name(name_q), split_tensor[0]), + (self.map_tensor_name(name_k), split_tensor[1]), + (self.map_tensor_name(name_v), split_tensor[2]), + ] + + # append .weight to these tensors + if name == "resampler.proj" or name == "resampler.query": + name += ".weight" + + if name.startswith("resampler.proj"): + data_torch = data_torch.transpose(-1, -2).contiguous() + + if "post_layernorm" in name: + return [] # skip post_layernorm + + return [(self.map_tensor_name(name), data_torch)] + + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + del name, bid # unused + if "v.resmpl.query" in new_name or "v.resmpl.pos_embd_k" in new_name: + return gguf.GGMLQuantizationType.F32 + if "v.resmpl." in new_name: + return gguf.GGMLQuantizationType.F32 if n_dims == 1 else gguf.GGMLQuantizationType.F16 + return False + + # utils to work with MiniCPM-V resampler + + # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 + def _get_2d_sincos_pos_embed(self, embed_dim: int, grid_size: tuple[int, int] | int, cls_token=False) -> np.ndarray: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + def _get_2d_sincos_pos_embed_from_grid(self, embed_dim: int, grid: np.ndarray) -> np.ndarray: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + def _get_1d_sincos_pos_embed_from_grid(self, embed_dim: int, pos: np.ndarray) -> np.ndarray: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + @Model.register("WavTokenizerDec") class WavTokenizerDecModel(Model): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC @@ -4877,6 +5184,58 @@ def _reverse_hf_permute(data_torch, n_heads, hidden_dim): data_torch = data_torch.repeat_interleave(n_heads, 0) return data_torch +@Model.register("CogAgentForCausalLM") +class CogVLMModel(Model): + model_arch = gguf.MODEL_ARCH.COGVLM + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ftype = gguf.LlamaFileType.ALL_F32 + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip boi and eoi tensors for now + if name.endswith("boi"): + return [] + if name.endswith("eoi"): + return [] + if name.startswith("model.vision"): + return [] + if name.startswith("model.cross_vision"): + return [] + + return [(self.map_tensor_name(name), data_torch)] + + def set_vocab(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') + vocab_size = len(tokenizer.vocab.items()) + + reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + added_vocab = tokenizer.get_added_vocab() + tokens: list[str] = [] + toktypes: list[int] = [] + + for i in range(vocab_size): + if i not in reverse_vocab: + tokens.append(f"[PAD{i}]") + toktypes.append(gguf.TokenType.UNUSED) + else: + token: str = reverse_vocab[i] + if token in added_vocab: + if tokenizer.added_tokens_decoder[i].special or self.does_token_look_special(token): + toktypes.append(gguf.TokenType.CONTROL) + else: + token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + toktypes.append(gguf.TokenType.USER_DEFINED) + else: + toktypes.append(gguf.TokenType.NORMAL) + tokens.append(token) + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + ###### CONVERSION LOGIC ###### @@ -4949,7 +5308,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Convert a huggingface model to a GGML compatible file") + description="Convert a huggingface model to a GGML compatible file\n\nNote: When converting vision models, this script may use internet connection to download configuration files via Hugging Face.") parser.add_argument( "--vocab-only", action="store_true", help="extract only the vocab", diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 66cfab2c3b796..59973bb3395d9 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -18,6 +18,7 @@ if (EMSCRIPTEN) else() add_subdirectory(batched-bench) add_subdirectory(batched) + add_subdirectory(cogagent) add_subdirectory(embedding) add_subdirectory(eval-callback) @@ -53,6 +54,7 @@ else() add_subdirectory(tokenize) add_subdirectory(tts) add_subdirectory(gen-docs) + add_subdirectory(vision) if (NOT GGML_BACKEND_DL) # these examples use the backends directly and cannot be built with dynamic loading add_subdirectory(convert-llama2c-to-ggml) diff --git a/examples/cogagent/CMakeLists.txt b/examples/cogagent/CMakeLists.txt new file mode 100644 index 0000000000000..e83b1c3febcfd --- /dev/null +++ b/examples/cogagent/CMakeLists.txt @@ -0,0 +1,18 @@ +set(TARGET llama-cogagent-cli) +add_executable(${TARGET} cogagent-cli.cpp) +add_library(cogagent OBJECT + vision_encoder.cpp + vision_encoder.h + cross_vision.cpp + cross_vision.h + cogagent_util.cpp + cogagent_util.h + image_util.cpp + image_util.h) +set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-cogagent-cli) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common cogagent ggml ${CMAKE_THREAD_LIBS_INIT}) +target_include_directories(cogagent PUBLIC ../../ggml/include) +target_include_directories(cogagent PUBLIC ../../include) +target_include_directories(cogagent PUBLIC ../../common) +target_compile_features(${TARGET} PRIVATE cxx_std_11) \ No newline at end of file diff --git a/examples/cogagent/cogagent-cli.cpp b/examples/cogagent/cogagent-cli.cpp new file mode 100644 index 0000000000000..f195fd4d2b736 --- /dev/null +++ b/examples/cogagent/cogagent-cli.cpp @@ -0,0 +1,285 @@ +#include "arg.h" +#include "base64.hpp" +#include "log.h" +#include "common.h" +#include "sampling.h" +#include "llama.h" + +#include +#include +#include +#include + +#include "cogagent.h" + +cogagent_ctx cogagent_global; + +// This function is mostly copied from cogagent cli +static bool eval_string_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { + int N = (int) tokens.size(); + + //// Processing the input tokens in batches + for (int i = 0; i < N; i += n_batch) { + int n_eval = (int) tokens.size() - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + + std::vector pos; + pos.resize(n_eval); + for (int i=0; i &img_data, + int n_batch, int * n_past) { + int n_embd = 4096; + int num_tokens = 258; + int positions[258]; + + positions[0] = *n_past; + for (int i=0; i n_batch) { + n_eval = n_batch; + } + llama_batch batch = {int32_t(n_eval), nullptr, data_ptr, positions, nullptr, nullptr, nullptr, nullptr, nullptr, }; + batch.cross_embd = cogagent_global.cross_vision_image_tensor; + if (llama_decode(ctx_llama, batch)) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + data_ptr += i * n_embd; + } + *n_past += 3; + return true; +} + +static void print_usage(int, char ** argv) { + LOG("\n example usage:\n"); + LOG("\n %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); + LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n"); +} + +static const char * sample(struct common_sampler * smpl, + struct llama_context * ctx_llama, + int * n_past) { + const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); + common_sampler_accept(smpl, id, true); + + const llama_model * model = llama_get_model(ctx_llama); + const llama_vocab * vocab = llama_model_get_vocab(model); + + static std::string ret; + if (llama_vocab_is_eog(vocab, id)) { + ret = ""; + } else { + ret = common_token_to_piece(ctx_llama, id); + } + // Give the new token to the model. I'm not sure how it is stored. + // Perhaps it is stored in the KV cache. + std::vector tokens; + tokens.push_back(id); + eval_string_tokens(ctx_llama, tokens, 1, n_past); + + return ret.c_str(); +} + +static bool run_vision_encoders(const char* vision_encoder_path, const char* image_path) { + // Load image and resize for the encoders + std::vector small_image_data; // For vision encoder + std::vector large_image_data; // For cross vision encoder + if (!load_and_stretch_image(image_path, cogagent_global.vision_encoder_img_size, + small_image_data, cogagent_global.norm_mean, cogagent_global.norm_deviation)) { + printf("Failed to load the specified image file.\n"); + return false; + } + if (!load_and_stretch_image(image_path, cogagent_global.cross_vision_img_size, + large_image_data, cogagent_global.norm_mean, cogagent_global.norm_deviation)) { + printf("Failed to load the specified image file.\n"); + return false; + } + + // For debugging purposes + const char * vision_encoder_resized_image = "cogagent_encoders/llama_vision_encoder_input.gguf"; + int dims[3] = {cogagent_global.vision_encoder_img_size, + cogagent_global.vision_encoder_img_size, 3}; + save_tensor_from_data(small_image_data, dims, vision_encoder_resized_image); + const char * cross_vision_resized_image = "cogagent_encoders/llama_cross_vision_input.gguf"; + dims[0] = cogagent_global.cross_vision_img_size; + dims[1] = cogagent_global.cross_vision_img_size; + save_tensor_from_data(large_image_data, dims, cross_vision_resized_image); + + // const char * reference_vision_encoder_input = "/home/tianyue/myworkspace" + // "/vlm_intermediate/vision_encoder_input.gguf"; + // const char * reference_cross_vision_input = "/home/tianyue/myworkspace" + // "/vlm_intermediate/cross_vision_input.gguf"; + // // Load the reference input + // if (get_input(small_image_data, reference_vision_encoder_input) < 0) { + // printf("Failed to load small image input\n"); + // return false; + // } + // if (get_input(large_image_data, reference_cross_vision_input) < 0) { + // printf("Failed to load big image input\n"); + // return false; + // } + printf("Loaded and resized the specified image.\n"); + + // Load the vision encoder weights + if (!vision_encoder_init_load(vision_encoder_path)) { + printf("Failed to load vision encoder model file.\n"); + return false; + } + printf("Vision encoder weights loaded.\n"); + + // Run the vision encoder + run_vision_encoder(small_image_data); + printf("Completed vision encoder run on image file.\n"); + + free_vision_encoder_ctx(); + + // Load and run the cross vision encoder + if (!cross_vision_init_load(vision_encoder_path)) { + printf("Failed to load cross vision encoder model file.\n"); + return false; + } + printf("Cross vision encoder weights loaded.\n"); + + run_cross_vision(large_image_data); + printf("Completed cross vision encoder run on image file.\n"); + + free_cross_vision_ctx(); + return true; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + common_params params; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COGAGENT, print_usage)) { + return 1; + } + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + // Initialize a GGML context to store the encoded image tensors + struct ggml_init_params token_ctx_params = { + size_t(40000000), + NULL, + false, + }; + cogagent_global.token_ctx = ggml_init(token_ctx_params); + if (!cogagent_global.token_ctx) { + printf("Failed to initialize token storage context.\n"); + return 1; + } + // Allocate the tensor for cross vision encoded image + cogagent_global.cross_vision_image_tensor = ggml_new_tensor_2d( + cogagent_global.token_ctx, GGML_TYPE_F32, 1024, 6400 + ); + + // Load the images and the encoder models + // Then run the encoder models + if (!run_vision_encoders(params.mmproj.c_str(), params.image[0].c_str())) { + return 1; + } + + llama_model_params model_params = common_model_params_to_llama(params); + llama_model * model = llama_model_load_from_file(params.model.c_str(), model_params); + if (model == nullptr) { + printf("Failed to load decoder model\n"); + return 1; + } + + llama_context_params ctx_params = common_context_params_to_llama(params); + printf("Context size is %d tokens\n", ctx_params.n_ctx); + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); + + if (ctx_llama == nullptr) { + printf("Failed to create the llama context\n"); + return 1; + } + + cogagent_global.ctx_llama = ctx_llama; + cogagent_global.cogvlm_model = model; + + // At the moment I can't figure out how the llama kv cache + // keeps its information across runs. + // It seems to me that the graph is allocated for each batch, + // which would invalidate any tensors stored in the kv cache. + // I don't spot logic for separately allocating the kv cache + // tensors to avoid this, so it doesn't make sense. + // Maybe the graph isn't actually allocated for each batch? + // Perhaps that is why a worst case graph is allocated. + + // TODO: Check if system prompt is compatible + std::vector begin_token; + const llama_vocab * vocab = llama_model_get_vocab(cogagent_global.cogvlm_model); + begin_token.push_back(llama_vocab_bos(vocab)); + + int n_past = 0; + printf("Run model with bos token.\n"); + eval_string_tokens(cogagent_global.ctx_llama, + begin_token, params.n_batch, &n_past); + printf("Run model with image tokens.\n"); + eval_image_tokens(cogagent_global.ctx_llama, cogagent_global.vision_encoder_image, + params.n_batch, &n_past); + // Tokenize user prompt + // Third option set to false to that the tokenizer doesn't add + // beginning of sentence and end of sentence + std::vector user_prompt_tokens = common_tokenize( + cogagent_global.ctx_llama, params.prompt, false, true + ); + printf("Run model with user entered text tokens.\n"); + eval_string_tokens(cogagent_global.ctx_llama, user_prompt_tokens, + params.n_batch, &n_past); + + printf("Parsed maximum sampling length %d.\n", params.n_predict); + int max_len = params.n_predict < 0 ? 256 : params.n_predict; + + struct common_sampler * smpl = common_sampler_init(cogagent_global.cogvlm_model, params.sampling); + if (!smpl) { + printf("Failed to initialize sampler.\n"); + return 1; + } + printf("\nReprinting entered prompt.\n %s \n", params.prompt.c_str()); + printf("\n\n Beginning of response.\n"); + std::string response = ""; + for (int i=0; i") == 0) { + if (i < 10) { + continue; + } + break; + } + printf("%s", tmp); + fflush(stdout); + } + common_sampler_free(smpl); + + llama_model_free(model); + ggml_free(cogagent_global.token_ctx); + return 0; +} \ No newline at end of file diff --git a/examples/cogagent/cogagent.h b/examples/cogagent/cogagent.h new file mode 100644 index 0000000000000..d51a87a539a35 --- /dev/null +++ b/examples/cogagent/cogagent.h @@ -0,0 +1,36 @@ +#ifndef COGAGENT_H +#define COGAGENT_H + +#include "vision_encoder.h" +#include "cross_vision.h" +#include "cogagent_util.h" +#include "image_util.h" +#include "ggml.h" +#include "gguf.h" + +struct cogagent_ctx { + // Vision encoder and cross vision encoder models + vision_encoder_ctx vision_encoder; + cross_vision_ctx cross_vision; + + struct llama_context * ctx_llama; + struct llama_model * cogvlm_model; + + // Context for storing vision tokens and cross vision + // embedded picture tensor + ggml_context * token_ctx; + + std::string user_prompt; + std::vector vision_encoder_image; // Image encoded by the vision encoder + struct ggml_tensor * cross_vision_image_tensor; // Image encoded by the cross vision encoder + + int vision_encoder_img_size = 224; + int cross_vision_img_size = 1120; + + float norm_mean[3] = {0.48145466, 0.4578275, 0.40821073}; + float norm_deviation[3] = {0.26862954, 0.26130258, 0.27577711}; +}; + +extern struct cogagent_ctx cogagent_global; + +#endif \ No newline at end of file diff --git a/examples/cogagent/cogagent_util.cpp b/examples/cogagent/cogagent_util.cpp new file mode 100644 index 0000000000000..fce62dda3ddab --- /dev/null +++ b/examples/cogagent/cogagent_util.cpp @@ -0,0 +1,162 @@ +#include "cogagent_util.h" + +void print_dims(struct ggml_tensor * input_tensor, const char * name) { + printf("Tensor %s has shape %ld x %ld x %ld x %ld and data type %d\n", name, + input_tensor->ne[0], input_tensor->ne[1], input_tensor->ne[2], + input_tensor->ne[3], input_tensor->type); +} + +struct ggml_tensor * get_tensor(struct ggml_context * dst_ctx, struct ggml_context * src_ctx, std::string tensor_name, int &count_failed) { + struct ggml_tensor * cur_tensor = ggml_get_tensor(src_ctx, tensor_name.c_str()); + if (!cur_tensor) { + printf("Retrieval of tensor %s from model context failed\n", tensor_name.c_str()); + count_failed++; + return nullptr; + } + struct ggml_tensor * new_tensor = ggml_dup_tensor(dst_ctx, cur_tensor); + ggml_set_name(new_tensor, cur_tensor->name); + return new_tensor; +} + +void save_tensor_filename(struct ggml_tensor * input_tensor, std::string filename) { + std::string prefix = "/home/tianyue/myworkspace/"; + filename = prefix + filename; + gguf_context * gguf_ctx = gguf_init_empty(); + gguf_set_val_str(gguf_ctx, "model.architecture", "cogagent"); + gguf_set_val_u32(gguf_ctx, "general.file_type", GGML_TYPE_F32); + + struct ggml_init_params params = { + ggml_nbytes(input_tensor) + 1000000, // Memory to allocate + nullptr, // Buffer location + false, // no_alloc=false, so that tensor data is allocated + }; + struct ggml_context * tensor_ctx = ggml_init(params); + struct ggml_tensor * tensor_with_data = ggml_dup(tensor_ctx, input_tensor); + ggml_backend_tensor_get(input_tensor, tensor_with_data->data, + 0, ggml_nbytes(input_tensor)); + + ggml_set_name(tensor_with_data, "output_tensor"); + gguf_add_tensor(gguf_ctx, tensor_with_data); + gguf_write_to_file(gguf_ctx, filename.c_str(), false); + gguf_free(gguf_ctx); + ggml_free(tensor_ctx); +} + +void save_tensor_from_data(std::vector tensor_data, int* dims, std::string filename) { + std::string prefix = "/home/tianyue/myworkspace/"; + filename = prefix + filename; + gguf_context * gguf_ctx = gguf_init_empty(); + gguf_set_val_str(gguf_ctx, "model.architecture", "cogagent"); + gguf_set_val_u32(gguf_ctx, "general.file_type", GGML_TYPE_F32); + + struct ggml_init_params params = { + tensor_data.size() * sizeof(float) + 1000000, // Memory to allocate + nullptr, // Buffer location + false, // Allocate tensor data + }; + struct ggml_context * tensor_ctx = ggml_init(params); + struct ggml_tensor * tensor_with_data = ggml_new_tensor_3d(tensor_ctx, + GGML_TYPE_F32, dims[0], dims[1], dims[2]); + // copy the data + memcpy(tensor_with_data->data, tensor_data.data(), ggml_nbytes(tensor_with_data)); + + ggml_set_name(tensor_with_data, "output_tensor"); + gguf_add_tensor(gguf_ctx, tensor_with_data); + gguf_write_to_file(gguf_ctx, filename.c_str(), false); + gguf_free(gguf_ctx); + ggml_free(tensor_ctx); +} + +// Function that loads data from the GGUF file to a temporary buffer +// and then from the temporary buffer to the GGML backend buffer +// Copied from the GGML MNIST example +bool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct gguf_context * ctx_gguf) { + FILE * f = ggml_fopen(fname, "rb"); + if (!f) { + return false; + } + + const size_t buf_size = 4*1024*1024; + void * buf = malloc(buf_size); + + const int n_tensors = gguf_get_n_tensors(ctx_gguf); + for (int i = 0; i < n_tensors; i++) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + + struct ggml_tensor * tensor = ggml_get_tensor(ctx_ggml, name); + if (!tensor) { + // We get here if there is a tensor in the file + // that is not being requested for the context + // that we are loading into + continue; + } + + const size_t offs = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i); + + if (fseek(f, offs, SEEK_SET) != 0) { + fclose(f); + free(buf); + return false; + } + + const size_t nbytes = ggml_nbytes(tensor); + for (size_t pos = 0; pos < nbytes; pos += buf_size) { + const size_t nbytes_cpy = buf_size < nbytes - pos ? buf_size : nbytes - pos; + + if (fread(buf, 1, nbytes_cpy, f) != nbytes_cpy) { + fclose(f); + free(buf); + return false; + } + + ggml_backend_tensor_set(tensor, buf, pos, nbytes_cpy); + } + } + + fclose(f); + free(buf); + return true; +} + +int get_input( + std::vector &input_data, + const char * filename +) { + struct ggml_context * meta_info; + + struct gguf_init_params gguf_params = { + true, &meta_info, + }; + struct gguf_context * gguf_ctx = gguf_init_from_file(filename, gguf_params); + if (!gguf_ctx) { + printf("Failed to initialize GGUF context. Check filename.\n"); + return false; + } + + // I don't know how to set tensor name when writing a GGUF file + // in cpp, so the cross vision output tensor doesn't appear + // to have a name when reading the GGUF file + struct ggml_tensor * meta_tensor = ggml_get_first_tensor(meta_info); + if (meta_tensor->type != GGML_TYPE_F32) { + printf("Expected the input image datatype to be float 32.\n"); + printf("Image loading failed because the datatype is actually %d\n", + meta_tensor->type); + return -1; + } + + size_t tensor_size = ggml_nbytes(meta_tensor); + printf("Input tensor size is %ld bytes\n", tensor_size); + + input_data.resize(meta_tensor->ne[0] * meta_tensor->ne[1] * + meta_tensor->ne[2]); + int num_tokens = meta_tensor->ne[1]; + + std::ifstream input_file = std::ifstream(filename, std::ios::binary); + const size_t offset = gguf_get_data_offset(gguf_ctx) + gguf_get_tensor_offset(gguf_ctx, 0); + input_file.seekg(offset, std::ios::beg); + printf("Seeked to input tensor GGUF position %ld\n", offset); + input_file.read(reinterpret_cast(input_data.data()), tensor_size); + input_file.close(); + ggml_free(meta_info); + return num_tokens; +} \ No newline at end of file diff --git a/examples/cogagent/cogagent_util.h b/examples/cogagent/cogagent_util.h new file mode 100644 index 0000000000000..d31dc25db288a --- /dev/null +++ b/examples/cogagent/cogagent_util.h @@ -0,0 +1,39 @@ +#ifndef COGAGENT_UTIL_H +#define COGAGENT_UTIL_H + +#include "llama.h" +#include "ggml.h" +#include "gguf.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +extern void set_processing_text(llama_context * ctx, bool value); + +extern void set_cross_input(llama_context * ctx, std::vector &value); + +void print_dims(struct ggml_tensor * input_tensor, const char * name); + +struct ggml_tensor * get_tensor(struct ggml_context * dst_ctx, struct ggml_context * src_ctx, std::string tensor_name, int &count_failed); + +void save_tensor_filename(struct ggml_tensor * input_tensor, std::string filename); + +void save_tensor_from_data(std::vector tensor_data, int* dims, std::string filename); + +bool load_from_gguf(const char * fname, struct ggml_context * ctx_ggml, struct gguf_context * ctx_gguf); + +int get_input( + std::vector &input_data, const char * filename +); + +#endif \ No newline at end of file diff --git a/examples/cogagent/cross_vision.cpp b/examples/cogagent/cross_vision.cpp new file mode 100644 index 0000000000000..d20b76a8336f8 --- /dev/null +++ b/examples/cogagent/cross_vision.cpp @@ -0,0 +1,299 @@ +#include "cogagent.h" +#include "cross_vision.h" + +#include +#include +#include +#include +#include +#include + +bool cross_vision_init_load(const char * filename) { + cross_vision_ctx &model_ctx = cogagent_global.cross_vision; + cross_vision &model = cogagent_global.cross_vision.model; + + model_ctx.backend = ggml_backend_cpu_init(); + ggml_backend_cpu_set_n_threads(model_ctx.backend, 16); + + struct ggml_init_params weight_params { + // Counted 515 tensors in cross vision encoder save file + 10000000, // Memory size + NULL, // Memory buffer + true, // Don't allocate tensor data + }; + + struct ggml_init_params compute_params { + GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + ggml_graph_overhead(), + NULL, + true, + }; + + model_ctx.ctx_weight = ggml_init(weight_params); + model_ctx.ctx_compute = ggml_init(compute_params); + + struct ggml_context * meta; + struct gguf_init_params gguf_params = { + true, &meta, + }; + + struct gguf_context * gguf_ctx = gguf_init_from_file(filename, gguf_params); + if (!gguf_ctx) { + printf("Failed to initialize GGUF context. Check filename.\n"); + return false; + } + + // Calculate the ctx size based on tensors in the GGUF file + size_t ctx_size = 0; + int num_tensors = gguf_get_n_tensors(gguf_ctx); + printf("There are %d tensors in the GGUF file\n", num_tensors); + for (int i=0; ine[0]); + + model.cls_embed = get_tensor(model_ctx.ctx_weight, meta, "cross_vision.vit.model.cls_token", failed_count); + model.pos_embed_1 = get_tensor(model_ctx.ctx_weight, meta, "cross_vision.vit.model.pos_embed", failed_count); + + model.rope_freqs_cos = get_tensor(model_ctx.ctx_weight, meta, "cross_vision.vit.model.rope.freqs_cos", failed_count); + model.rope_freqs_sin = get_tensor(model_ctx.ctx_weight, meta, "cross_vision.vit.model.rope.freqs_sin", failed_count); + + for (int i=0; i<24; i++) { + std::string layer_prefix = "cross_vision.vit.model.blocks." + std::to_string(i) + "."; + model.transformer_layers.emplace_back(); + cross_vision_layer &cur_layer = model.transformer_layers.back(); + cur_layer.norm1_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "norm1.weight", failed_count); + cur_layer.norm1_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "norm1.bias", failed_count); + cur_layer.q_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "attn.q_proj.weight", failed_count); + cur_layer.q_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "attn.q_bias", failed_count); + cur_layer.k_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "attn.k_proj.weight", failed_count); + cur_layer.v_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "attn.v_proj.weight", failed_count); + cur_layer.v_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "attn.v_bias", failed_count); + cur_layer.attn_ln_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "attn.inner_attn_ln.weight", failed_count); + cur_layer.attn_ln_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "attn.inner_attn_ln.bias", failed_count); + cur_layer.attn_linear_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "attn.proj.weight", failed_count); + cur_layer.attn_linear_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "attn.proj.bias", failed_count); + cur_layer.norm2_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "norm2.weight", failed_count); + cur_layer.norm2_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "norm2.bias", failed_count); + cur_layer.mlp_linear1_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "mlp.w1.weight", failed_count); + cur_layer.mlp_linear1_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "mlp.w1.bias", failed_count); + cur_layer.mlp_linear2_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "mlp.w2.weight", failed_count); + cur_layer.mlp_linear2_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "mlp.w2.bias", failed_count); + cur_layer.mlp_ln_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "mlp.ffn_ln.weight", failed_count); + cur_layer.mlp_ln_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "mlp.ffn_ln.bias", failed_count); + cur_layer.mlp_linear3_w = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "mlp.w3.weight", failed_count); + cur_layer.mlp_linear3_b = get_tensor(model_ctx.ctx_weight, meta, layer_prefix + "mlp.w3.bias", failed_count); + } + + model.pos_embed_2 = get_tensor(model_ctx.ctx_weight, meta, "cross_vision.pos_embed", failed_count); + + if (failed_count > 0) { + printf("%d tensors could not be found in the model context. Model loading failed.\n", failed_count); + return false; + } + + // Allocate data storage for the tensors on the backend + model_ctx.weight_data = ggml_backend_alloc_ctx_tensors(model_ctx.ctx_weight, model_ctx.backend); + + if (!load_from_gguf(filename, model_ctx.ctx_weight, gguf_ctx)) { + printf("Loading data from GGUF file failed\n"); + return false; + } + + ggml_free(meta); + + return true; +} + +static struct ggml_tensor * compute_rope(cross_vision_ctx &model_ctx, struct ggml_tensor *input_tensor) { + struct ggml_context * ctx = model_ctx.ctx_compute; + cross_vision model = model_ctx.model; + // Don't really think this should be necessary + // The ggml_is_contiguous_n code in ggml.c doesn't even seem to be using the n variable + input_tensor = ggml_cont(ctx, input_tensor); + struct ggml_tensor * cos = ggml_mul(ctx, input_tensor, model.rope_freqs_cos); + if (!ggml_is_contiguous(input_tensor)) {printf("Not contiguous input tensor\n");} + struct ggml_tensor * rotate_half = ggml_reshape_4d(ctx, input_tensor, 2, input_tensor->ne[0] / 2, + input_tensor->ne[1], input_tensor->ne[2]); + rotate_half = ggml_permute(ctx, rotate_half, 3, 1, 2, 0); + rotate_half = ggml_cont(ctx, rotate_half); + struct ggml_tensor * positive = ggml_view_4d(ctx, rotate_half, rotate_half->ne[0], rotate_half->ne[1], + rotate_half->ne[2], 1, rotate_half->nb[1], rotate_half->nb[2], rotate_half->nb[3], 0); + struct ggml_tensor * negative = ggml_view_4d(ctx, rotate_half, rotate_half->ne[0], rotate_half->ne[1], + rotate_half->ne[2], 1, rotate_half->nb[1], rotate_half->nb[2], rotate_half->nb[3], rotate_half->nb[3]); + negative = ggml_scale(ctx, negative, -1); + rotate_half = ggml_concat(ctx, negative, positive, 3); + rotate_half = ggml_permute(ctx, rotate_half, 3, 1, 2, 0); + rotate_half = ggml_cont(ctx, rotate_half); + rotate_half = ggml_reshape_3d(ctx, rotate_half, 2 * rotate_half->ne[1], rotate_half->ne[2], rotate_half->ne[3]); + struct ggml_tensor * sin = ggml_mul(ctx, rotate_half, model.rope_freqs_sin); + return ggml_add(ctx, cos, sin); +} + +static struct ggml_cgraph * cross_vision_graph() { + struct ggml_context * ctx = cogagent_global.cross_vision.ctx_compute; + cross_vision_ctx &model_ctx = cogagent_global.cross_vision; + cross_vision &model = cogagent_global.cross_vision.model; + + // Set flag to tell allocator to not overwrite the input tensor + // before it is done with computation + ggml_set_input(model.input_image); + + ggml_tensor * patch_embedding = ggml_conv_2d(ctx, model.patch_conv_w, model.input_image, + 14, 14, 0, 0, 1, 1); + // ggml_repeat should be automatically applied + // It is required that the tensor to be repeated is the second one + patch_embedding = ggml_add(ctx, patch_embedding, model.patch_conv_b); + // From w x h x d x b to l x d x b + patch_embedding = ggml_reshape_3d(ctx, patch_embedding, patch_embedding->ne[0] * patch_embedding->ne[1], + patch_embedding->ne[2], patch_embedding->ne[3]); + patch_embedding = ggml_transpose(ctx, patch_embedding); // From l x d x b to d x l x b + patch_embedding = ggml_cont(ctx, patch_embedding); + // Concatenate the class embedding + struct ggml_tensor * cls_embed_shape = ggml_new_tensor_3d(ctx, model.cls_embed->type, + patch_embedding->ne[0], 1, patch_embedding->ne[2]); + // ggml_concat only supports F32 + patch_embedding = ggml_concat(ctx, ggml_repeat(ctx, model.cls_embed, cls_embed_shape), patch_embedding, 1); + patch_embedding = ggml_add(ctx, patch_embedding, model.pos_embed_1); + + struct ggml_tensor * layer_in = patch_embedding; + for (int i=0; i<23; i++) { + // d x l x b + cross_vision_layer &cur_layer = model.transformer_layers[i]; + struct ggml_tensor * attention_input = ggml_norm(ctx, layer_in, model.layernorm_eps); + attention_input = ggml_mul(ctx, attention_input, cur_layer.norm1_w); + attention_input = ggml_add(ctx, attention_input, cur_layer.norm1_b); + struct ggml_tensor * qt = ggml_mul_mat(ctx, cur_layer.q_w, attention_input); + qt = ggml_add(ctx, qt, cur_layer.q_b); + struct ggml_tensor * kt = ggml_mul_mat(ctx, cur_layer.k_w, attention_input); + struct ggml_tensor * v = ggml_mul_mat(ctx, cur_layer.v_w, attention_input); + v = ggml_add(ctx, v, cur_layer.v_b); + // reshape and permute to k x l x h x b + qt = ggml_reshape_4d(ctx, qt, qt->ne[0] / model.num_heads, model.num_heads, + qt->ne[1], qt->ne[2]); + qt = ggml_permute(ctx, qt, 0, 2, 1, 3); + qt = ggml_cont(ctx, qt); + kt = ggml_reshape_4d(ctx, kt, kt->ne[0] / model.num_heads, model.num_heads, + kt->ne[1], kt->ne[2]); + kt = ggml_permute(ctx, kt, 0, 2, 1, 3); + kt = ggml_cont(ctx, kt); + // for v, reshape and permute to l x k x h x b + v = ggml_reshape_4d(ctx, v, v->ne[0] / model.num_heads, model.num_heads, + v->ne[1], v->ne[2]); + v = ggml_permute(ctx, v, 1, 2, 0, 3); + v = ggml_cont(ctx, v); + + // At this point, qt and kt are k x l x h x b + // v is l x k x h x b + + // Process rope for Q and K + // Remove class embedding before computing rope + struct ggml_tensor * qtrope = ggml_view_4d(ctx, qt, qt->ne[0], qt->ne[1] - 1, + qt->ne[2], qt->ne[3], qt->nb[1], qt->nb[2], qt->nb[3], qt->nb[1]); + struct ggml_tensor * ktrope = ggml_view_4d(ctx, kt, kt->ne[0], kt->ne[1] - 1, + kt->ne[2], kt->ne[3], kt->nb[1], kt->nb[2], kt->nb[3], kt->nb[1]); + struct ggml_tensor * qtcls = ggml_view_4d(ctx, qt, qt->ne[0], 1, qt->ne[2], + qt->ne[3], qt->nb[1], qt->nb[2], qt->nb[3], 0); + struct ggml_tensor * ktcls = ggml_view_4d(ctx, kt, kt->ne[0], 1, kt->ne[2], + kt->ne[3], kt->nb[1], kt->nb[2], kt->nb[3], 0); + qtrope = compute_rope(model_ctx, qtrope); + ktrope = compute_rope(model_ctx, ktrope); + qtrope = ggml_concat(ctx, qtcls, qtrope, 1); + ktrope = ggml_concat(ctx, ktcls, ktrope, 1); + + struct ggml_tensor * qtscale = ggml_scale(ctx, qtrope, model.attn_scale); + + // Calculate attention score + // L x L x H x B + struct ggml_tensor * attnt = ggml_mul_mat(ctx, ktrope, qtscale); + attnt = ggml_soft_max(ctx, attnt); + attnt = ggml_mul_mat(ctx, v, attnt); // k x l x h x b + attnt = ggml_permute(ctx, attnt, 0, 2, 1, 3); // k x h x l x b + attnt = ggml_cont(ctx, attnt); + attnt = ggml_reshape_3d(ctx, attnt, attnt->ne[0] * attnt->ne[1], + attnt->ne[2], attnt->ne[3]); + attnt = ggml_norm(ctx, attnt, model.layernorm_eps); + attnt = ggml_mul(ctx, attnt, cur_layer.attn_ln_w); + attnt = ggml_add(ctx, attnt, cur_layer.attn_ln_b); + attnt = ggml_mul_mat(ctx, cur_layer.attn_linear_w, attnt); + attnt = ggml_add(ctx, attnt, cur_layer.attn_linear_b); + + layer_in = ggml_add(ctx, layer_in, attnt); + + // MLP calculation + struct ggml_tensor * mlp_tensor = ggml_norm(ctx, layer_in, model.layernorm_eps); + mlp_tensor = ggml_mul(ctx, mlp_tensor, cur_layer.norm2_w); + mlp_tensor = ggml_add(ctx, mlp_tensor, cur_layer.norm2_b); + struct ggml_tensor * w1 = ggml_mul_mat(ctx, cur_layer.mlp_linear1_w, mlp_tensor); + w1 = ggml_add(ctx, w1, cur_layer.mlp_linear1_b); + struct ggml_tensor * w2 = ggml_mul_mat(ctx, cur_layer.mlp_linear2_w, mlp_tensor); + w2 = ggml_add(ctx, w2, cur_layer.mlp_linear2_b); + w1 = ggml_silu(ctx, w1); + mlp_tensor = ggml_mul(ctx, w1, w2); // MLP hidden size is 2730 + mlp_tensor = ggml_norm(ctx, mlp_tensor, model.layernorm_eps); + mlp_tensor = ggml_mul(ctx, mlp_tensor, cur_layer.mlp_ln_w); + mlp_tensor = ggml_add(ctx, mlp_tensor, cur_layer.mlp_ln_b); + mlp_tensor = ggml_mul_mat(ctx, cur_layer.mlp_linear3_w, mlp_tensor); + mlp_tensor = ggml_add(ctx, mlp_tensor, cur_layer.mlp_linear3_b); + + layer_in = ggml_add(ctx, layer_in, mlp_tensor); + } + + model.output_tensor = ggml_view_3d(ctx, layer_in, layer_in->ne[0], layer_in->ne[1] - 1, + layer_in->ne[2], layer_in->nb[1], layer_in->nb[2], layer_in->nb[1]); + model.output_tensor = ggml_add(ctx, model.output_tensor, model.pos_embed_2); + + struct ggml_cgraph * gf = ggml_new_graph(ctx); + ggml_build_forward_expand(gf, model.output_tensor); + + // Copy the output tensor to the token context + ggml_build_forward_expand(gf, ggml_cpy(ctx, model.output_tensor, + cogagent_global.cross_vision_image_tensor)); + + return gf; +} + +void run_cross_vision(std::vector img_data) { + cross_vision_ctx &model_ctx = cogagent_global.cross_vision; + cross_vision &model = cogagent_global.cross_vision.model; + + // Declare the input image tensor + model.input_image = ggml_new_tensor_3d(cogagent_global.cross_vision.ctx_compute, GGML_TYPE_F32, + cogagent_global.cross_vision_img_size, cogagent_global.cross_vision_img_size, 3); + + struct ggml_cgraph * gf = cross_vision_graph(); + ggml_graph_print(gf); + printf("Number of nodes in the vision encoder graph is %d\n", ggml_graph_n_nodes(gf)); + + model_ctx.allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model_ctx.backend)); + ggml_gallocr_reserve(model_ctx.allocr, gf); + size_t compute_size = ggml_gallocr_get_buffer_size(model_ctx.allocr, 0); + printf("Allocated %ld bytes of space for graph computation.\n", compute_size); + ggml_gallocr_alloc_graph(model_ctx.allocr, gf); + + ggml_backend_tensor_set(model.input_image, img_data.data(), 0, ggml_nbytes(model.input_image)); + + // Computation result is at model.output_tensor + ggml_backend_graph_compute(model_ctx.backend, gf); + + save_tensor_filename(model.output_tensor, "cogagent_encoders/cross_vision_output.gguf"); +} + +void free_cross_vision_ctx() { + cross_vision_ctx &model_ctx = cogagent_global.cross_vision; + + ggml_gallocr_free(model_ctx.allocr); + ggml_backend_buffer_free(model_ctx.weight_data); + ggml_free(model_ctx.ctx_weight); + ggml_free(model_ctx.ctx_compute); +} \ No newline at end of file diff --git a/examples/cogagent/cross_vision.h b/examples/cogagent/cross_vision.h new file mode 100644 index 0000000000000..5ef98dfc1c385 --- /dev/null +++ b/examples/cogagent/cross_vision.h @@ -0,0 +1,83 @@ +#ifndef CROSS_VISION_H +#define CROSS_VISION_H + +#include "ggml-backend.h" +#include +#include +#include +#include +#include +#include + +#include "cogagent.h" + +struct cross_vision_layer { + struct ggml_tensor * norm1_w; + struct ggml_tensor * norm1_b; + + // No bias for K projection + struct ggml_tensor * q_w; + struct ggml_tensor * q_b; + struct ggml_tensor * k_w; + struct ggml_tensor * v_w; + struct ggml_tensor * v_b; + + struct ggml_tensor * attn_ln_w; + struct ggml_tensor * attn_ln_b; + struct ggml_tensor * attn_linear_w; + struct ggml_tensor * attn_linear_b; + + struct ggml_tensor * norm2_w; + struct ggml_tensor * norm2_b; + + struct ggml_tensor * mlp_linear1_w; + struct ggml_tensor * mlp_linear1_b; + struct ggml_tensor * mlp_linear2_w; + struct ggml_tensor * mlp_linear2_b; + struct ggml_tensor * mlp_ln_w; + struct ggml_tensor * mlp_ln_b; + struct ggml_tensor * mlp_linear3_w; + struct ggml_tensor * mlp_linear3_b; +}; + +struct cross_vision { + struct ggml_tensor * patch_conv_w; + struct ggml_tensor * patch_conv_b; + + struct ggml_tensor * cls_embed; + struct ggml_tensor * pos_embed_1; + + struct ggml_tensor * rope_freqs_cos; // 6400 x 64. In other words, l x k + struct ggml_tensor * rope_freqs_sin; + + std::vector transformer_layers; + + struct ggml_tensor * pos_embed_2; + + struct ggml_tensor * input_image; + struct ggml_tensor * output_tensor; + + float layernorm_eps = 0.000001; // Need to check all of these to confirm + int num_heads = 16; // For sure + int hidden_size = 1024; + int head_hidden_size = hidden_size / num_heads; + int num_layers = 24; + float attn_scale = 1.0 / std::sqrt(head_hidden_size); +}; + +struct cross_vision_ctx { + struct ggml_context * ctx_weight; + struct ggml_context * ctx_compute; + ggml_backend_buffer_t weight_data; + ggml_backend_t backend; + ggml_gallocr_t allocr; + cross_vision model; +}; + +bool cross_vision_init_load(const char * filename); + +void run_cross_vision(std::vector img_data); + +void free_cross_vision_ctx(); + +#endif \ No newline at end of file diff --git a/examples/cogagent/image_util.cpp b/examples/cogagent/image_util.cpp new file mode 100644 index 0000000000000..372359cc6716b --- /dev/null +++ b/examples/cogagent/image_util.cpp @@ -0,0 +1,296 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "image_util.h" +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" + +// RGB uint8 image +struct clip_image_u8 { + int nx; + int ny; + + std::vector buf; +}; + +// RGB float32 image (NHWC) +// Memory layout: RGBRGBRGB... +struct clip_image_f32 { + int nx; + int ny; + + std::vector buf; +}; + +struct clip_image_size * clip_image_size_init() { + struct clip_image_size * load_image_size = new struct clip_image_size(); + load_image_size->width = 448; + load_image_size->height = 448; + return load_image_size; +} + +struct clip_image_u8 * clip_image_u8_init() { + return new clip_image_u8(); +} + +struct clip_image_f32 * clip_image_f32_init() { + return new clip_image_f32(); +} + +void clip_image_u8_free(struct clip_image_u8 * img) { delete img; } +void clip_image_f32_free(struct clip_image_f32 * img) { delete img; } +void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { + if (batch->size > 0) { + delete[] batch->data; + batch->size = 0; + } +} +void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { + if (batch->size > 0) { + delete[] batch->data; + batch->size = 0; + } +} + +static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) { + img->nx = nx; + img->ny = ny; + img->buf.resize(3 * nx * ny); + memcpy(img->buf.data(), data, img->buf.size()); +} + +bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { + int nx, ny, nc; + auto * data = stbi_load(fname, &nx, &ny, &nc, 3); + if (!data) { + printf("%s: failed to load image '%s'\n", __func__, fname); + return false; + } + build_clip_img_from_data(data, nx, ny, img); + stbi_image_free(data); + return true; +} + +bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) { + int nx, ny, nc; + auto * data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3); + if (!data) { + printf("%s: failed to decode image bytes\n", __func__); + return false; + } + build_clip_img_from_data(data, nx, ny, img); + stbi_image_free(data); + return true; +} + +// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not +static void normalize_image_u8_to_f32(const clip_image_u8* src, clip_image_f32* dst, const float mean[3], const float std[3]) { + dst->nx = src->nx; + dst->ny = src->ny; + dst->buf.resize(src->buf.size()); + + for (size_t i = 0; i < src->buf.size(); ++i) { + int c = i % 3; // rgb + dst->buf[i] = (static_cast(src->buf[i]) / 255.0f - mean[c]) / std[c]; + } +} + +inline int clip(int x, int lower, int upper) { + return std::max(lower, std::min(x, upper)); +} + +static bool bicubic_resize(const clip_image_u8 &img, clip_image_u8 &dst, int target_width, int target_height) { + const int nx = img.nx; + const int ny = img.ny; + + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); + + float Cc; + float C[5]; + float d0, d2, d3, a0, a1, a2, a3; + int i, j, k, jj; + int x, y; + float dx, dy; + float tx, ty; + + tx = (float)nx / (float)target_width; + ty = (float)ny / (float)target_height; + + // Bicubic interpolation; adapted from ViT.cpp, inspired from : + // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 + // -> https://en.wikipedia.org/wiki/Bicubic_interpolation + + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { + x = (int)(tx * j); + y = (int)(ty * i); + + dx = tx * j - x; + dy = ty * i - y; + + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { + d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; + + d0 = C[0] - C[1]; + d2 = C[2] - C[1]; + d3 = C[3] - C[1]; + a0 = C[1]; + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; + + const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); + dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + } + } + } + } + + return true; +} + +static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) { + auto file = fopen(path, "rb"); + if (file == NULL) { + printf("%s: can't read file %s\n", __func__, path); + return false; + } + + fseek(file, 0, SEEK_END); + auto fileSize = ftell(file); + fseek(file, 0, SEEK_SET); + + auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data + if (buffer == NULL) { + printf("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path); + perror("Memory allocation error"); + fclose(file); + return false; + } + errno = 0; + size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer + if (ferror(file)) { + printf("read error: %s\n", strerror(errno)); + return false; + } + if (ret != (size_t) fileSize) { + printf("unexpectedly reached end of file\n"); + return false; + } + fclose(file); // Close the file + + *bytesOut = buffer; + *sizeOut = fileSize; + return true; +} + +// llava-1.6 type of resize_and_pad (black) +static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair& target_resolution) { + int target_width = target_resolution.first; + int target_height = target_resolution.second; + + float scale_w = static_cast(target_width) / image.nx; + float scale_h = static_cast(target_height) / image.ny; + + int new_width, new_height; + + if (scale_w < scale_h) { + new_width = target_width; + new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); + } else { + new_height = target_height; + new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); + } + + clip_image_u8 resized_image; + // bilinear_resize(image, resized_image, new_width, new_height); + bicubic_resize(image, resized_image, new_width, new_height); + + clip_image_u8 padded_image; + padded_image.nx = target_width; + padded_image.ny = target_height; + padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black + + // Calculate padding offsets + int pad_x = (target_width - new_width) / 2; + int pad_y = (target_height - new_height) / 2; + + // Copy the resized image into the center of the padded buffer + for (int y = 0; y < new_height; ++y) { + for (int x = 0; x < new_width; ++x) { + for (int c = 0; c < 3; ++c) { + padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; + } + } + } + image_output = std::move(padded_image); +} + +// The above is copied from llava. +// The following is added for resizing an image +// without padding the image +// Assumes that the output size is a square image + +bool load_and_stretch_image(const char* path, int output_size, + std::vector &output_data, + const float mean[3], const float std[3]) { + unsigned char * image_bytes; // allocation done by the load_file_to_bytes function + long image_size; + if (!load_file_to_bytes(path, &image_bytes, &image_size)) { + printf("Failed to load the specified image file.\n"); + return false; + } + + clip_image_u8 * clip_img = clip_image_u8_init(); + if (!clip_image_load_from_bytes(image_bytes, image_size, clip_img)) { + clip_image_u8_free(clip_img); + free(image_bytes); + printf("Failed to create CLIP image structure from image bytes.\n"); + return false; + } + + clip_image_u8 resized_image; + bicubic_resize(*clip_img, resized_image, output_size, output_size); + clip_image_u8_free(clip_img); + free(image_bytes); + clip_image_f32 float_image; + + normalize_image_u8_to_f32(&resized_image, &float_image, mean, std); + + output_data.resize(3 * output_size * output_size); + for (int c=0; c<3; c++) { + for (int y=0; y +#include +// Copied from the LLAVA example + +struct clip_image_size { + int width; + int height; +}; + +struct clip_image_u8_batch { + struct clip_image_u8 * data; + size_t size; +}; + +struct clip_image_f32_batch { + struct clip_image_f32 * data; + size_t size; +}; + +struct clip_image_size * clip_image_size_init(); +struct clip_image_u8 * clip_image_u8_init (); +struct clip_image_f32 * clip_image_f32_init(); + +void clip_image_u8_free (struct clip_image_u8 * img); +void clip_image_f32_free(struct clip_image_f32 * img); +void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); +void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); + +bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); + +/** interpret bytes as an image file with length bytes_length, and use the result to populate img */ +bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); +bool load_and_stretch_image(const char* path, int output_size, std::vector &output_data, + const float mean[3], const float std[3]); \ No newline at end of file diff --git a/examples/cogagent/vision_encoder.cpp b/examples/cogagent/vision_encoder.cpp new file mode 100644 index 0000000000000..faffd35dc1d46 --- /dev/null +++ b/examples/cogagent/vision_encoder.cpp @@ -0,0 +1,300 @@ +#include "cogagent.h" +#include "vision_encoder.h" + +#include +#include +#include +#include +#include +#include + +bool vision_encoder_init_load(const char * filename) { + vision_encoder_ctx &model_ctx = cogagent_global.vision_encoder; + vision_encoder &model = cogagent_global.vision_encoder.model; + + model_ctx.backend = ggml_backend_cpu_init(); + ggml_backend_cpu_set_n_threads(model_ctx.backend, 16); + + // Initialize the GGML contexts + struct ggml_init_params weight_params { + 10000000, // Memory size + NULL, // Memory buffer + true, // Don't allocate tensor data + }; + + struct ggml_init_params compute_params { + 6400*ggml_tensor_overhead() + ggml_graph_overhead(), + NULL, + true, + }; + + model_ctx.ctx_weight = ggml_init(weight_params); + model_ctx.ctx_compute = ggml_init(compute_params); + + // Load the model weights + struct ggml_context * meta; + struct gguf_init_params gguf_params = { + true, &meta, + }; + + struct gguf_context * gguf_ctx = gguf_init_from_file(filename, gguf_params); + if (!gguf_ctx) { + printf("Failed to initialize GGUF context. Check filename.\n"); + return false; + } + + // Calculate the ctx size based on tensors in the GGUF file + size_t ctx_size = 0; + int num_tensors = gguf_get_n_tensors(gguf_ctx); + printf("There are %d tensors in the GGUF file\n", num_tensors); + for (int i=0; i 0) { + printf("%d tensors could not be found in the model context. Model loading failed.\n", failed_count); + return false; + } + + model_ctx.weight_data = ggml_backend_alloc_ctx_tensors(model_ctx.ctx_weight, model_ctx.backend); + + if (!load_from_gguf(filename, model_ctx.ctx_weight, gguf_ctx)) { + printf("Loading data from GGUF file failed\n"); + return false; + } + + ggml_free(meta); + + return true; +} + +// This is not declared in the header because it is only intended +// to be called from run_vision_encoder +static struct ggml_cgraph * vision_encoder_graph() { + struct ggml_context * ctx = cogagent_global.vision_encoder.ctx_compute; + vision_encoder &model = cogagent_global.vision_encoder.model; + + // Set flag to tell allocator to not overwrite the input tensor + // before it is done with computation + ggml_set_input(model.input_image); + + // Assuming input: h, w, 3, b + // This is different from PyTorch, which is b, 3, h, w + // Confirm later that the height and width are not swapped + // It would appear that the number of out channels in the ggml_conv_2d function + // is implied from the kernel + // 3 -> 1792 + struct ggml_tensor * patch_embedding = ggml_conv_2d(ctx, model.patch_conv_w, model.input_image, + 14, 14, 0, 0, 1, 1); // Don't actually know dilation does :) Matches the torch defaults + // When adding together float16 and float32, float16 has to be first + patch_embedding = ggml_add(ctx, ggml_repeat(ctx, model.patch_conv_b, patch_embedding), patch_embedding); + // after conv: w, h, d, b + // after flatten: w x h = l, d, b + // after transpose: d, l, b + // cls token shape: d, 1, b + // after concatenation: d, l+1, b + patch_embedding = ggml_reshape_3d(ctx, patch_embedding, patch_embedding->ne[0] * patch_embedding->ne[1], patch_embedding->ne[2], patch_embedding->ne[3]); // Flatten + patch_embedding = ggml_cont(ctx, patch_embedding); + // d, l, b shape at this point + // Most layer weights will need reshaping + patch_embedding = ggml_transpose(ctx, patch_embedding); + patch_embedding = ggml_cont(ctx, patch_embedding); + // Assume cls_embed and position_embed are expanded to have 1 more dimension + // than original. Assume that these operations can broadcast automatically + struct ggml_tensor * cls_embed_shape = ggml_new_tensor_3d(ctx, model.cls_embed->type, + patch_embedding->ne[0], 1, patch_embedding->ne[2]); + patch_embedding = ggml_concat(ctx, ggml_repeat(ctx, model.cls_embed, cls_embed_shape), patch_embedding, 1); + patch_embedding = ggml_cont(ctx, patch_embedding); + // num_positions in the config is 257, which is 256 + 1 + // 224 x 224 becomes 16 x 16 after convolution with a kernel of 14 x 14 + // 16 x 16 + 1 = 257 + patch_embedding = ggml_add(ctx, ggml_repeat(ctx, model.position_embed_1, patch_embedding), patch_embedding); + + // Original PatchEmbedding complete at this point + // Loop through the transformer layers + struct ggml_tensor * layer_in = patch_embedding; + for (int i=0; ine[0] / 3, qkv->ne[1], qkv->ne[2], + qkv->nb[1], qkv->nb[2], 0); + struct ggml_tensor * kt = ggml_view_3d(ctx, qkv, qkv->ne[0] / 3, qkv->ne[1], qkv->ne[2], + qkv->nb[1], qkv->nb[2], qkv->ne[0] / 3 * qkv->nb[0]); + struct ggml_tensor * vt = ggml_view_3d(ctx, qkv, qkv->ne[0] / 3, qkv->ne[1], qkv->ne[2], + qkv->nb[1], qkv->nb[2], 2 * qkv->ne[0] / 3 * qkv->nb[0]); + qt = ggml_cont(ctx, qt); + kt = ggml_cont(ctx, kt); + vt = ggml_cont(ctx, vt); + qt = ggml_scale(ctx, qt, model.attn_scale); + // Separate into heads + // K x H x L x B + qt = ggml_view_4d(ctx, qt, qt->ne[0] / model.num_heads, model.num_heads, qt->ne[1], qt->ne[2], + qt->ne[0] / model.num_heads * qt->nb[0], qt->nb[1], qt->nb[2], 0); + kt = ggml_view_4d(ctx, kt, kt->ne[0] / model.num_heads, model.num_heads, kt->ne[1], kt->ne[2], + kt->ne[0] / model.num_heads * kt->nb[0], kt->nb[1], kt->nb[2], 0); + vt = ggml_view_4d(ctx, vt, vt->ne[0] / model.num_heads, model.num_heads, vt->ne[1], vt->ne[2], + vt->ne[0] / model.num_heads * vt->nb[0], vt->nb[1], vt->nb[2], 0); + qt = ggml_cont(ctx, qt); + kt = ggml_cont(ctx, kt); + vt = ggml_cont(ctx, vt); + // Switch order of dimensions + // K x L x H x B + qt = ggml_permute(ctx, qt, 0, 2, 1, 3); + kt = ggml_permute(ctx, kt, 0, 2, 1, 3); + qt = ggml_cont(ctx, qt); + kt = ggml_cont(ctx, kt); + // L x K x H x B + struct ggml_tensor * v = ggml_permute(ctx, vt, 1, 2, 0, 3); + v = ggml_cont(ctx, v); + // L x L x H x B + struct ggml_tensor * attnt = ggml_mul_mat(ctx, kt, qt); + attnt = ggml_soft_max(ctx, attnt); // Should be on the first dimension + // attnt = vt x attnt, but we need to give it v instead of vt + // K x L x H x B + attnt = ggml_mul_mat(ctx, v, attnt); + // Switch the dimensions back + attnt = ggml_permute(ctx, attnt, 0, 2, 1, 3); // K, H, L, B + attnt = ggml_cont(ctx, attnt); + attnt = ggml_view_3d(ctx, attnt, attnt->ne[0] * attnt->ne[1], attnt->ne[2], attnt->ne[3], + attnt->nb[2], attnt->nb[3], 0); // D, L, B + attnt = ggml_mul_mat(ctx, cur_layer.attn_dense_w, attnt); + attnt = ggml_add(ctx, ggml_repeat(ctx, cur_layer.attn_dense_b, attnt), attnt); + // Attention calculation is now complete + attnt = ggml_norm(ctx, attnt, model.layernorm_eps); // Config value is 0.000001 + attnt = ggml_mul(ctx, ggml_repeat(ctx, cur_layer.input_norm_w, attnt), attnt); + attnt = ggml_add(ctx, ggml_repeat(ctx, cur_layer.input_norm_b, attnt), attnt); + layer_in = ggml_add(ctx, layer_in, attnt); // D, L, B + // Perform MLP calculation + // Weight has shape intermediate size x D + struct ggml_tensor * fc1 = ggml_mul_mat(ctx, cur_layer.fc1_w, layer_in); + fc1 = ggml_add(ctx, ggml_repeat(ctx, cur_layer.fc1_b, fc1), fc1); + struct ggml_tensor * gelu = ggml_gelu(ctx, fc1); + struct ggml_tensor * fc2 = ggml_mul_mat(ctx, cur_layer.fc2_w, gelu); + fc2 = ggml_add(ctx, ggml_repeat(ctx, cur_layer.fc2_b, fc2), fc2); + fc2 = ggml_norm(ctx, fc2, model.layernorm_eps); + fc2 = ggml_mul(ctx, ggml_repeat(ctx, cur_layer.post_attention_norm_w, fc2), fc2); + fc2 = ggml_add(ctx, ggml_repeat(ctx, cur_layer.post_attention_norm_b, fc2), fc2); + layer_in = ggml_add(ctx, layer_in, fc2); + } + struct ggml_tensor * transformer_output = layer_in; + // Drop class embedding + { + long d = transformer_output->ne[0]; + long l = transformer_output->ne[1]; + long b = transformer_output->ne[2]; + transformer_output = ggml_view_3d(ctx, transformer_output, d, l-1, b, + transformer_output->nb[1], transformer_output->nb[2], transformer_output->nb[1]); + } + // Linear projection + struct ggml_tensor * linear_proj = ggml_add(ctx, ggml_repeat(ctx, model.position_embed_2, transformer_output), transformer_output); + struct ggml_tensor * linear_proj_tmp = ggml_mul_mat(ctx, model.linear_proj_w, linear_proj); + linear_proj = ggml_norm(ctx, linear_proj_tmp, model.layernorm_eps); + linear_proj = ggml_mul(ctx, ggml_repeat(ctx, model.linear_proj_norm_w, linear_proj), linear_proj); + linear_proj = ggml_add(ctx, ggml_repeat(ctx, model.linear_proj_norm_b, linear_proj), linear_proj); + linear_proj = ggml_gelu(ctx, linear_proj); + struct ggml_tensor * gate_proj = ggml_mul_mat(ctx, model.gate_proj_w, linear_proj); + gate_proj = ggml_silu(ctx, gate_proj); + struct ggml_tensor * h_4h = ggml_mul_mat(ctx, model.dense_h_to_4h_w, linear_proj); + linear_proj = ggml_mul(ctx, gate_proj, h_4h); + linear_proj = ggml_mul_mat(ctx, model.dense_4h_to_h_w, linear_proj); + // GLU complete + struct ggml_tensor * expanded_size = ggml_new_tensor_3d(ctx, linear_proj->type, linear_proj->ne[0], 1, linear_proj->ne[2]); + model.output_tensor = ggml_concat(ctx, ggml_repeat(ctx, model.boi, expanded_size), linear_proj, 1); + model.output_tensor = ggml_concat(ctx, model.output_tensor, ggml_repeat(ctx, model.eoi, expanded_size), 1); + + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx, 4096, false); + ggml_build_forward_expand(gf, model.output_tensor); + return gf; +} + +void run_vision_encoder(std::vector img_data) { + vision_encoder_ctx &model_ctx = cogagent_global.vision_encoder; + vision_encoder &model = cogagent_global.vision_encoder.model; + + // Declare the input image tensor + model.input_image = ggml_new_tensor_3d(cogagent_global.vision_encoder.ctx_compute, GGML_TYPE_F32, + cogagent_global.vision_encoder_img_size, cogagent_global.vision_encoder_img_size, 3); + + struct ggml_cgraph * gf = vision_encoder_graph(); + ggml_graph_print(gf); + printf("Number of nodes in the vision encoder graph is %d\n", ggml_graph_n_nodes(gf)); + + model_ctx.allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model_ctx.backend)); + ggml_gallocr_reserve(model_ctx.allocr, gf); + size_t compute_size = ggml_gallocr_get_buffer_size(model_ctx.allocr, 0); + printf("Allocated %ld bytes of space for graph computation.\n", compute_size); + ggml_gallocr_alloc_graph(model_ctx.allocr, gf); + + ggml_backend_tensor_set(model.input_image, img_data.data(), 0, ggml_nbytes(model.input_image)); + + // Computation result is at model.output_tensor + ggml_backend_graph_compute(model_ctx.backend, gf); + + cogagent_global.vision_encoder_image.resize(model.output_tensor->ne[0] * + model.output_tensor->ne[1]); + ggml_backend_tensor_get(model.output_tensor, cogagent_global.vision_encoder_image.data(), + 0, ggml_nbytes(model.output_tensor)); + // Added for debugging implementation of encoders + save_tensor_filename(model.output_tensor, "cogagent_encoders/vision_encoder_output.gguf"); +} + +void free_vision_encoder_ctx() { + vision_encoder_ctx &model_ctx = cogagent_global.vision_encoder; + + ggml_gallocr_free(model_ctx.allocr); + ggml_backend_buffer_free(model_ctx.weight_data); + ggml_free(model_ctx.ctx_weight); + ggml_free(model_ctx.ctx_compute); +} \ No newline at end of file diff --git a/examples/cogagent/vision_encoder.h b/examples/cogagent/vision_encoder.h new file mode 100644 index 0000000000000..d507b70c25aee --- /dev/null +++ b/examples/cogagent/vision_encoder.h @@ -0,0 +1,85 @@ +#ifndef VISION_ENCODER_H +#define VISION_ENCODER_H + +#include "ggml-backend.h" +#include +#include +#include +#include +#include +#include + +#include "cogagent.h" + +struct vision_encoder_layer { + struct ggml_tensor * qkv_w; + struct ggml_tensor * qkv_b; + + struct ggml_tensor * attn_dense_w; + struct ggml_tensor * attn_dense_b; + + struct ggml_tensor * input_norm_w; + struct ggml_tensor * input_norm_b; + + struct ggml_tensor * fc1_w; + struct ggml_tensor * fc1_b; + + struct ggml_tensor * fc2_w; + struct ggml_tensor * fc2_b; + + struct ggml_tensor * post_attention_norm_w; + struct ggml_tensor * post_attention_norm_b; +}; + +struct vision_encoder { + struct ggml_tensor * cls_embed; + struct ggml_tensor * patch_conv_w; + struct ggml_tensor * patch_conv_b; + struct ggml_tensor * position_embed_1; + + std::vector transformer_layers; + + struct ggml_tensor * position_embed_2; + + struct ggml_tensor * linear_proj_w; + + struct ggml_tensor * linear_proj_norm_w; + struct ggml_tensor * linear_proj_norm_b; + + struct ggml_tensor * gate_proj_w; + struct ggml_tensor * dense_h_to_4h_w; + struct ggml_tensor * dense_4h_to_h_w; + + struct ggml_tensor * boi; + struct ggml_tensor * eoi; + + int hidden_size = 1792; + int num_heads = 16; + int head_hidden_size = hidden_size / num_heads; + int num_layers = 63; + float layernorm_eps = 0.000001; + float attn_scale = 1.0 / std::sqrt(head_hidden_size); + struct ggml_tensor * input_image; + struct ggml_tensor * output_tensor; +}; + +struct vision_encoder_ctx { + struct ggml_context * ctx_weight; + struct ggml_context * ctx_compute; + ggml_backend_buffer_t weight_data; + ggml_backend_t backend; + ggml_gallocr_t allocr; + vision_encoder model; +}; + +// Assume that overall context is accessible as a static variable +bool vision_encoder_init_load(const char * filename); + +// Defines a graph and runs the vision encoder +// Assumes that the picture is of the correct size +void run_vision_encoder(std::vector img_data); + +// Free the weights stored for the vision encoder +void free_vision_encoder_ctx(); + +#endif \ No newline at end of file diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e0acc47059656..aa0681494f738 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3150,6 +3150,7 @@ struct server_context { batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, + nullptr, }; const int ret = llama_decode(ctx, batch_view); diff --git a/examples/vision/CMakeLists.txt b/examples/vision/CMakeLists.txt new file mode 100644 index 0000000000000..ab009157a957f --- /dev/null +++ b/examples/vision/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-vision) +add_executable(${TARGET} vision.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/vision/README.md b/examples/vision/README.md new file mode 100644 index 0000000000000..c2468444caa89 --- /dev/null +++ b/examples/vision/README.md @@ -0,0 +1,3 @@ +# llama.cpp/example/simple-vision + +Minimal demo for vision API diff --git a/examples/vision/vision.cpp b/examples/vision/vision.cpp new file mode 100644 index 0000000000000..d97067bba616f --- /dev/null +++ b/examples/vision/vision.cpp @@ -0,0 +1,216 @@ +#include "llama.h" +#include "common.h" +#include "arg.h" +#include "log.h" +#include "sampling.h" +#include +#include +#include +#include +#include + +#define STB_IMAGE_IMPLEMENTATION +#include "stb_image.h" + +static void print_usage(int, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [--image img_path] [-p prompt]\n", argv[0]); + printf("\n"); +} + +static llama_vision_bitmap * load_image_from_file(const char * fname) { + std::ifstream file(fname, std::ios::binary); + if (!file) { + throw std::runtime_error("Unable to open file"); + } + std::vector image_bytes = std::vector( + std::istreambuf_iterator(file), + std::istreambuf_iterator()); + // decode image to byte array + int nx, ny, nc; + auto * bytes = (unsigned char *) image_bytes.data(); + auto * img = stbi_load_from_memory(bytes, image_bytes.size(), &nx, &ny, &nc, 3); + if (!img) { + throw std::runtime_error("failed to decode image bytes"); + } + // printf("nx=%d ny=%d nc=%d\n", nx, ny, nc); + // GGML_ASSERT(nc == 3); + // for (int y = 0; y < ny; y++) { + // for (int x = 0; x < nx; x++) { + // unsigned char * pix = img + x*nc + y*nc*nx; + // printf("%02x%02x%02x ", pix[0], pix[1], pix[2]); + // } + // printf("\n"); + // } + // printf("\n"); + llama_vision_bitmap * result = llama_vision_bitmap_init(nx, ny); + memcpy(result->data, img, nx*ny*3); + stbi_image_free(img); + return result; +} + +// split string by a `std::string delim` instead of `char delim` +static std::vector string_split_str(std::string s, const std::string & delimiter) { + std::vector tokens; + size_t pos = 0; + std::string token; + while ((pos = s.find(delimiter)) != std::string::npos) { + token = s.substr(0, pos); + tokens.push_back(token); + s.erase(0, pos + delimiter.length()); + } + tokens.push_back(s); + return tokens; +} + +struct tokenized_part { + llama_tokens tokens; + bool is_image; +}; + +// TODO: this function is hacky, need to be improved +// static const llama_token TOKEN_IMG_PLACEMENT = -1000; +static const std::string IMG_PLACEMENT = ""; +static std::vector tokenize_with_img_placement( + const llama_vocab * vocab, + const std::string & text, + bool add_special, + bool parse_special) { + std::vector parts = string_split_str(text, IMG_PLACEMENT); + std::vector output; + for (const auto & part : parts) { + //printf("tokenizing part: %s\n", part.c_str()); + bool add_bos = &parts.front() == ∂ + auto tokens = common_tokenize(vocab, part, add_special && add_bos, parse_special); + if (tokens.empty()) { + continue; + } + output.push_back({std::move(tokens), false}); + if (&parts.back() != &part) { + // add image token to middle of 2 parts + output.push_back({{}, true}); + } + } + return output; +} + +int main(int argc, char ** argv) { + common_params params; + + // default prompt for llava 1.5 + //params.prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:\nwhat did you see?\nASSISTANT:"; + // default prompt for minicpmv 2.6 + params.prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nwhat do you see?<|im_end|>\n<|im_start|>assistant\n"; + params.n_predict = 64; + params.n_batch = 2048; + params.n_ubatch = 1024; + params.n_gpu_layers = 99; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_VISION, print_usage)) { + return 1; + } + + common_init(); + common_init_result llama_init = common_init_from_params(params); + llama_context * ctx = llama_init.context.get(); + const llama_model * model = llama_init.model.get(); + const llama_vocab * vocab = llama_model_get_vocab(model); + if (!model) { + LOG_ERR("failed to load model\n"); + return 1; + } + + struct common_sampler * smpl = common_sampler_init(model, params.sampling); + + llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); + int n_past = 0; + int n_prompt = 0; + + // process image + llama_vision_tokens * img_tokens = nullptr; + { + const char * img_path = params.image[0].c_str(); + if (params.image[0].empty()) { + LOG_ERR("no image path provided\n"); + return 1; + } + llama_vision_bitmap * img = load_image_from_file(img_path); + LOG_INF("loaded image %s, size = %d x %d\n", img_path, img->nx, img->ny); + img_tokens = llama_vision_tokenize(ctx, img); + if (!img_tokens) { + LOG_ERR("failed to create image tokens\n"); + return 1; + } + if (llama_vision_encode(ctx, img_tokens)) { + LOG_ERR("failed to encode image\n"); + return 1; + } + LOG_INF("encoded image\n"); + } + + // process prompt + { + std::vector parts = tokenize_with_img_placement(vocab, params.prompt, true, true); + for (const tokenized_part & part : parts) { + if (!part.is_image) { + for (const llama_token & token : part.tokens) { + //LOG_INF("%d -> %s\n", token, common_token_to_piece(ctx, token).c_str()); + common_batch_add(batch, token, n_past++, {0}, &part == &parts.back()); + } + LOG_INF("eval text batch (%d tokens)\n", batch.n_tokens); + if (llama_decode(ctx, batch)) { + LOG_ERR("failed to decode text prompt\n"); + return 1; + } + } else { + auto * img_embd = llama_vision_get_output_tensor(ctx); + // std::vector output_debug(ggml_nelements(img_embd)); + // ggml_backend_tensor_get(img_embd, output_debug.data(), 0, ggml_nbytes(img_embd)); + // for (int row = 0; row < 10; row++) { + // int off = row * img_embd->ne[0]; + // printf("... %f %f %f\n", output_debug[off], output_debug[off+1], output_debug[off+2]); + // } + // exit(1); + llama_batch batch_img = llama_batch_get_one_from_tensor(img_embd, n_past, 0); + n_past += batch_img.n_tokens; + LOG_INF("eval image batch (%d embeddings)\n", batch_img.n_tokens); + if (llama_decode(ctx, batch_img)) { + LOG_ERR("failed to decode image prompt\n"); + return 1; + } + llama_batch_free(batch_img); + } + } + n_prompt = n_past; + LOG_INF("prompt processed, %d tokens\n", n_prompt); + } + + // generate response + while (true){ + int n_generated = n_past - n_prompt; + if (n_generated > params.n_predict) { + printf("\n"); + break; + } + + llama_token token_id = common_sampler_sample(smpl, ctx, -1); + common_sampler_accept(smpl, token_id, true); + printf("%s", common_token_to_piece(ctx, token_id).c_str()); + fflush(stdout); + + if (llama_vocab_is_eog(vocab, token_id)) { + printf("\n"); + break; + } + + // eval the token + common_batch_clear(batch); + common_batch_add(batch, token_id, n_past++, {0}, true); + if (llama_decode(ctx, batch)) { + LOG_ERR("failed to decode token\n"); + break; + } + } + + return 0; +} diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ecac5b4bb7f59..7fb495d980902 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -202,6 +202,9 @@ class Tokenizer: FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id" FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id" FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id" + # Vision models + IMAGE_START_ID = "tokenizer.ggml.image_start_token_id" + IMAGE_END_ID = "tokenizer.ggml.image_end_token_id" # deprecated: PREFIX_ID = "tokenizer.ggml.prefix_token_id" SUFFIX_ID = "tokenizer.ggml.suffix_token_id" @@ -211,6 +214,32 @@ class Adapter: TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" + class Vision: + # only support vision.type = "vit" for now + TYPE = "vision.type" + IMAGE_SIZE = "vision.image_size" + PATCH_SIZE = "vision.patch_size" + IMAGE_MEAN = "vision.image_mean" + IMAGE_STD = "vision.image_std" + + class Vit: + ARCHITECTURE = "vision.vit.architecture" + CONTEXT_LENGTH = "vision.vit.context_length" + EMBEDDING_LENGTH = "vision.vit.embedding_length" + BLOCK_COUNT = "vision.vit.block_count" + FEED_FORWARD_LENGTH = "vision.vit.feed_forward_length" + PROJECTION_TYPE = "vision.vit.projection_type" + PROJECTION_DIM = "vision.vit.projection_dim" + USE_GELU = "vision.vit.use_gelu" + MAX_POS_EMBEDDING = "vision.vit.max_position_embeddings" + MAX_SLICES = "vision.vit.max_slices" + PROJECTOR_TYPE = "vision.vit.projector_type" + SELECT_LAYER = "vision.vit.select_layer" + PATCH_MERGE_TYPE = "vision.vit.patch_merge_type" + HEAD_COUNT = "vision.vit.attention.head_count" + LAYERNORM_EPS = "vision.vit.attention.layer_norm_epsilon" + SCALE_FACTOR = "vision.vit.scale_factor" # only used by idefics3 for now + # # recommended mapping of model tensor names for storage in gguf # @@ -279,6 +308,12 @@ class MODEL_ARCH(IntEnum): GRANITE_MOE = auto() CHAMELEON = auto() WAVTOKENIZER_DEC = auto() + COGVLM = auto() + # vision models + VISION_LLAVA = auto() + VISION_MOBILEVLM = auto() + VISION_MINICPMV = auto() + VISION_IDEFICS3 = auto() class MODEL_TENSOR(IntEnum): @@ -390,6 +425,7 @@ class MODEL_TENSOR(IntEnum): ENC_OUTPUT_NORM = auto() CLS = auto() # classifier CLS_OUT = auto() # classifier output projection + # wavtokenizer CONV1D = auto() CONVNEXT_DW = auto() CONVNEXT_NORM = auto() @@ -406,6 +442,52 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_K = auto() POSNET_ATTN_V = auto() POSNET_ATTN_OUT = auto() + ATTN_TXT_QKV = auto() + ATTN_IMG_QKV = auto() + ATTN_TXT_DENSE = auto() + ATTN_IMG_DENSE = auto() + CROSS_ATTN_Q = auto() + CROSS_ATTN_KV = auto() + CROSS_ATTN_DENSE = auto() + FFN_TXT_UP = auto() + FFN_TXT_GATE = auto() + FFN_TXT_DOWN = auto() + FFN_IMG_UP = auto() + FFN_IMG_GATE = auto() + FFN_IMG_DOWN = auto() + # vision + V_MMPROJ = auto() + V_MMPROJ_FC = auto() + V_MMPROJ_MLP = auto() + V_MMPROJ_PEG = auto() + V_ENC_EMBD_CLS = auto() + V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_POS = auto() + V_ENC_ATTN_Q = auto() + V_ENC_ATTN_K = auto() + V_ENC_ATTN_V = auto() + V_ENC_INPUT_NORM = auto() + V_ENC_OUTPUT = auto() + V_ENC_OUTPUT_NORM = auto() + V_ENC_FFN_UP = auto() + V_ENC_FFN_DOWN = auto() + V_PRE_NORM = auto() + V_POST_NORM = auto() + V_RESMPL_POS_EMBD_K = auto() # minicpmv + V_RESMPL_ATTN_Q = auto() # minicpmv + V_RESMPL_ATTN_K = auto() # minicpmv + V_RESMPL_ATTN_V = auto() # minicpmv + V_RESMPL_ATTN_OUT = auto() # minicpmv + V_RESMPL_KV = auto() # minicpmv + V_RESMPL_KV_NORM = auto() # minicpmv + V_RESMPL_POST_NORM = auto() # minicpmv + V_RESMPL_Q_NORM = auto() # minicpmv + V_RESMPL_PROJ = auto() # minicpmv + V_RESMPL_QUERY = auto() # minicpmv + V_TOK_EMBD_IMAGE = auto() # embedding for token + V_TOK_EMBD_END_IMAGE = auto() # embedding for token + V_TOK_EMBD_SLICE = auto() # embedding for token + V_TOK_EMBD_END_SLICE = auto() # embedding for token MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -465,7 +547,13 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GRANITE: "granite", MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", + MODEL_ARCH.COGVLM: "cogvlm", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + # vision + MODEL_ARCH.VISION_LLAVA: "llava", + MODEL_ARCH.VISION_MOBILEVLM: "mobilevlm", + MODEL_ARCH.VISION_MINICPMV: "minicpmv", + MODEL_ARCH.VISION_IDEFICS3: "idefics3", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -593,6 +681,52 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", + MODEL_TENSOR.ATTN_TXT_QKV: "blk.{bid}.attn_txt_qkv", + MODEL_TENSOR.ATTN_IMG_QKV: "blk.{bid}.attn_img_qkv", + MODEL_TENSOR.ATTN_TXT_DENSE: "blk.{bid}.attn_txt_dense", + MODEL_TENSOR.ATTN_IMG_DENSE: "blk.{bid}.attn_img_dense", + MODEL_TENSOR.CROSS_ATTN_Q: "blk.{bid}.cross_attn_q", + MODEL_TENSOR.CROSS_ATTN_KV: "blk.{bid}.cross_attn_kv", + MODEL_TENSOR.CROSS_ATTN_DENSE: "blk.{bid}.cross_attn_dense", + MODEL_TENSOR.FFN_TXT_UP: "blk.{bid}.ffn_txt_up", + MODEL_TENSOR.FFN_TXT_GATE: "blk.{bid}.ffn_txt_gate", + MODEL_TENSOR.FFN_TXT_DOWN: "blk.{bid}.ffn_txt_down", + MODEL_TENSOR.FFN_IMG_UP: "blk.{bid}.ffn_img_up", + MODEL_TENSOR.FFN_IMG_GATE: "blk.{bid}.ffn_img_gate", + MODEL_TENSOR.FFN_IMG_DOWN: "blk.{bid}.ffn_img_down", + # vision + MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}", + MODEL_TENSOR.V_MMPROJ_FC: "v.mmproj.fc", + MODEL_TENSOR.V_MMPROJ_MLP: "v.mmproj.mlp.{bid}", + MODEL_TENSOR.V_MMPROJ_PEG: "v.mmproj.peg.{bid}", + MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls", + MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch", + MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos", + MODEL_TENSOR.V_ENC_ATTN_Q: "v.enc.blk.{bid}.attn_q", + MODEL_TENSOR.V_ENC_ATTN_K: "v.enc.blk.{bid}.attn_k", + MODEL_TENSOR.V_ENC_ATTN_V: "v.enc.blk.{bid}.attn_v", + MODEL_TENSOR.V_ENC_INPUT_NORM: "v.enc.blk.{bid}.input_norm", + MODEL_TENSOR.V_ENC_OUTPUT: "v.enc.blk.{bid}.output", + MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.enc.blk.{bid}.output_norm", + MODEL_TENSOR.V_ENC_FFN_UP: "v.enc.blk.{bid}.ffn_up", + MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down", + MODEL_TENSOR.V_PRE_NORM: "v.pre_norm", + MODEL_TENSOR.V_POST_NORM: "v.post_norm", + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "v.resmpl.pos_embd_k", + MODEL_TENSOR.V_RESMPL_ATTN_Q: "v.resmpl.attn_q", + MODEL_TENSOR.V_RESMPL_ATTN_K: "v.resmpl.attn_k", + MODEL_TENSOR.V_RESMPL_ATTN_V: "v.resmpl.attn_v", + MODEL_TENSOR.V_RESMPL_ATTN_OUT: "v.resmpl.attn_out", + MODEL_TENSOR.V_RESMPL_KV: "v.resmpl.kv", + MODEL_TENSOR.V_RESMPL_KV_NORM: "v.resmpl.kv_norm", + MODEL_TENSOR.V_RESMPL_POST_NORM: "v.resmpl.post_norm", + MODEL_TENSOR.V_RESMPL_Q_NORM: "v.resmpl.q_norm", + MODEL_TENSOR.V_RESMPL_PROJ: "v.resmpl.proj", + MODEL_TENSOR.V_RESMPL_QUERY: "v.resmpl.query", + MODEL_TENSOR.V_TOK_EMBD_IMAGE: "v.tok_embd.image", + MODEL_TENSOR.V_TOK_EMBD_END_IMAGE: "v.tok_embd.end_image", + MODEL_TENSOR.V_TOK_EMBD_SLICE: "v.tok_embd.slice", + MODEL_TENSOR.V_TOK_EMBD_END_SLICE: "v.tok_embd.end_slice", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1515,6 +1649,27 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.COGVLM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_TXT_QKV, + MODEL_TENSOR.ATTN_IMG_QKV, + MODEL_TENSOR.ATTN_TXT_DENSE, + MODEL_TENSOR.ATTN_IMG_DENSE, + MODEL_TENSOR.ATTN_NORM_2, + MODEL_TENSOR.CROSS_ATTN_Q, + MODEL_TENSOR.CROSS_ATTN_KV, + MODEL_TENSOR.CROSS_ATTN_DENSE, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_TXT_UP, + MODEL_TENSOR.FFN_TXT_GATE, + MODEL_TENSOR.FFN_TXT_DOWN, + MODEL_TENSOR.FFN_IMG_UP, + MODEL_TENSOR.FFN_IMG_GATE, + MODEL_TENSOR.FFN_IMG_DOWN, + ], MODEL_ARCH.WAVTOKENIZER_DEC: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, @@ -1537,6 +1692,80 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_V, MODEL_TENSOR.POSNET_ATTN_OUT, ], + MODEL_ARCH.VISION_LLAVA: [ + MODEL_TENSOR.V_MMPROJ, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + ], + MODEL_ARCH.VISION_MOBILEVLM: [ + MODEL_TENSOR.V_MMPROJ_MLP, + MODEL_TENSOR.V_MMPROJ_PEG, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + ], + MODEL_ARCH.VISION_MINICPMV: [ + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_RESMPL_POS_EMBD_K, + MODEL_TENSOR.V_RESMPL_ATTN_Q, + MODEL_TENSOR.V_RESMPL_ATTN_K, + MODEL_TENSOR.V_RESMPL_ATTN_V, + MODEL_TENSOR.V_RESMPL_ATTN_OUT, + MODEL_TENSOR.V_RESMPL_KV, + MODEL_TENSOR.V_RESMPL_KV_NORM, + MODEL_TENSOR.V_RESMPL_POST_NORM, + MODEL_TENSOR.V_RESMPL_Q_NORM, + MODEL_TENSOR.V_RESMPL_PROJ, + MODEL_TENSOR.V_RESMPL_QUERY, + MODEL_TENSOR.V_TOK_EMBD_IMAGE, + MODEL_TENSOR.V_TOK_EMBD_END_IMAGE, + MODEL_TENSOR.V_TOK_EMBD_SLICE, + MODEL_TENSOR.V_TOK_EMBD_END_SLICE, + ], + MODEL_ARCH.VISION_IDEFICS3: [ + MODEL_TENSOR.V_MMPROJ_FC, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_POST_NORM, + ], # TODO } @@ -1618,6 +1847,18 @@ class PoolingType(IntEnum): CLS = 2 +class CLIPProjectorType(Enum): + MLP = 'mlp' + LDPV2 = 'ldpv2' + MINICPMV_2_5 = 'minicpmv-2.5' # resampler + MINICPMV_2_6 = 'minicpmv-2.6' # resampler + + +class CLIPPatchMergeType(Enum): + FLAT = 'flat' + SPATIAL_UNPAD = 'spatial_unpad' + + class GGMLQuantizationType(IntEnum): F32 = 0 F16 = 1 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 080d2b9dce5cb..a31ab736bc20a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -27,6 +27,8 @@ PoolingType, TokenType, ExpertGatingFuncType, + CLIPPatchMergeType, + CLIPProjectorType, ) from .quants import quant_shape_from_byte_shape @@ -875,6 +877,60 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + def add_vision_type(self, value: str) -> None: + self.add_string(Keys.Vision.TYPE, value) + + def add_vision_image_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.IMAGE_SIZE, value) + + def add_vision_patch_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.PATCH_SIZE, value) + + def add_vision_vit_architecture(self, value: str) -> None: + self.add_string(Keys.Vision.Vit.ARCHITECTURE, value) + + def add_vision_vit_context_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.CONTEXT_LENGTH, value) + + def add_vision_vit_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.EMBEDDING_LENGTH, value) + + def add_vision_vit_block_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.BLOCK_COUNT, value) + + def add_vision_vit_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.FEED_FORWARD_LENGTH, value) + + def add_vision_vit_head_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.HEAD_COUNT, value) + + def add_vision_vit_max_position_embeddings(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.MAX_POS_EMBEDDING, value) + + def add_vision_vit_projector_type(self, value: CLIPProjectorType) -> None: + self.add_string(Keys.Vision.Vit.PROJECTOR_TYPE, value.value) + + def add_vision_vit_max_slices(self, value: int) -> None: + self.add_uint32(Keys.Vision.Vit.MAX_SLICES, value) + + def add_vision_vit_select_layer(self, value: int) -> None: + self.add_int32(Keys.Vision.Vit.SELECT_LAYER, value) + + def add_vision_vit_patch_merge_type(self, value: CLIPPatchMergeType) -> None: + self.add_string(Keys.Vision.Vit.PATCH_MERGE_TYPE, value.value) + + def add_vision_vit_layer_norm_epsilon(self, value: float) -> None: + self.add_float32(Keys.Vision.Vit.LAYERNORM_EPS, value) + + def add_vision_vit_image_mean(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_MEAN, value) + + def add_vision_vit_image_std(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_STD, value) + + def add_vision_vit_scale_factor(self, value: int) -> None: + self.add_int32(Keys.Vision.Vit.SCALE_FACTOR, value) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 617791e240b60..8dfbe97e2dd6d 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -28,6 +28,7 @@ class TensorNameMap: "transformer.token_embeddings", # openelm "shared", # t5 "rwkv.embeddings", # rwkv + "model.embed_tokens", # cogvlm ), # Token type embeddings @@ -55,7 +56,7 @@ class TensorNameMap: # Output MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe cogvlm "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon "lm_head.linear", # phi2 @@ -68,7 +69,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 gpt-j falcon jais exaone - "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe + "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe nemotron cogvlm "norm", # llama-pth "transformer.norm_f", # mpt dbrx "ln_f", # refact bloom qwen gpt2 @@ -80,7 +81,6 @@ class TensorNameMap: "transformer.rms_norm", # Grok "encoder.final_layernorm", # chatglm "transformer.norm", # openelm - "model.norm", # nemotron "rwkv.ln_out", # rwkv "backbone.final_layer_norm", # wavtokenizer ), @@ -108,7 +108,7 @@ class TensorNameMap: "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe cogvlm "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon "model.layers.{bid}.ln1", # yi @@ -127,9 +127,10 @@ class TensorNameMap: # Attention norm 2 MODEL_TENSOR.ATTN_NORM_2: ( - "transformer.h.{bid}.ln_attn", # falcon40b - "encoder.layer.{bid}.layer_norm_1", # jina-v2-code - "rwkv.blocks.{bid}.ln2", # rwkv + "transformer.h.{bid}.ln_attn", # falcon40b + "encoder.layer.{bid}.layer_norm_1", # jina-v2-code + "rwkv.blocks.{bid}.ln2", # rwkv + "model.layers.{bid}.post_cross_attention_layernorm", # cogvlm ), # Attention query-key-value @@ -242,7 +243,7 @@ class TensorNameMap: "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone "h.{bid}.post_attention_layernorm", # bloom "transformer.blocks.{bid}.norm_2", # mpt - "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe + "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe cogvlm "layers.{bid}.ffn_norm", # llama-pth "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon "model.layers.{bid}.ln2", # yi @@ -787,6 +788,209 @@ class TensorNameMap: MODEL_TENSOR.POSNET_ATTN_OUT: ( "backbone.posnet.{bid}.proj_out", # wavtokenizer ), + + MODEL_TENSOR.ATTN_TXT_QKV: ( + "model.layers.{bid}.self_attn.language_expert_query_key_value", #cogvlm + ), + + MODEL_TENSOR.ATTN_IMG_QKV: ( + "model.layers.{bid}.self_attn.vision_expert_query_key_value", #cogvlm + ), + + MODEL_TENSOR.ATTN_TXT_DENSE: ( + "model.layers.{bid}.self_attn.language_expert_dense", #cogvlm + ), + + MODEL_TENSOR.ATTN_IMG_DENSE: ( + "model.layers.{bid}.self_attn.vision_expert_dense", #cogvlm + ), + + MODEL_TENSOR.CROSS_ATTN_Q: ( + "model.layers.{bid}.cross_attn.query", # cogvlm + ), + + MODEL_TENSOR.CROSS_ATTN_KV: ( + "model.layers.{bid}.cross_attn.key_value", # cogvlm + ), + + MODEL_TENSOR.CROSS_ATTN_DENSE: ( + "model.layers.{bid}.cross_attn.dense", # cogvlm + ), + + MODEL_TENSOR.FFN_TXT_UP: ( + "model.layers.{bid}.mlp.language_mlp.up_proj", # cogvlm + ), + + MODEL_TENSOR.FFN_TXT_GATE: ( + "model.layers.{bid}.mlp.language_mlp.gate_proj", # cogvlm + ), + + MODEL_TENSOR.FFN_TXT_DOWN: ( + "model.layers.{bid}.mlp.language_mlp.down_proj", # cogvlm + ), + + MODEL_TENSOR.FFN_IMG_UP: ( + "model.layers.{bid}.mlp.vision_mlp.up_proj", # cogvlm + ), + + MODEL_TENSOR.FFN_IMG_GATE: ( + "model.layers.{bid}.mlp.vision_mlp.gate_proj", # cogvlm + ), + + MODEL_TENSOR.FFN_IMG_DOWN: ( + "model.layers.{bid}.mlp.vision_mlp.down_proj", # cogvlm + ), + + ############################################################################# + + MODEL_TENSOR.V_MMPROJ: ( + "multi_modal_projector.linear_{bid}", + ), + + MODEL_TENSOR.V_MMPROJ_FC: ( + "model.connector.modality_projection.proj", # SmolVLM + ), + + MODEL_TENSOR.V_MMPROJ_MLP: ( + "model.mm_projector.mlp.mlp.{bid}", + ), + + MODEL_TENSOR.V_MMPROJ_PEG: ( + "model.mm_projector.peg.peg.{bid}", + ), + + MODEL_TENSOR.V_ENC_EMBD_CLS: ( + "vision_tower.vision_model.embeddings.class_embedding", + ), + + MODEL_TENSOR.V_ENC_EMBD_PATCH: ( + "vision_tower.vision_model.embeddings.patch_embedding", + "vpm.embeddings.patch_embedding", + "model.vision_model.embeddings.patch_embedding", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_EMBD_POS: ( + "vision_tower.vision_model.embeddings.position_embedding", + "vpm.embeddings.position_embedding", + "model.vision_model.embeddings.position_embedding", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_ATTN_Q: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", + "vpm.encoder.layers.{bid}.self_attn.q_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_ATTN_K: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", + "vpm.encoder.layers.{bid}.self_attn.k_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_ATTN_V: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", + "vpm.encoder.layers.{bid}.self_attn.v_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_INPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", + "vpm.encoder.layers.{bid}.layer_norm1", + "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_OUTPUT: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", + "vpm.encoder.layers.{bid}.self_attn.out_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", + "vpm.encoder.layers.{bid}.layer_norm2", + "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_FFN_UP: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", + "vpm.encoder.layers.{bid}.mlp.fc1", + "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM + ), + + MODEL_TENSOR.V_ENC_FFN_DOWN: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", + "vpm.encoder.layers.{bid}.mlp.fc2", + "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM + ), + + MODEL_TENSOR.V_PRE_NORM: ( + "vision_tower.vision_model.pre_layrnorm", + ), + + MODEL_TENSOR.V_POST_NORM: ( + "vision_tower.vision_model.post_layernorm", + "model.vision_model.post_layernorm", # SmolVLM + ), + + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: ( + "resampler.pos_embed_k", + ), + + MODEL_TENSOR.V_RESMPL_ATTN_Q: ( + "resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_K: ( + "resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_V: ( + "resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_OUT: ( + "resampler.attn.out_proj", + ), + + MODEL_TENSOR.V_RESMPL_KV: ( + "resampler.kv_proj", + ), + + MODEL_TENSOR.V_RESMPL_POST_NORM: ( + "resampler.ln_post", + ), + + MODEL_TENSOR.V_RESMPL_KV_NORM: ( + "resampler.ln_kv", + ), + + MODEL_TENSOR.V_RESMPL_Q_NORM: ( + "resampler.ln_q", + ), + + MODEL_TENSOR.V_RESMPL_PROJ: ( + "resampler.proj", + ), + + MODEL_TENSOR.V_RESMPL_QUERY: ( + "resampler.query", + ), + + MODEL_TENSOR.V_TOK_EMBD_IMAGE:( + "v.tok_embd.image", # tensor generated from token embeddings + ), + + MODEL_TENSOR.V_TOK_EMBD_END_IMAGE:( + "v.tok_embd.end_image", # tensor generated from token embeddings + ), + + MODEL_TENSOR.V_TOK_EMBD_SLICE:( + "v.tok_embd.slice", # tensor generated from token embeddings + ), + + MODEL_TENSOR.V_TOK_EMBD_END_SLICE:( + "v.tok_embd.end_slice", # tensor generated from token embeddings + ), } # architecture-specific block mappings diff --git a/include/llama.h b/include/llama.h index 61907ed404dbf..456be8db4d3cb 100644 --- a/include/llama.h +++ b/include/llama.h @@ -229,6 +229,18 @@ extern "C" { bool sorted; } llama_token_data_array; + // Structure represents the basic input unit of vision model + // This can be a processed image or slices of images under the hood + struct llama_vision_tokens; + + // represent an RGB image + // size of data must be equal to 3*nx*ny + typedef struct llama_vision_bitmap { + uint32_t nx; + uint32_t ny; + unsigned char * data; + } llama_vision_bitmap; + typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_decode @@ -253,6 +265,9 @@ extern "C" { int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" + + struct ggml_tensor * embd_tensor; + struct ggml_tensor * cross_embd; } llama_batch; enum llama_model_kv_override_type { @@ -529,6 +544,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model has a image attention KV cache + LLAMA_API bool llama_model_has_cross_kv(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, @@ -845,6 +863,10 @@ extern "C" { int32_t embd, int32_t n_seq_max); + // Allocates a batch based on a tensor, only used by vision API for now + // Unlike llama_batch_get_one, this will need to be freed after use + LLAMA_API struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id); + // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); @@ -1275,6 +1297,25 @@ extern "C" { // TODO: extend in the future //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); + // + // Vision API + // + + // Container for RGB bitmap + LLAMA_API struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny); + LLAMA_API void llama_vision_bitmap_free(struct llama_vision_bitmap * bmp); + + // Create image tokens from the RGB bitmap + LLAMA_API struct llama_vision_tokens * llama_vision_tokenize(struct llama_context * ctx, llama_vision_bitmap * bmp); + LLAMA_API void llama_vision_tokens_free(struct llama_vision_tokens * img_tokens); + + // User must reserve N number of tokens in tokenized text prompt for each image + // LLAMA_API int32_t llama_vision_get_n_tokens(const llama_vision_img_tokens * img_tokens); + + // Encode patches into embeddings + LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, struct llama_vision_tokens * img_tokens); + LLAMA_API struct ggml_tensor * llama_vision_get_output_tensor(struct llama_context * ctx); + // // Model split // diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e1b02e4c08f07..25b375c134a1d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -24,6 +24,7 @@ add_library(llama llama-quant.cpp llama-sampling.cpp llama-vocab.cpp + llama-vision.cpp unicode.h unicode.cpp unicode-data.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 97a1e7e5e01ef..b7d37b0d515ce 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -3,6 +3,7 @@ #include "llama-impl.h" #include +#include static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, @@ -61,7 +62,12 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_COGVLM, "cogvlm" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_VISION_LLAVA, "llava" }, + { LLM_ARCH_VISION_MOBILEVLM, "mobilevlm" }, + { LLM_ARCH_VISION_MINICPMV, "minicpmv" }, + { LLM_ARCH_VISION_IDEFICS3, "idefics3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -190,6 +196,28 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_VISION_TYPE, "vision.type" }, + { LLM_KV_VISION_IMAGE_SIZE, "vision.image_size" }, + { LLM_KV_VISION_PATCH_SIZE, "vision.patch_size" }, + { LLM_KV_VISION_IMAGE_MEAN, "vision.image_mean" }, + { LLM_KV_VISION_IMAGE_STD, "vision.image_std" }, + { LLM_KV_VISION_VIT_ARCHITECTURE, "vision.vit.architecture" }, + { LLM_KV_VISION_VIT_CONTEXT_LENGTH, "vision.vit.context_length" }, + { LLM_KV_VISION_VIT_EMBEDDING_LENGTH, "vision.vit.embedding_length" }, + { LLM_KV_VISION_VIT_BLOCK_COUNT, "vision.vit.block_count" }, + { LLM_KV_VISION_VIT_FEED_FORWARD_LENGTH, "vision.vit.feed_forward_length" }, + { LLM_KV_VISION_VIT_PROJECTION_TYPE, "vision.vit.projection_type" }, + { LLM_KV_VISION_VIT_PROJECTION_DIM, "vision.vit.projection_dim" }, + { LLM_KV_VISION_VIT_USE_GELU, "vision.vit.use_gelu" }, + { LLM_KV_VISION_VIT_MAX_POS_EMBD, "vision.vit.max_position_embeddings" }, + { LLM_KV_VISION_VIT_MAX_SLICES, "vision.vit.max_slices" }, + { LLM_KV_VISION_VIT_PROJECTOR_TYPE, "vision.vit.projector_type" }, + { LLM_KV_VISION_VIT_SELECT_LAYER, "vision.vit.select_layer" }, + { LLM_KV_VISION_VIT_PATCH_MERGE_TYPE, "vision.vit.patch_merge_type" }, + { LLM_KV_VISION_VIT_HEAD_COUNT, "vision.vit.attention.head_count" }, + { LLM_KV_VISION_VIT_LAYERNORM_EPS, "vision.vit.attention.layer_norm_epsilon" }, + { LLM_KV_VISION_VIT_SCALE_FACTOR, "vision.vit.scale_factor" }, + // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, @@ -1271,6 +1299,30 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, + { + LLM_ARCH_COGVLM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, // input_norm_w + { LLM_TENSOR_ATTN_TXT_QKV, "blk.%d.attn_txt_qkv" }, // language_qkv_w + { LLM_TENSOR_ATTN_IMG_QKV, "blk.%d.attn_img_qkv" }, // vision_qkv_w + { LLM_TENSOR_ATTN_TXT_DENSE, "blk.%d.attn_txt_dense" }, // language_dense_w + { LLM_TENSOR_ATTN_IMG_DENSE, "blk.%d.attn_img_dense" }, // vision_dense_w + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, // self_attn_norm_w + { LLM_TENSOR_CROSS_ATTN_Q, "blk.%d.cross_attn_q" }, // cross_query_w + { LLM_TENSOR_CROSS_ATTN_KV, "blk.%d.cross_attn_kv" }, // cross_query_kv + { LLM_TENSOR_CROSS_ATTN_DENSE, "blk.%d.cross_attn_dense" }, // cross_dense_w + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, // attn_norm_w + { LLM_TENSOR_FFN_TXT_UP, "blk.%d.ffn_txt_up" }, // language_up_proj_w + { LLM_TENSOR_FFN_TXT_GATE, "blk.%d.ffn_txt_gate" }, // language_gate_proj_w + { LLM_TENSOR_FFN_TXT_DOWN, "blk.%d.ffn_txt_down" }, // language_down_proj_w + { LLM_TENSOR_FFN_IMG_UP, "blk.%d.ffn_img_up" }, // vision_up_proj_w + { LLM_TENSOR_FFN_IMG_GATE, "blk.%d.ffn_img_gate" }, // vision_gate_proj_w + { LLM_TENSOR_FFN_IMG_DOWN, "blk.%d.ffn_img_down" } // vision_down_proj_w + }, + }, { LLM_ARCH_WAVTOKENIZER_DEC, { @@ -1296,6 +1348,95 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + // vision + { + LLM_ARCH_VISION_LLAVA, + { + { LLM_TENSOR_V_MMPROJ, "v.mmproj_%d" }, + { LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" }, + { LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { LLM_TENSOR_V_PRE_NORM, "v.pre_norm" }, + { LLM_TENSOR_V_POST_NORM, "v.post_norm" }, + } + }, + { + LLM_ARCH_VISION_MOBILEVLM, + { + { LLM_TENSOR_V_MMPROJ_MLP, "v.mmproj.mlp.%d" }, + { LLM_TENSOR_V_MMPROJ_PEG, "v.mmproj.peg.%d" }, + { LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" }, + { LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { LLM_TENSOR_V_PRE_NORM, "v.pre_norm" }, + { LLM_TENSOR_V_POST_NORM, "v.post_norm" }, + } + }, + { + LLM_ARCH_VISION_MINICPMV, + { + { LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" }, + { LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { LLM_TENSOR_V_RESMPL_POS_EMBD_K, "v.resmpl.pos_embd_k" }, + { LLM_TENSOR_V_RESMPL_ATTN_Q, "v.resmpl.attn_q" }, + { LLM_TENSOR_V_RESMPL_ATTN_K, "v.resmpl.attn_k" }, + { LLM_TENSOR_V_RESMPL_ATTN_V, "v.resmpl.attn_v" }, + { LLM_TENSOR_V_RESMPL_ATTN_OUT, "v.resmpl.attn_out" }, + { LLM_TENSOR_V_RESMPL_KV, "v.resmpl.kv" }, + { LLM_TENSOR_V_RESMPL_KV_NORM, "v.resmpl.kv_norm" }, + { LLM_TENSOR_V_RESMPL_POST_NORM, "v.resmpl.post_norm" }, + { LLM_TENSOR_V_RESMPL_Q_NORM, "v.resmpl.q_norm" }, + { LLM_TENSOR_V_RESMPL_PROJ, "v.resmpl.proj" }, + { LLM_TENSOR_V_RESMPL_QUERY, "v.resmpl.query" }, + { LLM_TENSOR_V_TOK_EMBD_IMAGE, "v.tok_embd.image" }, + { LLM_TENSOR_V_TOK_EMBD_END_IMAGE, "v.tok_embd.end_image" }, + { LLM_TENSOR_V_TOK_EMBD_SLICE, "v.tok_embd.slice" }, + { LLM_TENSOR_V_TOK_EMBD_END_SLICE, "v.tok_embd.end_slice" }, + } + }, + { + LLM_ARCH_VISION_IDEFICS3, + { + { LLM_TENSOR_V_MMPROJ_FC, "v.mmproj.fc" }, + { LLM_TENSOR_V_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { LLM_TENSOR_V_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { LLM_TENSOR_V_ENC_EMBD_POS, "v.enc.embd.pos" }, + { LLM_TENSOR_V_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { LLM_TENSOR_V_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { LLM_TENSOR_V_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { LLM_TENSOR_V_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { LLM_TENSOR_V_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { LLM_TENSOR_V_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { LLM_TENSOR_V_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { LLM_TENSOR_V_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { LLM_TENSOR_V_PRE_NORM, "v.pre_norm" }, + { LLM_TENSOR_V_POST_NORM, "v.post_norm" }, + } + }, { LLM_ARCH_UNKNOWN, { @@ -1445,6 +1586,52 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_ATTN_TXT_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_IMG_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_TXT_DENSE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_IMG_DENSE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CROSS_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CROSS_ATTN_KV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CROSS_ATTN_DENSE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_TXT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_TXT_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_TXT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_IMG_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_IMG_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_IMG_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + // vision + {LLM_TENSOR_V_MMPROJ, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_MMPROJ_MLP, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_MMPROJ_PEG, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_EMBD_CLS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}}, + {LLM_TENSOR_V_ENC_EMBD_PATCH, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}}, + {LLM_TENSOR_V_ENC_EMBD_POS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_ADD}}, + {LLM_TENSOR_V_ENC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_INPUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_V_ENC_OUTPUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_V_ENC_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_ENC_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_PRE_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_V_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_POS_EMBD_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_ADD}}, + {LLM_TENSOR_V_RESMPL_ATTN_Q, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_ATTN_K, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_ATTN_V, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_ATTN_OUT, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_KV, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_KV_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_POST_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_Q_NORM, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL}}, + {LLM_TENSOR_V_RESMPL_PROJ, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_V_RESMPL_QUERY, {LLM_TENSOR_LAYER_PROJECTION, GGML_OP_MUL_MAT}}, + // special token embeddings for image + {LLM_TENSOR_V_TOK_EMBD_IMAGE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, + {LLM_TENSOR_V_TOK_EMBD_END_IMAGE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, + {LLM_TENSOR_V_TOK_EMBD_SLICE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, + {LLM_TENSOR_V_TOK_EMBD_END_SLICE, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_CONCAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 122fdcebe0af6..e5eafb1138bfe 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,6 +66,12 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_COGVLM, + // vision + LLM_ARCH_VISION_LLAVA, + LLM_ARCH_VISION_MOBILEVLM, + LLM_ARCH_VISION_MINICPMV, + LLM_ARCH_VISION_IDEFICS3, LLM_ARCH_UNKNOWN, }; @@ -194,6 +200,28 @@ enum llm_kv { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_BLOCK_COUNT, + LLM_KV_VISION_TYPE, + LLM_KV_VISION_IMAGE_SIZE, + LLM_KV_VISION_PATCH_SIZE, + LLM_KV_VISION_IMAGE_MEAN, + LLM_KV_VISION_IMAGE_STD, + LLM_KV_VISION_VIT_ARCHITECTURE, + LLM_KV_VISION_VIT_CONTEXT_LENGTH, + LLM_KV_VISION_VIT_EMBEDDING_LENGTH, + LLM_KV_VISION_VIT_BLOCK_COUNT, + LLM_KV_VISION_VIT_FEED_FORWARD_LENGTH, + LLM_KV_VISION_VIT_PROJECTION_TYPE, + LLM_KV_VISION_VIT_PROJECTION_DIM, + LLM_KV_VISION_VIT_USE_GELU, + LLM_KV_VISION_VIT_MAX_POS_EMBD, + LLM_KV_VISION_VIT_MAX_SLICES, + LLM_KV_VISION_VIT_PROJECTOR_TYPE, + LLM_KV_VISION_VIT_SELECT_LAYER, + LLM_KV_VISION_VIT_PATCH_MERGE_TYPE, + LLM_KV_VISION_VIT_HEAD_COUNT, + LLM_KV_VISION_VIT_LAYERNORM_EPS, + LLM_KV_VISION_VIT_SCALE_FACTOR, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -327,11 +355,59 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_ATTN_TXT_QKV, + LLM_TENSOR_ATTN_IMG_QKV, + LLM_TENSOR_ATTN_TXT_DENSE, + LLM_TENSOR_ATTN_IMG_DENSE, + LLM_TENSOR_CROSS_ATTN_Q, + LLM_TENSOR_CROSS_ATTN_KV, + LLM_TENSOR_CROSS_ATTN_DENSE, + LLM_TENSOR_FFN_TXT_UP, + LLM_TENSOR_FFN_TXT_GATE, + LLM_TENSOR_FFN_TXT_DOWN, + LLM_TENSOR_FFN_IMG_UP, + LLM_TENSOR_FFN_IMG_GATE, + LLM_TENSOR_FFN_IMG_DOWN, + // vision + LLM_TENSOR_V_MMPROJ, + LLM_TENSOR_V_MMPROJ_FC, + LLM_TENSOR_V_MMPROJ_MLP, + LLM_TENSOR_V_MMPROJ_PEG, + LLM_TENSOR_V_ENC_EMBD_CLS, + LLM_TENSOR_V_ENC_EMBD_PATCH, + LLM_TENSOR_V_ENC_EMBD_POS, + LLM_TENSOR_V_ENC_ATTN_Q, + LLM_TENSOR_V_ENC_ATTN_K, + LLM_TENSOR_V_ENC_ATTN_V, + LLM_TENSOR_V_ENC_INPUT_NORM, + LLM_TENSOR_V_ENC_OUTPUT, + LLM_TENSOR_V_ENC_OUTPUT_NORM, + LLM_TENSOR_V_ENC_FFN_UP, + LLM_TENSOR_V_ENC_FFN_DOWN, + LLM_TENSOR_V_PRE_NORM, + LLM_TENSOR_V_POST_NORM, + // vision - minicpmv + LLM_TENSOR_V_RESMPL_POS_EMBD_K, + LLM_TENSOR_V_RESMPL_ATTN_Q, + LLM_TENSOR_V_RESMPL_ATTN_K, + LLM_TENSOR_V_RESMPL_ATTN_V, + LLM_TENSOR_V_RESMPL_ATTN_OUT, + LLM_TENSOR_V_RESMPL_KV, + LLM_TENSOR_V_RESMPL_KV_NORM, + LLM_TENSOR_V_RESMPL_POST_NORM, + LLM_TENSOR_V_RESMPL_Q_NORM, + LLM_TENSOR_V_RESMPL_PROJ, + LLM_TENSOR_V_RESMPL_QUERY, + LLM_TENSOR_V_TOK_EMBD_IMAGE, + LLM_TENSOR_V_TOK_EMBD_END_IMAGE, + LLM_TENSOR_V_TOK_EMBD_SLICE, + LLM_TENSOR_V_TOK_EMBD_END_SLICE, }; enum llm_tensor_layer { LLM_TENSOR_LAYER_INPUT, LLM_TENSOR_LAYER_REPEATING, + LLM_TENSOR_LAYER_PROJECTION, LLM_TENSOR_LAYER_OUTPUT, }; diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57fd82b..aef83863007b9 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -31,6 +31,8 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { /*n_seq_id =*/ ubatch_n_seq_id.data(), /*seq_id =*/ ubatch_seq_id.data(), /*output =*/ ubatch_output.data(), + /*embd_tensor =*/ nullptr, + /*cross_embd =*/ nullptr, }; return ubatch; } @@ -55,7 +57,9 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s } else { ubatch.token = nullptr; } - if (batch->embd) { + if (batch->embd_tensor) { + ubatch.embd_tensor = batch->embd_tensor; + } else if (batch->embd) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { memcpy( @@ -71,6 +75,9 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s } else { ubatch.embd = nullptr; } + if (batch->cross_embd) { + ubatch.cross_embd = batch->cross_embd; + } if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; @@ -139,7 +146,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr); ubatch.equal_seqs = false; if (!seq.empty()) { llama_sbatch_seq & s = seq[0]; @@ -152,7 +159,7 @@ llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr); if (!seq.empty()) { size_t length = 0; size_t n_tokens_in_ubatch = 0; @@ -179,7 +186,7 @@ llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); + llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr || batch->embd_tensor != nullptr); if (!seq.empty()) { llama_sbatch_seq & s = seq[seq.size() - 1]; size_t length = s.length < n_ubatch ? s.length : n_ubatch; @@ -320,6 +327,8 @@ struct llama_batch llama_batch_get_one( /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, + /*embd_tensor =*/ nullptr, + /*cross_embd =*/ nullptr, }; } @@ -332,6 +341,8 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, + /*embd_tensor =*/ nullptr, + /*cross_embd =*/ nullptr, }; if (embd) { @@ -353,6 +364,36 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ return batch; } +struct llama_batch llama_batch_get_one_from_tensor(struct ggml_tensor * tensor, int32_t p0, int32_t seq_id) { + GGML_ASSERT(tensor->ne[2] == 1 && tensor->ne[3] == 1); + int32_t n_tokens = tensor->ne[1]; + llama_batch batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + /*embd_tensor =*/ tensor, + /*cross_embd =*/ nullptr, + }; + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.pos [i] = p0 + i; + batch.seq_id [i] = (llama_seq_id *) malloc(sizeof(llama_seq_id)); + batch.seq_id [i][0] = seq_id; + batch.n_seq_id[i] = 1; + } + batch.seq_id[n_tokens] = nullptr; + + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + void llama_batch_free(struct llama_batch batch) { if (batch.token) free(batch.token); if (batch.embd) free(batch.embd); diff --git a/src/llama-batch.h b/src/llama-batch.h index 773c3808b770f..3a2877dde777e 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -21,6 +21,9 @@ struct llama_ubatch { int32_t * n_seq_id; // [n_seqs] llama_seq_id ** seq_id; // [n_seqs] int8_t * output; // [n_tokens] + + struct ggml_tensor * embd_tensor; + struct ggml_tensor * cross_embd; }; struct llama_sbatch_seq { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 671d2a81adabf..47cb701a3b05f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -73,7 +73,7 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens)); } - if (ubatch.embd) { + if (ubatch.embd && !ubatch.embd_tensor) { const int64_t n_embd = hparams.n_embd; const int64_t n_tokens = ubatch.n_tokens; diff --git a/src/llama-context.h b/src/llama-context.h index a9268b2920908..f4fcc4003a70f 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -6,6 +6,7 @@ #include "llama-model.h" #include "llama-kv-cache.h" #include "llama-adapter.h" +#include "llama-vision.h" #include "ggml-cpp.h" @@ -26,6 +27,7 @@ struct llama_context { struct llama_sbatch sbatch; // TODO: revisit if needed struct llama_kv_cache kv_self; struct llama_adapter_cvec cvec; + struct llama_cross_kv_cache kv_cross; std::unordered_map lora; @@ -107,6 +109,11 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + + struct ggml_tensor * debug_intermediate; + + // vision + llama_vision_context vctx; }; // TODO: make these methods of llama_context diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 1fe45410371b9..ca48ea7c614d6 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -41,6 +41,7 @@ struct llama_hparams { uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + uint32_t n_embd_cross = 1024; // For cross attention with different hidden size // for WavTokenizer struct llama_hparams_posnet posnet; @@ -96,7 +97,7 @@ struct llama_hparams { float f_max_alibi_bias = 0.0f; float f_logit_scale = 0.0f; - // Additional scale factors (Granite/Granite MoE) + // Additional scale factors (Granite/Granite MoE/MiniCPM) float f_residual_scale = 0.0f; float f_embedding_scale = 0.0f; float f_attention_scale = 0.0f; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index feffdf0de52cf..360a0d672ec19 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -716,3 +716,75 @@ void llama_kv_cache_view_update(struct llama_kv_cache_view * view, const struct __func__, kv.used, used_cells); } } + +// Cross attention KV cache +bool llama_cross_kv_cache_init(struct llama_cross_kv_cache & cache, + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + uint32_t n_elements, + bool offload) { + const struct llama_hparams & hparams = model.hparams; + const int32_t n_layer = hparams.n_layer; + cache.cache_filled = false; + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + struct ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + ctx_map[buft] = ctx; + cache.ctxs.emplace_back(ctx); + return ctx; + } + return it->second; + }; + + for (int i = 0; i < n_layer; i++) { + ggml_backend_buffer_type_t buft; + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + } else { + buft = ggml_backend_cpu_buffer_type(); + } + ggml_context * ctx = ctx_for_buft(buft); + + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to initialize cross KV cache", __func__); + return false; + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_elements); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_elements); + ggml_format_name(k, "cross_cache_k_l%d", i); + ggml_format_name(v, "cross_cache_v_l%d", i); + cache.k_l.push_back(k); + cache.v_l.push_back(v); + } + + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for cross kv cache\n", __func__); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s cross KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + cache.bufs.emplace_back(buf); + } + + return true; +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index dca6f3998c645..d5026a6df59f3 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -216,3 +216,21 @@ struct llama_kv_slot_restorer { } }; +// Simple cache that holds the computed K and V tensors +// for each layer's cross attention calculation +struct llama_cross_kv_cache { + std::vector k_l; + std::vector v_l; + + std::vector ctxs; + std::vector bufs; + + bool cache_filled; +}; + +bool llama_cross_kv_cache_init(struct llama_cross_kv_cache & cache, + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + uint32_t n_elements, + bool offload); diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 05d58ad90eba9..4ddffcfd84f22 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -375,6 +375,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_key (enum llm_kv kid, bool & result, bool required); template bool llama_model_loader::get_key (enum llm_kv kid, float & result, bool required); + template bool llama_model_loader::get_key (enum llm_kv kid, int32_t & result, bool required); template bool llama_model_loader::get_key (enum llm_kv kid, uint32_t & result, bool required); template bool llama_model_loader::get_key(enum llm_kv kid, std::string & result, bool required); @@ -439,6 +440,7 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0487c978b5e77..a49c47dbe40f2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2,6 +2,7 @@ #include "llama-impl.h" #include "llama-mmap.h" +#include "llama-vision.h" #include "llama-model-loader.h" #include "ggml-cpp.h" @@ -216,6 +217,11 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); } break; + case GGML_OP_CONCAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_concat(ctx, w, b, 0); + } break; default: GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); } @@ -1238,6 +1244,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_COGVLM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + }break; case LLM_ARCH_WAVTOKENIZER_DEC: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1257,6 +1272,56 @@ void llama_model::load_hparams(llama_model_loader & ml) { } hparams.rope_type = llama_model_rope_type(this); + + // vision model + auto & vparams = vit.hparams; + std::string vision_type; + ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); + if (vision_type == "vit") { + LLAMA_LOG_INFO("%s: loading ViT vision model\n", __func__); + has_vision = true; + ml.get_key(LLM_KV_VISION_IMAGE_SIZE, vparams.image_size, true); + ml.get_key(LLM_KV_VISION_PATCH_SIZE, vparams.patch_size, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, vparams.image_mean, 3, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, vparams.image_std, 3, true); + ml.get_key(LLM_KV_VISION_VIT_EMBEDDING_LENGTH, vparams.hidden_size, true); + ml.get_key(LLM_KV_VISION_VIT_BLOCK_COUNT, vparams.n_layer, true); + ml.get_key(LLM_KV_VISION_VIT_FEED_FORWARD_LENGTH, vparams.n_intermediate, true); + ml.get_key(LLM_KV_VISION_VIT_HEAD_COUNT, vparams.n_head, true); + ml.get_key(LLM_KV_VISION_VIT_LAYERNORM_EPS, vparams.eps, true); + ml.get_key(LLM_KV_VISION_VIT_SELECT_LAYER, vparams.select_layer, true); + ml.get_key(LLM_KV_VISION_VIT_MAX_POS_EMBD, vparams.max_pos_embd, true); + ml.get_key(LLM_KV_VISION_VIT_SCALE_FACTOR, vparams.scale_factor, false); + { + std::string name; + ml.get_key(LLM_KV_VISION_VIT_PROJECTOR_TYPE, name, true); + vparams.proj_type = vision_projector_type_from_name(name); + if (vparams.proj_type == VISION_PROJECTOR_TYPE_UNKNOWN) { + throw std::runtime_error(format("unsupported clip projector type: %s", name.c_str())); + } + } + { + std::string name; + ml.get_key(LLM_KV_VISION_VIT_PATCH_MERGE_TYPE, name, false); + vparams.mm_patch_merge_type = mm_patch_merge_from_name(name); + } + { + std::string arch; + ml.get_key(LLM_KV_VISION_VIT_ARCHITECTURE, arch, true); + vparams.arch = llm_arch_from_string(arch); + if (vparams.arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error(format("unsupported vision arch: %s", arch.c_str())); + } + } + } else if (!vision_type.empty()) { + throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); + } + + // arch-specific CLIP hparams + // switch (vparams.arch) { + // case VISION_ARCH_LLAVA: + // default: (void)0; + // } } void llama_model::load_vocab(llama_model_loader & ml) { @@ -1387,6 +1452,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; + const int64_t n_embd_cross = hparams.n_embd_cross; if (n_expert > 0 && hparams.n_expert_used == 0) { throw std::runtime_error("model has expert layers but no expert layers are used"); @@ -1432,7 +1498,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // sanity checks - if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (info.layer == LLM_TENSOR_LAYER_PROJECTION) { + // nothing to check + } else if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { if (tn.bid != -1) { GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); } @@ -1454,6 +1522,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_TENSOR_LAYER_REPEATING: buft_list = pimpl->dev_layer.at(tn.bid).buft_list; break; + case LLM_TENSOR_LAYER_PROJECTION: + buft_list = pimpl->dev_layer.back().buft_list; + break; default: GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); } @@ -3311,6 +3382,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_COGVLM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + // Not supporting ctx_split + for (int i=0; i < n_layer; i++) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv_txt = create_tensor(tn(LLM_TENSOR_ATTN_TXT_QKV, "weight", i), {n_embd, n_embd * 3}, 0); + layer.wqkv_img = create_tensor(tn(LLM_TENSOR_ATTN_IMG_QKV, "weight", i), {n_embd, n_embd * 3}, 0); + layer.wdense_txt = create_tensor(tn(LLM_TENSOR_ATTN_TXT_DENSE, "weight", i), {n_embd, n_embd}, 0); + layer.wdense_img = create_tensor(tn(LLM_TENSOR_ATTN_IMG_DENSE, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + + layer.wq_cross = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_cross}, 0); + // The input dimension is the number of dimensions from the cross vision encoder + // it might not be guaranteed that this is the same as the number of dimensions + // in the cogvlm attention calculation + layer.wkv_cross = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_KV, "weight", i), {n_embd_cross, n_embd_cross * 2}, 0); + layer.wdense_cross = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_DENSE, "weight", i), {n_embd_cross, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_txt = create_tensor(tn(LLM_TENSOR_FFN_TXT_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down_txt = create_tensor(tn(LLM_TENSOR_FFN_TXT_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up_txt = create_tensor(tn(LLM_TENSOR_FFN_TXT_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_img = create_tensor(tn(LLM_TENSOR_FFN_IMG_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down_img = create_tensor(tn(LLM_TENSOR_FFN_IMG_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up_img = create_tensor(tn(LLM_TENSOR_FFN_IMG_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_WAVTOKENIZER_DEC: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); @@ -3423,6 +3532,179 @@ bool llama_model::load_tensors(llama_model_loader & ml) { __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1, ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); } + + // load tensors for vision model + auto & vparams = vit.hparams; + if (has_vision) { + // language params + const int64_t n_embd = hparams.n_embd; + // vision params + const int64_t n_vlayer = vparams.n_layer; + const int64_t n_vembd = vparams.hidden_size; + const int64_t n_vff = vparams.n_intermediate; + const int64_t max_pos_embd = vparams.max_pos_embd; + const int64_t n_channel = 3; // always RGB + const int64_t patch_size = vparams.patch_size; + const auto tn = LLM_TN(vparams.arch); + + // TODO: vit is cpu only for now + vit.buft = ggml_backend_cpu_buffer_type(); + vit.layers.resize(n_vlayer); + + switch (vparams.arch) { + case LLM_ARCH_VISION_LLAVA: + case LLM_ARCH_VISION_MOBILEVLM: + { + if (vparams.arch == LLM_ARCH_VISION_LLAVA) { + vit.mm_1_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "weight", 1), {n_vembd, n_vff}, 0); + vit.mm_1_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "bias" , 1), {n_vff}, 0); + vit.mm_2_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "weight", 2), {n_vff, n_vff}, 0); + vit.mm_2_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ, "bias" , 2), {n_vff}, 0); + } else if (vparams.arch == LLM_ARCH_VISION_MOBILEVLM) { + vit.mm_model_mlp_0_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 0), {n_vembd, n_embd}, 0); + vit.mm_model_mlp_0_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 0), {n_embd}, 0); + vit.mm_model_mlp_2_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "weight", 2), {n_embd, n_embd}, 0); + vit.mm_model_mlp_2_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_MLP, "bias", 2), {n_embd}, 0); + vit.mm_model_peg_0_w = create_tensor(tn(LLM_TENSOR_V_MMPROJ_PEG, "weight", 0), {n_channel, n_channel, 1, n_embd}, 0); + vit.mm_model_peg_0_b = create_tensor(tn(LLM_TENSOR_V_MMPROJ_PEG, "bias", 0), {n_embd}, 0); + } + + vit.class_embedding = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_CLS ), {n_vembd}, 0); + vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0); + vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0); + + vit.pre_norm_w = create_tensor(tn(LLM_TENSOR_V_PRE_NORM, "weight"), {n_vembd}, 0); + vit.pre_norm_b = create_tensor(tn(LLM_TENSOR_V_PRE_NORM, "bias" ), {n_vembd}, 0); + vit.post_norm_w = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); + vit.post_norm_b = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_vlayer; ++i) { + auto & layer = vit.layers[i]; + + layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0); + layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0); + layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0); + layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0); + layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0); + layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0); + + layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0); + layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0); + + layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0); + layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0); + + layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0); + layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0); + } + } break; + case LLM_ARCH_VISION_MINICPMV: + { + vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0); + vit.patch_bias = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}, 0); + vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0); + + // tok embd + vit.mm_tok_embd_image = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_IMAGE, "weight"), {n_embd}, 0); + vit.mm_tok_embd_end_image = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_END_IMAGE, "weight"), {n_embd}, 0); + vit.mm_tok_embd_slice = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_SLICE, "weight"), {n_embd}, 0); + vit.mm_tok_embd_end_slice = create_tensor(tn(LLM_TENSOR_V_TOK_EMBD_END_SLICE, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_vlayer; ++i) { + auto & layer = vit.layers[i]; + + layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0); + layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0); + layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0); + layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0); + layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0); + layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0); + + layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0); + layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0); + + layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0); + layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0); + + layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0); + layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0); + } + + // resampler, we consider it as one layer on top of the encoder + int il = n_vlayer - 1; + int rs_n_embd = llama_vision_n_mmproj_embd(vit); + vit.mm_model_pos_embed_k = create_tensor(tn(LLM_TENSOR_V_RESMPL_POS_EMBD_K, "weight", il), {rs_n_embd, max_pos_embd}, 0); + vit.mm_model_query = create_tensor(tn(LLM_TENSOR_V_RESMPL_QUERY, "weight", il), {rs_n_embd, 64}, 0); // why 64? + vit.mm_model_proj = create_tensor(tn(LLM_TENSOR_V_RESMPL_PROJ, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_kv_proj = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV, "weight", il), {n_vembd, rs_n_embd}, 0); + vit.mm_model_attn_q_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_q_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_Q, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_attn_k_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_K, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_k_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_K, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_attn_v_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_V, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_v_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_V, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_attn_o_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "weight", il), {rs_n_embd, rs_n_embd}, 0); + vit.mm_model_attn_o_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_ATTN_OUT, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_ln_q_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_Q_NORM, "weight", il), {rs_n_embd}, 0); + vit.mm_model_ln_q_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_Q_NORM, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_ln_kv_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV_NORM, "weight", il), {rs_n_embd}, 0); + vit.mm_model_ln_kv_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_KV_NORM, "bias" , il), {rs_n_embd}, 0); + vit.mm_model_ln_post_w = create_tensor(tn(LLM_TENSOR_V_RESMPL_POST_NORM, "weight", il), {rs_n_embd}, 0); + vit.mm_model_ln_post_b = create_tensor(tn(LLM_TENSOR_V_RESMPL_POST_NORM, "bias" , il), {rs_n_embd}, 0); + + } break; + case LLM_ARCH_VISION_IDEFICS3: + { + int scale_factor = vit.hparams.scale_factor; + vit.projection = create_tensor(tn(LLM_TENSOR_V_MMPROJ_FC, "weight"), {n_vembd * scale_factor * scale_factor, n_embd}, 0); + + vit.patch_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_vembd}, 0); + vit.patch_bias = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_PATCH, "bias" ), {n_vembd}, 0); + vit.position_embeddings = create_tensor(tn(LLM_TENSOR_V_ENC_EMBD_POS, "weight"), {n_vembd, max_pos_embd}, 0); + + vit.post_norm_w = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "weight"), {n_vembd}, 0); + vit.post_norm_b = create_tensor(tn(LLM_TENSOR_V_POST_NORM, "bias" ), {n_vembd}, 0); + + for (int i = 0; i < n_vlayer; ++i) { + auto & layer = vit.layers[i]; + + layer.k_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "weight", i), {n_vembd, n_vembd}, 0); + layer.k_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_K, "bias" , i), {n_vembd}, 0); + layer.v_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "weight", i), {n_vembd, n_vembd}, 0); + layer.v_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_V, "bias" , i), {n_vembd}, 0); + layer.q_w = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "weight", i), {n_vembd, n_vembd}, 0); + layer.q_b = create_tensor(tn(LLM_TENSOR_V_ENC_ATTN_Q, "bias" , i), {n_vembd}, 0); + + layer.ffn_up_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "weight", i), {n_vembd, n_vff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_UP, "bias" , i), {n_vff}, 0); + layer.ffn_down_w = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "weight", i), {n_vff, n_vembd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_V_ENC_FFN_DOWN, "bias" , i), {n_vembd}, 0); + + layer.norm_in_w = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_in_b = create_tensor(tn(LLM_TENSOR_V_ENC_INPUT_NORM, "bias" , i), {n_vembd}, 0); + layer.norm_out_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "weight", i), {n_vembd}, 0); + layer.norm_out_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT_NORM, "bias" , i), {n_vembd}, 0); + + layer.output_w = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "weight", i), {n_vembd, n_vembd}, 0); + layer.output_b = create_tensor(tn(LLM_TENSOR_V_ENC_OUTPUT, "bias" , i), {n_vembd}, 0); + } + } break; + default: + throw std::runtime_error("unknown vision architecture"); + } + + if (llama_vision_n_mmproj_embd(vit) != hparams.n_embd) { + std::runtime_error("model has vision, but n_mmproj_embd != n_embd"); + } + } } ml.done_getting_tensors(); @@ -3886,6 +4168,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_CHAMELEON: + case LLM_ARCH_COGVLM: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -3918,6 +4201,12 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN2VL: return LLAMA_ROPE_TYPE_MROPE; + case LLM_ARCH_VISION_LLAVA: + case LLM_ARCH_VISION_MOBILEVLM: + case LLM_ARCH_VISION_MINICPMV: + case LLM_ARCH_VISION_IDEFICS3: + GGML_ABORT("vision arch does not use RoPE"); + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); @@ -4019,3 +4308,10 @@ bool llama_model_is_recurrent(const struct llama_model * model) { default: return false; } } + +bool llama_model_has_cross_kv(const struct llama_model * model) { + switch (model->arch) { + case LLM_ARCH_COGVLM: return true; + default: return false; + } +} diff --git a/src/llama-model.h b/src/llama-model.h index a7c30444786fd..ec725f35b31e3 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -4,6 +4,7 @@ #include "llama-arch.h" #include "llama-hparams.h" #include "llama-vocab.h" +#include "llama-vision.h" #include #include @@ -164,11 +165,18 @@ struct llama_layer { struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; + // Added this here to reuse the T5 variables + struct ggml_tensor * wkv_cross; struct ggml_tensor * wo_cross = nullptr; struct ggml_tensor * wq_enc = nullptr; struct ggml_tensor * wk_enc = nullptr; struct ggml_tensor * wv_enc = nullptr; struct ggml_tensor * wo_enc = nullptr; + struct ggml_tensor * wqkv_txt; + struct ggml_tensor * wqkv_img; + struct ggml_tensor * wdense_txt; + struct ggml_tensor * wdense_img; + struct ggml_tensor * wdense_cross; // attention bias struct ggml_tensor * bq = nullptr; @@ -198,6 +206,12 @@ struct llama_layer { struct ggml_tensor * ffn_gate_enc = nullptr; struct ggml_tensor * ffn_down_enc = nullptr; struct ggml_tensor * ffn_up_enc = nullptr; + struct ggml_tensor * ffn_gate_txt = nullptr; + struct ggml_tensor * ffn_down_txt = nullptr; + struct ggml_tensor * ffn_up_txt = nullptr; + struct ggml_tensor * ffn_gate_img = nullptr; + struct ggml_tensor * ffn_down_img = nullptr; + struct ggml_tensor * ffn_up_img = nullptr; // ff MoE struct ggml_tensor * ffn_gate_inp = nullptr; @@ -362,6 +376,10 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; + // vision + bool has_vision = false; + llama_vision_model vit; + private: struct impl; std::unique_ptr pimpl; diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp new file mode 100644 index 0000000000000..bb6ffcf32bf1c --- /dev/null +++ b/src/llama-vision.cpp @@ -0,0 +1,1259 @@ +#include "llama.h" +#include "llama-vision.h" +#include "llama-impl.h" +#include "llama-context.h" + +#include // memcpy +#include +#include + +#ifndef NDEBUG +// for debugging +#include +#include +#include + +// export llama_image_u8 to bmp file for debugging +// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c +static int bmp_export(const struct llama_image_u8 &img, const std::string &location); +#endif + +struct img_size { + int width; + int height; + img_size(int w, int h) : width(w), height(h) {} +}; + +// RGB uint8 image +// Memory layout: RGBRGBRGB... +struct llama_image_u8 { + int nx; + int ny; + std::vector buf; + llama_image_u8() {} + llama_image_u8(const llama_vision_bitmap & bmp) { + nx = bmp.nx; + ny = bmp.ny; + buf.resize(nx*ny*3); + memcpy(buf.data(), bmp.data, buf.size()); + } +}; + +uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel) { + auto & proj_type = vmodel.hparams.proj_type; + if (proj_type == VISION_PROJECTOR_TYPE_MLP) { + return vmodel.mm_2_b + ? vmodel.mm_2_b->ne[0] + : vmodel.projection->ne[1]; // idefics3 + } else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) { + return vmodel.mm_model_peg_0_b->ne[0]; + } else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) { + return 4096; // resampler + } else if (proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_6) { + return 3584; // resampler + } else { + GGML_ASSERT(false && "invalid proj type"); + } +} + + +// +// internal utils +// + +static int get_n_patches_x(const llama_vision_context & ctx) { + auto & hparams = ctx.model->hparams; + return hparams.image_size / hparams.patch_size; +} + +static int get_n_patches_y(const llama_vision_context & ctx) { + return get_n_patches_x(ctx); +} + +static int get_n_patches(const llama_vision_context & ctx) { + return get_n_patches_x(ctx) * get_n_patches_y(ctx); +} + +// +// bitmap utils +// + +/** + * Selects the best resolution from a list of possible resolutions based on the original size. + * + * @param original_size The original size of the image in the format (width, height). + * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + * @return The best fit resolution in the format (width, height). + */ +static img_size select_best_resolution(const img_size & original_size, const std::vector& possible_resolutions) { + int original_width = original_size.width; + int original_height = original_size.height; + + img_size best_fit(0, 0); + int max_effective_resolution = 0; + int min_wasted_resolution = std::numeric_limits::max(); + + for (const auto& resolution : possible_resolutions) { + int width = resolution.width; + int height = resolution.height; + float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); + int downscaled_width = static_cast(original_width * scale); + int downscaled_height = static_cast(original_height * scale); + int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); + int wasted_resolution = (width * height) - effective_resolution; + // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + best_fit = resolution; + } + } + + return best_fit; +} + +static bool bicubic_resize(const llama_image_u8 & img, llama_image_u8 & dst, int target_width, int target_height) { + auto clip = [](int x, int lower, int upper) -> int { + return std::max(lower, std::min(x, upper)); + }; + + const int nx = img.nx; + const int ny = img.ny; + + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); + + float Cc; + float C[5]; + float d0, d2, d3, a0, a1, a2, a3; + int i, j, k, jj; + int x, y; + float dx, dy; + float tx, ty; + + tx = (float)nx / (float)target_width; + ty = (float)ny / (float)target_height; + + // Bicubic interpolation; adapted from ViT.cpp, inspired from : + // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 + // -> https://en.wikipedia.org/wiki/Bicubic_interpolation + + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { + x = (int)(tx * j); + y = (int)(ty * i); + + dx = tx * j - x; + dy = ty * i - y; + + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { + d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; + + d0 = C[0] - C[1]; + d2 = C[2] - C[1]; + d3 = C[3] - C[1]; + a0 = C[1]; + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; + + const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); + dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + } + } + } + } + + return true; +} + +static std::vector divide_to_patches_u8(const llama_image_u8 & image, int patch_size) { + std::vector patches; + int width = image.nx; + int height = image.ny; + for (int i = 0; i < height; i += patch_size) { + for (int j = 0; j < width; j += patch_size) { + llama_image_u8 patch; + patch.nx = std::min(patch_size, width - j); + patch.ny = std::min(patch_size, height - i); + patch.buf.resize(3 * patch.nx * patch.ny); + for (int y = 0; y < patch.ny; ++y) { + for (int x = 0; x < patch.nx; ++x) { + for (int c = 0; c < 3; ++c) { + patch.buf[3 * (y * patch.nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c]; + } + } + } + patches.push_back(patch); + } + } + return patches; +} + +// llava-1.6 type of resize_and_pad (black) +static llama_image_u8 resize_and_pad_image(const llama_image_u8 & image, const img_size & target_resolution) { + int target_width = target_resolution.width; + int target_height = target_resolution.height; + + float scale_w = static_cast(target_width) / image.nx; + float scale_h = static_cast(target_height) / image.ny; + + int new_width, new_height; + + if (scale_w < scale_h) { + new_width = target_width; + new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); + } else { + new_height = target_height; + new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); + } + + llama_image_u8 resized_image; + // bilinear_resize(image, resized_image, new_width, new_height); + bicubic_resize(image, resized_image, new_width, new_height); + + llama_image_u8 padded_image; + padded_image.nx = target_width; + padded_image.ny = target_height; + padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black + + // Calculate padding offsets + int pad_x = (target_width - new_width) / 2; + int pad_y = (target_height - new_height) / 2; + + // Copy the resized image into the center of the padded buffer + for (int y = 0; y < new_height; ++y) { + for (int x = 0; x < new_width; ++x) { + for (int c = 0; c < 3; ++c) { + padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; + } + } + } + return padded_image; +} + +static void normalize_image_u8_to_f32(const llama_image_u8 & src, std::vector & dst, const std::array & mean, const std::array & std) { + dst.resize(src.buf.size()); + + for (size_t i = 0; i < src.buf.size(); ++i) { + int c = i % 3; // rgb + dst[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; + } +} + + +// +// processor +// + +struct llama_vision_processor { + const llama_vision_context & ctx; + llama_vision_processor(const llama_vision_context & ctx) : ctx(ctx) {} + virtual llama_vision_tokens tokenize(const llama_image_u8 & img) = 0; + virtual ~llama_vision_processor() = default; +}; + +// inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava/processing_llava.py +struct llama_vision_processor_llava : llama_vision_processor { + llama_vision_processor_llava(const llama_vision_context & ctx) : llama_vision_processor(ctx) {} + + virtual llama_vision_tokens tokenize(const llama_image_u8 & img) override { + bool pad_to_square = true; + auto & params = ctx.model->hparams; + // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing + if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) { + pad_to_square = false; + } + + llama_vision_tokens output_slices; + output_slices.n_px = get_n_patches_x(ctx); + output_slices.n_py = get_n_patches_y(ctx); + output_slices.px = params.patch_size; + output_slices.py = params.patch_size; + + // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) + // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 + + llama_image_u8 temp; + if (pad_to_square && img.nx != img.ny) { + // if the image is not square, pad it to a square + int longer_side = std::max(img.nx, img.ny); + temp.nx = longer_side; + temp.ny = longer_side; + temp.buf.resize(3 * longer_side * longer_side); + const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) + + // fill with background color + for (size_t i = 0; i < temp.buf.size(); i++) { + temp.buf[i] = bc[i % 3]; + } + + // copy from the input image + for (int y = 0; y < img.ny; y++) { + for (int x = 0; x < img.nx; x++) { + const int i = 3 * (y * img.nx + x); + const int j = 3 * (y * temp.nx + x); + temp.buf[j] = img.buf[i]; + temp.buf[j+1] = img.buf[i+1]; + temp.buf[j+2] = img.buf[i+2]; + } + } + } else if (params.image_grid_pinpoints[0] != 0) { + // "spatial_unpad" with "anyres" processing for llava-1.6 + std::vector possible_resolutions; + for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) { + img_size s(0, 0); + s.width = params.image_grid_pinpoints[i]; + s.height = params.image_grid_pinpoints[i+1]; + possible_resolutions.push_back(s); + } + img_size best_resolution = select_best_resolution(img_size(img.nx, img.ny), possible_resolutions); + // debug_image_save_to_bmp(*img, "input.bmp"); + temp = resize_and_pad_image(img, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 + // debug_image_save_to_bmp(*temp, "resized.bmp"); + + std::vector patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + + llama_image_u8 image_original_resize; + // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + patches.insert(patches.begin(), image_original_resize); + output_slices.buf.resize(patches.size()); + int num = 0; + for (auto & patch : patches) { + normalize_image_u8_to_f32(patch, output_slices.buf[num], params.image_mean, params.image_std); + num++; + } + return output_slices; + } else { + temp.nx = img.nx; + temp.ny = img.ny; + temp.buf.resize(img.buf.size()); + memcpy(temp.buf.data(), img.buf.data(), temp.buf.size()); + } + + const int nx = temp.nx; + const int ny = temp.ny; + // bmp_export(temp, "resized_vanilla.bmp"); + + const int nx2 = params.image_size; + const int ny2 = params.image_size; + std::vector res; + res.resize(3 * nx2 * ny2); + + const float scale = std::max(nx, ny) / (float)params.image_size; + + const int nx3 = int(nx / scale + 0.5f); + const int ny3 = int(ny / scale + 0.5f); + + const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; + const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; + + for (int y = 0; y < ny3; y++) { + for (int x = 0; x < nx3; x++) { + for (int c = 0; c < 3; c++) { + // linear interpolation + const float sx = (x + 0.5f) * scale - 0.5f; + const float sy = (y + 0.5f) * scale - 0.5f; + + const int x0 = std::max(0, (int)std::floor(sx)); + const int y0 = std::max(0, (int)std::floor(sy)); + + const int x1 = std::min(x0 + 1, nx - 1); + const int y1 = std::min(y0 + 1, ny - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = 3 * (y0 * nx + x0) + c; + const int j01 = 3 * (y0 * nx + x1) + c; + const int j10 = 3 * (y1 * nx + x0) + c; + const int j11 = 3 * (y1 * nx + x1) + c; + + const float v00 = temp.buf[j00]; + const float v01 = temp.buf[j01]; + const float v10 = temp.buf[j10]; + const float v11 = temp.buf[j11]; + + const float v0 = v00 * (1.0f - dx) + v01 * dx; + const float v1 = v10 * (1.0f - dx) + v11 * dx; + + const float v = v0 * (1.0f - dy) + v1 * dy; + + const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); + + const int i = 3 * (y * nx3 + x) + c; + + res[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; + } + } + } + + output_slices.buf.resize(1); + output_slices.buf[0] = std::move(res); + + return output_slices; + } +}; + +struct llama_vision_processor_uhd : llama_vision_processor { + llama_vision_processor_uhd(const llama_vision_context & ctx) : llama_vision_processor(ctx) {} + + int ensure_divide(int length, int patch_size) { + return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); + } + + img_size find_best_resize(const img_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.width; + int height = original_size.height; + if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { + float r = static_cast(width) / height; + height = static_cast(scale_resolution / std::sqrt(r)); + width = static_cast(height * r); + } + int best_width = ensure_divide(width, patch_size); + int best_height = ensure_divide(height, patch_size); + return img_size(best_width, best_height); + } + + img_size get_refine_size(const img_size & original_size, const img_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) { + int width = original_size.width; + int height = original_size.height; + int grid_x = grid.width; + int grid_y = grid.height; + + int refine_width = ensure_divide(width, grid_x); + int refine_height = ensure_divide(height, grid_y); + + int grid_width = refine_width / grid_x; + int grid_height = refine_height / grid_y; + + // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) + auto best_grid = find_best_resize({grid_width, grid_height}, scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair + + // img_size refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) + img_size refine_size = img_size(best_grid.width * grid_x, best_grid.height * grid_y); // (new line) + return refine_size; + } + + img_size find_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { + std::vector candidate_split_grids_nums; + for (int i : {multiple - 1, multiple, multiple + 1}) { + if (i == 1 || i > max_slice_nums) { + continue; + } + candidate_split_grids_nums.push_back(i); + } + + std::vector candidate_grids; + for (int split_grids_nums : candidate_split_grids_nums) { + int m = 1; + while (m <= split_grids_nums) { + if (split_grids_nums % m == 0) { + candidate_grids.emplace_back(m, split_grids_nums / m); + } + ++m; + } + } + + img_size best_grid = img_size(1, 1); + float min_error = std::numeric_limits::infinity(); + for (const auto& grid : candidate_grids) { + float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height)); + if (error < min_error) { + best_grid = grid; + min_error = error; + } + } + return best_grid; + } + + std::vector> slice_image( + const llama_image_u8 & img, + const int max_slice_nums = 9, + const int scale_resolution = 448, + const int patch_size = 14) { + const img_size original_size = img_size(img.nx, img.ny); + const int original_width = img.nx; + const int original_height = img.ny; + const float log_ratio = log(1.0*original_width/original_height); + const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); + const int multiple = fmin(ceil(ratio), max_slice_nums); + + std::vector> images; + LLAMA_LOG_DEBUG("%s: multiple %d\n", __func__, multiple); + images.push_back(std::vector()); + + if (multiple <= 1) { + auto best_size = find_best_resize(original_size, scale_resolution, patch_size, true); + llama_image_u8 source_image; + bicubic_resize(img, source_image, best_size.width, best_size.height); + // source_image = image.resize(best_size, Image.Resampling.BICUBIC) + images.back().push_back(source_image); + } else if (multiple > 1) { + auto best_size = find_best_resize(original_size, scale_resolution, patch_size); + llama_image_u8 source_image; + bicubic_resize(img, source_image, best_size.width, best_size.height); + // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) + LLAMA_LOG_DEBUG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img.nx, img.ny, best_size.width, best_size.height); + images.back().push_back(source_image); + + img_size best_grid = find_best_grid(max_slice_nums, multiple, log_ratio); + LLAMA_LOG_DEBUG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img.nx, img.ny, best_grid.width, best_grid.height); + + auto refine_size = get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); + llama_image_u8 refine_image; + // TODO: so far, we spend most of the time in bicubic_resize, we should optimize it + bicubic_resize(img, refine_image, refine_size.width, refine_size.height); + + LLAMA_LOG_DEBUG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image.nx, refine_image.ny, refine_size.width, refine_size.height); + + // split_to_patches + int width = refine_image.nx; + int height = refine_image.ny; + int grid_x = int(width / best_grid.width); + int grid_y = int(height / best_grid.height); + for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.height; patches_i += grid_y, ic += 1){ + std::vector patches_out; + images.push_back(std::vector()); + for (int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.width; patches_j += grid_x, jc += 1) { + llama_image_u8 patch; + patch.nx = grid_x; + patch.ny = grid_y; + patch.buf.resize(3 * patch.nx * patch.ny); + for (int y = patches_i; y < patches_i + grid_y; ++y) { + for (int x = patches_j; x < patches_j + grid_x; ++x) { + const int i = 3 * (y * refine_image.nx + x); + const int j = 3 * ((y-patches_i) * patch.nx + (x-patches_j)); + patch.buf[j] = refine_image.buf[i]; + patch.buf[j+1] = refine_image.buf[i+1]; + patch.buf[j+2] = refine_image.buf[i+2]; + } + } + patches_out.push_back(std::move(patch)); + } + images.push_back(std::move(patches_out)); + } + } + return images; + } + + virtual llama_vision_tokens tokenize(const llama_image_u8 & img) override { + auto & params = ctx.model->hparams; + + std::vector> imgs = slice_image(img); + + llama_vision_tokens output; + output.n_px = get_n_patches_x(ctx); + output.n_py = get_n_patches_y(ctx); + output.px = params.patch_size; + output.py = params.patch_size; + + for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t j = 0; j < imgs[i].size(); ++j) { + std::vector res; + normalize_image_u8_to_f32(imgs[i][j], res, params.image_mean, params.image_std); + output.buf.push_back(res); + } + } + + return output; + } +}; + +// +// cgraph builder +// + +// TODO: move this to llm_build_context in llama.cpp +struct llama_vision_graph_builder { + llama_vision_context & ctx; + const llama_vision_model & model; + struct ggml_context * ctx0; + int batch_size; + int hidden_size; + int n_head; + int d_head; + int patch_size; + float eps; + int num_patches; + int num_positions; + int img_w; + int img_h; + bool use_gelu; + int n_layers; + int rs_n_embd; + vision_projector_type proj_type; + + llama_vision_graph_builder(llama_vision_context & ctx, const llama_vision_tokens & inp) : ctx(ctx), model(*ctx.model) { + struct ggml_init_params params = { + /*.mem_size =*/ ctx.buf_compute_meta.size(), + /*.mem_buffer =*/ ctx.buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + ctx0 = ggml_init(params); + + auto & hparams = ctx.model->hparams; + + batch_size = inp.buf.size(); + hidden_size = hparams.hidden_size; + n_head = hparams.n_head; + d_head = hidden_size / n_head; + patch_size = hparams.patch_size; + eps = hparams.eps; + num_patches = inp.n_px * inp.n_py; + num_positions = num_patches + (model.class_embedding ? 1 : 0); + img_w = inp.px * inp.n_px; + img_h = inp.py * inp.n_py; + use_gelu = hparams.use_gelu; + n_layers = (int)hparams.n_layer + hparams.select_layer; + proj_type = hparams.proj_type; + } + + ~llama_vision_graph_builder() { + ggml_free(ctx0); + } + + struct ggml_tensor * build_inp() { + struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, img_w, img_h, 3, batch_size); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + // auto * ne = inp->ne; printf("%d %d %d %d\n", ne[0], ne[1], ne[2], ne[3]); + + struct ggml_tensor * embd = inp; + if (model.class_embedding) { + embd = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); + ggml_set_name(embd, "inp_embd"); + ggml_set_input(embd); + + embd = ggml_acc(ctx0, embd, model.class_embedding, + embd->nb[1], embd->nb[2], embd->nb[3], 0); + embd = ggml_acc(ctx0, embd, inp, + embd->nb[1], embd->nb[2], embd->nb[3], model.class_embedding->nb[1]); + } + + struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions); + ggml_set_name(positions, "inp_pos"); + ggml_set_input(positions); + + embd = ggml_add(ctx0, + embd, + ggml_get_rows(ctx0, model.position_embeddings, positions)); + + return embd; + } + + struct ggml_tensor * build_pre_norm(struct ggml_tensor * cur) { + if (model.pre_norm_w) { + cur = ggml_norm(ctx0, cur, eps); + ggml_set_name(cur, "pre_ln"); + + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.pre_norm_w), model.pre_norm_b); + } + return cur; + } + + struct ggml_tensor * build_post_norm(struct ggml_tensor * cur) { + if (model.post_norm_w) { + cur = ggml_norm(ctx0, cur, eps); + ggml_set_name(cur, "post_ln"); + + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.post_norm_w), model.post_norm_b); + } + return cur; + } + + struct ggml_tensor * build_layer(struct ggml_tensor * inpL, int il) { + struct ggml_tensor * cur = inpL; + + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_in_w), + model.layers[il].norm_in_b); + } + + // self-attention + { + struct ggml_tensor * Q = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].q_w, cur), + model.layers[il].q_b); + + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * K = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].k_w, cur), + model.layers[il].k_b); + + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + + struct ggml_tensor * V = ggml_add(ctx0, + ggml_mul_mat(ctx0, model.layers[il].v_w, cur), + model.layers[il].v_b); + + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].output_w, cur), model.layers[il].output_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.layers[il].norm_out_w), + model.layers[il].norm_out_b); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_up_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_up_b); + + if (use_gelu) { + cur = ggml_gelu_inplace(ctx0, cur); + } else { + cur = ggml_gelu_quick_inplace(ctx0, cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ffn_down_b); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + + return cur; + } + + struct ggml_tensor * build_vit() { + struct ggml_tensor * cur = build_inp(); + cur = build_pre_norm(cur); + for (int il = 0; il < n_layers; il++) { + cur = build_layer(cur, il); + } + cur = build_post_norm(cur); + return cur; + } + + // graph for each vision arch + + struct ggml_cgraph * build_llava() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false); + struct ggml_tensor * cur = build_vit(); + + // llava projector + { + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1]); + + struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); + ggml_set_name(patches, "inp_patches"); + ggml_set_input(patches); + + // shape [1, 576, 1024] + // ne is whcn, ne = [1024, 576, 1, 1] + cur = ggml_get_rows(ctx0, cur, patches); + + if (proj_type == VISION_PROJECTOR_TYPE_MLP) { + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + cur = ggml_add(ctx0, cur, model.mm_2_b); + + } else if (proj_type == VISION_PROJECTOR_TYPE_LDPV2) { + int n_patch = 24; + struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, cur); + mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b); + mlp_0 = ggml_gelu(ctx0, mlp_0); + struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0); + mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); + // mlp_2 ne = [2048, 576, 1, 1] + // // AVG Pool Layer 2*2, strides = 2 + mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3)); + // mlp_2 ne = [576, 2048, 1, 1] + mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); + // mlp_2 ne [24, 24, 2048, 1] + mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); + // weight ne = [3, 3, 2048, 1] + struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1); + peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3)); + peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b); + mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3)); + peg_0 = ggml_add(ctx0, peg_0, mlp_2); + peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]); + cur = ggml_cont(ctx0, peg_0); + + } else { + GGML_ASSERT(false && "unsupported proj type"); + } + } + + ggml_set_name(cur, "output"); + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_minicpmv() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false); + struct ggml_tensor * cur = build_vit(); + + // minicpmv resampler projector + { + int hidden_size = llama_vision_n_mmproj_embd(*ctx.model); + struct ggml_tensor * q = model.mm_model_query; + // layernorm + { + q = ggml_norm(ctx0, q, eps); + q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + } + + struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, cur); + // layernorm + { + v = ggml_norm(ctx0, v, eps); + v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b); + } + + // position + struct ggml_tensor * k = ggml_add(ctx0, v, model.mm_model_pos_embed_k); + + // attention + { + const int d_head = 128; + int n_head = hidden_size/d_head; + int num_query = -1; + if (model.hparams.proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_5) { + num_query = 96; + } else if (model.hparams.proj_type == VISION_PROJECTOR_TYPE_MINICPMV_2_6) { + num_query = 64; + } + + struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); + Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); + struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); + struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); + // permute + Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // TODO: do this when converting the model + Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size); + K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // TODO: do this when converting the model + K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); + V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); // TODO: do this when converting the model + V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_soft_max_inplace(ctx0, KQ); + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); // TODO: do this when converting the model + KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); + + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); + } + // layernorm + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.mm_model_ln_post_w), model.mm_model_ln_post_b); + } + cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur); + } + + // add and token embeddings + cur = ggml_concat(ctx0, model.mm_tok_embd_image, cur, 1); + cur = ggml_concat(ctx0, cur, model.mm_tok_embd_end_image, 1); + + ggml_set_name(cur, "output"); + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_idefics3() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, VISION_GRAPH_MAX_NODE, false); + struct ggml_tensor * cur = build_vit(); + + // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 + { + const int scale_factor = model.hparams.scale_factor; + const int n_embd = cur->ne[0]; + const int seq = cur->ne[1]; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = std::sqrt(seq); + const int width = std::sqrt(seq); + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + height / scale_factor, + width / scale_factor, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + seq / (scale_factor * scale_factor), + bsz); + + cur = ggml_mul_mat(ctx0, model.projection, cur); + } + + ggml_set_name(cur, "output"); + ggml_build_forward_expand(gf, cur); + + return gf; + } +}; + +static int32_t llama_vision_encode_impl(llama_vision_context & ctx, const llama_vision_tokens & inp) { + int batch_size = inp.buf.size(); + auto & model = *ctx.model; + auto & hparams = ctx.model->hparams; + + if (hparams.arch == LLM_ARCH_VISION_LLAVA) { + GGML_ASSERT(batch_size == 1); // TODO: support multiple images + } + + img_size image_size = img_size((int)hparams.image_size, (int)hparams.image_size); + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size.width / patch_size) * (image_size.height / patch_size)); + const int num_positions = num_patches + (model.class_embedding ? 1 : 0); + + LLAMA_LOG_DEBUG("%s: image_size = %d\n", __func__, hparams.image_size); + LLAMA_LOG_DEBUG("%s: num_positions = %d\n", __func__, num_positions); + + // build the inference graph + llama_vision_graph_builder builder(ctx, inp); + ggml_cgraph * gf; + switch(hparams.arch) { + case LLM_ARCH_VISION_LLAVA: + case LLM_ARCH_VISION_MOBILEVLM: + gf = builder.build_llava(); + break; + case LLM_ARCH_VISION_MINICPMV: + gf = builder.build_minicpmv(); + break; + case LLM_ARCH_VISION_IDEFICS3: + gf = builder.build_idefics3(); + break; + default: + GGML_ASSERT(false && "unsupported vision arch"); + } + + // alloc memory for graph + bool ok = ggml_backend_sched_alloc_graph(ctx.sched, gf); + if (!ok) { + LLAMA_LOG_ERROR("failed to alloc memory for graph\n"); + return -1; + } + + // set raw input + { + struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); + std::vector inp_buf(ggml_nelements(inp_raw)); + + for (int i = 0; i < batch_size; i++) { + const int nx = inp.px * inp.n_px; + const int ny = inp.py * inp.n_py; + const int n = nx * ny; + + for (int b = 0; b < batch_size; b++) { + for (int k = 0; k < 3; k++) { + for (int y = 0; y < ny; y++) { + for (int x = 0; x < nx; x++) { + inp_buf[(b * 3 * n) + k * n + y * nx + x] = inp.buf[b][3 * (y * nx + x) + k]; + } + } + } + } + } + ggml_backend_tensor_set(inp_raw, inp_buf.data(), 0, ggml_nbytes(inp_raw)); + } + + if (model.class_embedding) { + struct ggml_tensor * inp_embd = ggml_graph_get_tensor(gf, "inp_embd"); + ggml_set_zero(inp_embd); + } + + if (hparams.arch == LLM_ARCH_VISION_MINICPMV) { + // inspired from siglip: + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit + // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "inp_pos"); + std::vector buf(ggml_nelements(positions)); + GGML_ASSERT(num_positions == (int)buf.size()); + + int bucket_coords_h[70]; + int bucket_coords_w[70]; + size_t h = inp.py; + size_t w = inp.py; + for (size_t i = 0; i < h; i++) { + bucket_coords_h[i] = std::floor(70.0*i/h); + } + for (size_t i = 0; i < w; i++) { + bucket_coords_w[i] = std::floor(70.0*i/w); + } + for (size_t i = 0, id = 0; i < h; i++){ + for (size_t j = 0; j < w; j++){ + buf[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; + } + } + ggml_backend_tensor_set(positions, buf.data(), 0, ggml_nbytes(positions)); + + } else { + struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "inp_pos"); + std::vector pos_buf(ggml_nelements(positions)); + GGML_ASSERT(num_positions == (int)pos_buf.size()); + for (int i = 0; i < num_positions; i++) { + pos_buf[i] = i; + } + ggml_backend_tensor_set(positions, pos_buf.data(), 0, ggml_nbytes(positions)); + } + + struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "inp_patches"); + if (patches) { + std::vector patches_buf(ggml_nelements(patches)); + GGML_ASSERT(num_patches == (int)patches_buf.size()); + for (int i = 0; i < num_patches; i++) { + patches_buf[i] = i + 1; + } + ggml_backend_tensor_set(patches, patches_buf.data(), 0, ggml_nbytes(patches)); + } + + // compute + LLAMA_LOG_DEBUG("%s: compute start\n", __func__); + int64_t t_start = ggml_time_ms(); + ggml_backend_sched_graph_compute(ctx.sched, gf); + + // the last node is the embedding tensor + struct ggml_tensor * output_node = ggml_graph_node(gf, -1); + //LLAMA_LOG_INFO("%s: output tensor shape = %lld %lld %lld %lld\n", __func__, output->ne[0], output->ne[1], output->ne[2], output->ne[3]); + LLAMA_LOG_DEBUG("%s: compute time = %lld ms\n", __func__, ggml_time_ms() - t_start); + + // copy output node to context + if (ctx.ctx_ggml) { + ggml_free(ctx.ctx_ggml); + } + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ctx.ctx_ggml = ggml_init(params); + ctx.output = ggml_dup_tensor(ctx.ctx_ggml, output_node); + ggml_backend_alloc_ctx_tensors_from_buft(ctx.ctx_ggml, ctx.model->buft); + ggml_backend_tensor_copy(output_node, ctx.output); + + return 0; +} + +//////////////////////////////////////////////////////////////////////////////////////// +// public API + +struct llama_vision_bitmap * llama_vision_bitmap_init(uint32_t nx, uint32_t ny) { + llama_vision_bitmap * bmp = new llama_vision_bitmap; + bmp->nx = nx; + bmp->ny = ny; + bmp->data = (unsigned char *)malloc(3 * nx * ny); + return bmp; +} + +void llama_vision_bitmap_free(llama_vision_bitmap * bmp) { + free(bmp->data); + delete bmp; +} + +struct llama_vision_tokens * llama_vision_tokenize( + struct llama_context * ctx, + llama_vision_bitmap * bmp) { + llama_vision_context & vctx = ctx->vctx; + switch (vctx.model->hparams.arch) { + case LLM_ARCH_VISION_LLAVA: + case LLM_ARCH_VISION_MOBILEVLM: + case LLM_ARCH_VISION_IDEFICS3: + return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp)); + case LLM_ARCH_VISION_MINICPMV: + return new llama_vision_tokens(llama_vision_processor_llava(vctx).tokenize(*bmp)); + default: + GGML_ASSERT(false && "unsupported arch"); + } +} + +void llama_vision_tokens_free(llama_vision_tokens * p) { + delete p; +} + +int32_t llama_vision_encode(struct llama_context * ctx, llama_vision_tokens * p) { + if (p->buf.empty()) { + LLAMA_LOG_ERROR("%s: nothing to encode\n", __func__); + return -1; + } + + llama_vision_context & vctx = ctx->vctx; + auto & hparams = vctx.model->hparams; + switch (hparams.mm_patch_merge_type) { + case MM_PATCH_MERGE_FLAT: + { + // flat / default llava-1.5 type embedding + int32_t encoded = llama_vision_encode_impl(vctx, *p); + if (encoded != 0) { + LLAMA_LOG_ERROR("Unable to encode image\n"); + return encoded; + } + } break; + case MM_PATCH_MERGE_SPATIAL_UNPAD: + { + // TODO: support llava-1.6 + (void)0; + } break; + default: + GGML_ASSERT(false && "unsupported mm_patch_merge_type"); + } + + return 0; +} + +struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx) { + return ctx->vctx.output; +} + +//////////////////////////////////////////////////////////////////////////////////////// +// for debugging +#ifndef NDEBUG + +static int bmp_export(const struct llama_image_u8 &img, const std::string &location) { + const uint32_t width = img.nx; + const uint32_t height = img.ny; + // swap red and blue channel + std::vector buffer(width*height*3); + for (uint32_t y = 0; y < height; y++) { + for (uint32_t x = 0; x < width; x++) { + size_t base = x*3 + y*3*width; + buffer[base+2] = img.buf[base]; + buffer[base+1] = img.buf[base+1]; + buffer[base] = img.buf[base+2]; + } + } + const bool hasAlphaChannel = false; + + std::ofstream fout(location, std::ios::out | std::ios::binary); + + if (fout.fail()) { + return 0; + } + + //Padding + const uint8_t padding = hasAlphaChannel ? 0 : (4 - (width * 3) % 4) % 4; + + //Bitmap file header. + const char signature[2] = { 'B', 'M' }; + const uint32_t fileSize = buffer.size() * sizeof(uint8_t) + padding * (height - 1) + 14 + 124; + const uint32_t offset = 14 + 124; + + //Bitmap information header file + const uint32_t DIBSize = 124; + const int32_t bitmapWidth = width; + const int32_t bitmapHeight = height; + const uint16_t numPlanes = 1; + const uint16_t bitsPerPixel = (hasAlphaChannel) ? 32 : 24; + const uint32_t compressionMethod = (hasAlphaChannel) ? 3 : 0; //BI_RGB = 0, BI_BITFIELDS = 3 + const uint32_t bitmapSize = buffer.size() * sizeof(uint8_t); + const int32_t horizontalResolution = 2834; + const int32_t verticalResolution = 2834; + const uint32_t numColors = 0; + const uint32_t impColorCount = 0; + const uint32_t redBitmask = (hasAlphaChannel) ? 0x0000FF00 : 0; //ARGB32 pixel format + const uint32_t greenBitmask = (hasAlphaChannel) ? 0x00FF0000 : 0; + const uint32_t blueBitmask = (hasAlphaChannel) ? 0xFF000000 : 0; + const uint32_t alphaBitmask = (hasAlphaChannel) ? 0x000000FF : 0; + + //Writing the file header and information header to the file + std::vector header(offset, 0); + header[0] = signature[0]; + header[1] = signature[1]; + +#define BMP_HEADERS(i, variableName) header[i] = variableName; header[i+1] = variableName >> 8; header[i+2] = variableName >> 16; header[i+3] = variableName >> 24; + + BMP_HEADERS(2, fileSize); + BMP_HEADERS(6, 0); + BMP_HEADERS(10, offset); + BMP_HEADERS(14, DIBSize); + BMP_HEADERS(18, bitmapWidth); + BMP_HEADERS(22, bitmapHeight); + + header[26] = (uint8_t)numPlanes; + header[27] = (uint8_t)(numPlanes >> 8); + header[28] = (uint8_t)bitsPerPixel; + header[29] = (uint8_t)(bitsPerPixel >> 8); + + BMP_HEADERS(30, compressionMethod); + BMP_HEADERS(34, (unsigned char)bitmapSize); + BMP_HEADERS(38, (unsigned char)horizontalResolution); + BMP_HEADERS(42, (unsigned char)verticalResolution); + BMP_HEADERS(46, (unsigned char)numColors); + BMP_HEADERS(50, (unsigned char)impColorCount); + BMP_HEADERS(54, (unsigned char)redBitmask); + BMP_HEADERS(58, (unsigned char)greenBitmask); + BMP_HEADERS(62, (unsigned char)blueBitmask); + BMP_HEADERS(66, alphaBitmask); + +#undef BMP_HEADERS + + fout.write((char *)header.data(), sizeof(uint8_t) * header.size()); + + //Writing the pixel array + const uint32_t bWidth = bitsPerPixel / 8 * width; + + for (int i = height - 1; i >= 0; i--) { + std::vector row(buffer.begin() + i * bWidth, buffer.begin() + i * bWidth + bWidth); + fout.write((char *)row.data(), row.size() * sizeof(uint8_t)); + fout.seekp(padding * sizeof(uint8_t), std::ios::cur); + } + + fout.close(); + return 1; +} + +#endif + diff --git a/src/llama-vision.h b/src/llama-vision.h new file mode 100644 index 0000000000000..953ec57953079 --- /dev/null +++ b/src/llama-vision.h @@ -0,0 +1,192 @@ +#pragma once + +#include "ggml.h" +#include "llama.h" +#include "llama-arch.h" + +#include +#include + +#define VISION_GRAPH_MAX_NODE 2048 + +enum vision_projector_type { + VISION_PROJECTOR_TYPE_UNKNOWN, + VISION_PROJECTOR_TYPE_MLP, + VISION_PROJECTOR_TYPE_LDPV2, + VISION_PROJECTOR_TYPE_MINICPMV_2_5, + VISION_PROJECTOR_TYPE_MINICPMV_2_6, +}; + +enum mm_patch_merge { + MM_PATCH_MERGE_UNKNOWN, + MM_PATCH_MERGE_FLAT, + MM_PATCH_MERGE_SPATIAL_UNPAD, +}; + +struct llama_vision_model { + struct vision_hparams { + llm_arch arch = LLM_ARCH_UNKNOWN; + + uint32_t image_size; + uint32_t patch_size; + uint32_t hidden_size; + uint32_t n_intermediate; + uint32_t projection_dim; + uint32_t n_head; + uint32_t n_layer; + uint32_t max_pos_embd; + int32_t select_layer = 0; + bool use_gelu = false; + + float eps; + + vision_projector_type proj_type = VISION_PROJECTOR_TYPE_UNKNOWN; + mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_UNKNOWN; + + std::array image_mean; + std::array image_std; + + std::array image_grid_pinpoints; // TODO: should this be array of (x, y) pairs? + int32_t image_crop_resolution; + + // idefics3 + int scale_factor = 0; + }; + struct vision_hparams hparams; + ggml_backend_buffer_type_t buft; + + // embeddings + struct ggml_tensor * class_embedding = nullptr; + struct ggml_tensor * patch_embeddings = nullptr; + struct ggml_tensor * patch_bias = nullptr; + struct ggml_tensor * position_embeddings = nullptr; + + struct ggml_tensor * pre_norm_w = nullptr; + struct ggml_tensor * pre_norm_b = nullptr; + + struct vision_layer { + // attention + struct ggml_tensor * k_w = nullptr; + struct ggml_tensor * k_b = nullptr; + struct ggml_tensor * q_w = nullptr; + struct ggml_tensor * q_b = nullptr; + struct ggml_tensor * v_w = nullptr; + struct ggml_tensor * v_b = nullptr; + + struct ggml_tensor * output_w = nullptr; + struct ggml_tensor * output_b = nullptr; + + // layernorm 1 + struct ggml_tensor * norm_in_w = nullptr; + struct ggml_tensor * norm_in_b = nullptr; + + // ff + struct ggml_tensor * ffn_up_w = nullptr; + struct ggml_tensor * ffn_up_b = nullptr; + + struct ggml_tensor * ffn_down_w = nullptr; + struct ggml_tensor * ffn_down_b = nullptr; + + // layernorm 2 + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; + }; + std::vector layers; + + struct ggml_tensor * post_norm_w = nullptr; + struct ggml_tensor * post_norm_b = nullptr; + + struct ggml_tensor * projection = nullptr; + + // LLaVA projection + struct ggml_tensor * mm_1_w = nullptr; + struct ggml_tensor * mm_1_b = nullptr; + struct ggml_tensor * mm_2_w = nullptr; + struct ggml_tensor * mm_2_b = nullptr; + + // MobileVLM_V2 projection + struct ggml_tensor * mm_model_mlp_0_w = nullptr; + struct ggml_tensor * mm_model_mlp_0_b = nullptr; + struct ggml_tensor * mm_model_mlp_2_w = nullptr; + struct ggml_tensor * mm_model_mlp_2_b = nullptr; + struct ggml_tensor * mm_model_peg_0_w = nullptr; + struct ggml_tensor * mm_model_peg_0_b = nullptr; + + // MINICPMV projection + struct ggml_tensor * mm_model_pos_embed_k = nullptr; + struct ggml_tensor * mm_model_query = nullptr; + struct ggml_tensor * mm_model_proj = nullptr; + struct ggml_tensor * mm_model_kv_proj = nullptr; + struct ggml_tensor * mm_model_attn_q_w = nullptr; + struct ggml_tensor * mm_model_attn_q_b = nullptr; + struct ggml_tensor * mm_model_attn_k_w = nullptr; + struct ggml_tensor * mm_model_attn_k_b = nullptr; + struct ggml_tensor * mm_model_attn_v_w = nullptr; + struct ggml_tensor * mm_model_attn_v_b = nullptr; + struct ggml_tensor * mm_model_attn_o_w = nullptr; + struct ggml_tensor * mm_model_attn_o_b = nullptr; + struct ggml_tensor * mm_model_ln_q_w = nullptr; + struct ggml_tensor * mm_model_ln_q_b = nullptr; + struct ggml_tensor * mm_model_ln_kv_w = nullptr; + struct ggml_tensor * mm_model_ln_kv_b = nullptr; + struct ggml_tensor * mm_model_ln_post_w = nullptr; + struct ggml_tensor * mm_model_ln_post_b = nullptr; + + // special tokens + struct ggml_tensor * mm_tok_embd_image = nullptr; + struct ggml_tensor * mm_tok_embd_end_image = nullptr; + struct ggml_tensor * mm_tok_embd_slice = nullptr; + struct ggml_tensor * mm_tok_embd_end_slice = nullptr; +}; + +struct llama_vision_context { + // memory buffers used to evaluate the model + std::vector buf_compute_meta; + ggml_backend_sched_t sched = nullptr; + struct ggml_context * ctx_ggml = nullptr; + + const llama_vision_model * model; + + // temporary output data, to be picked up by llama_decode() + struct ggml_tensor * output; +}; + +// for now, this only contains: +// - the instruction for ggml_conv_2d to break the image into patches +// - the pre-processed image data in f32 +struct llama_vision_tokens { + uint32_t px; // size of patch + uint32_t py; // size of patch + size_t n_px; // number of patches in x direction + size_t n_py; // number of patches in y direction + // RGB float32 image (NHWC) + // Memory layout: RGBRGBRGB... + std::vector> buf; // preprocessed image data +}; + +inline mm_patch_merge mm_patch_merge_from_name(std::string & name) { + if (name == "flat") { + return MM_PATCH_MERGE_FLAT; + } else if (name == "spatial_unpad") { + return MM_PATCH_MERGE_SPATIAL_UNPAD; + } + return MM_PATCH_MERGE_UNKNOWN; +} + +inline vision_projector_type vision_projector_type_from_name(std::string & name) { + if (name == "mlp") { + return VISION_PROJECTOR_TYPE_MLP; + } else if (name == "ldpv2") { + return VISION_PROJECTOR_TYPE_LDPV2; + } else if (name == "minicpmv-2.5") { + return VISION_PROJECTOR_TYPE_MINICPMV_2_5; + } else if (name == "minicpmv-2.6") { + return VISION_PROJECTOR_TYPE_MINICPMV_2_6; + } + return VISION_PROJECTOR_TYPE_UNKNOWN; +} + +// only for sanity check: must be equal to n_embd of language model +uint32_t llama_vision_n_mmproj_embd(const llama_vision_model & vmodel); + +struct ggml_tensor * llama_vision_get_output_tensor(llama_context * ctx); diff --git a/src/llama.cpp b/src/llama.cpp index 5760017e0d9cb..1a71286ec1d61 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -138,6 +138,9 @@ static struct ggml_tensor * llm_build_inp_embd( ), scale); inpL = ggml_add(ctx, inpL, inpL_delta); } + } else if (ubatch.embd_tensor) { + inpL = ubatch.embd_tensor; + ggml_set_input(ubatch.embd_tensor); } else { lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens); inpL = lctx.inp_embd; @@ -154,6 +157,22 @@ static struct ggml_tensor * llm_build_inp_embd( return inpL; } +static struct ggml_tensor * llm_build_cross_embd( + struct ggml_context * ctx, + const llama_ubatch & ubatch +) { + struct ggml_tensor * cross_embd; + if (ubatch.cross_embd) { + cross_embd = ubatch.cross_embd; + } else { + printf("ubatch does not have cross_embd tensor, " + "building graph with placeholder instead\n"); + cross_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1024, 6400); + } + ggml_set_input(cross_embd); + return cross_embd; +} + static void llm_build_kv_store( struct ggml_context * ctx, const llama_hparams & hparams, @@ -697,6 +716,67 @@ static struct ggml_tensor * llm_build_kv( return cur; } +// Function that computes the cross attention score +// and stores / retrieves K and V values in the +// cross attention KV cache +static struct ggml_tensor * llm_build_cross_kv( + struct ggml_context * ctx, + struct llama_context & lctx, + struct ggml_tensor * qcur, + struct ggml_tensor * kcur, + struct ggml_tensor * vcur, + struct ggml_cgraph * graph, + int64_t il +) { + llama_cross_kv_cache & kv = lctx.kv_cross; + + // Q has dimensions K, H, L, B + // K = hidden dimension per head + // H = number of heads + // L = number of tokens + // B = batch size + const int64_t num_heads = lctx.model.hparams.n_head(); + const float cross_attn_scale = 1.0f / sqrtf(float(qcur->ne[0] / num_heads)); + // Only add the computation of K and V if + // the cache doesn't already have the data + if (!kv.cache_filled) { + // Add computation of K, V to the graph + ggml_build_forward_expand(graph, kcur); + ggml_build_forward_expand(graph, vcur); + // Copy K and V to the cross KV cache + ggml_build_forward_expand(graph, ggml_cpy(ctx, kcur, kv.k_l[il])); + ggml_build_forward_expand(graph, ggml_cpy(ctx, vcur, kv.v_l[il])); + if (il == 0) { + printf("Copying KV values to the cross KV cache\n"); + } + } + struct ggml_tensor * k = kv.k_l[il]; + struct ggml_tensor * v = kv.v_l[il]; + // Compute cross attention score + struct ggml_tensor * q = ggml_reshape_4d(ctx, qcur, qcur->ne[0] / num_heads, + num_heads, qcur->ne[1], qcur->ne[2]); + k = ggml_reshape_3d(ctx, k, 1024 / num_heads, num_heads, 6400); + v = ggml_reshape_3d(ctx, v, 1024 / num_heads, num_heads, 6400); + // K x L x H x B + q = ggml_permute(ctx, q, 0, 2, 1, 3); + k = ggml_permute(ctx, k, 0, 2, 1, 3); + // L x K x H x B + v = ggml_permute(ctx, v, 1, 2, 0, 3); + q = ggml_cont(ctx, q); + k = ggml_cont(ctx, k); + v = ggml_cont(ctx, v); + + q = ggml_scale(ctx, q, cross_attn_scale); + struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); + kq = ggml_soft_max(ctx, kq); + struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq); + kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); + kqv = ggml_cont(ctx, kqv); + kqv = ggml_reshape_3d(ctx, kqv, kqv->ne[0] * kqv->ne[1], + kqv->ne[2], kqv->ne[3]); + return kqv; +} + static struct ggml_tensor * llm_build_copy_mask_state( struct ggml_context * ctx, struct ggml_cgraph * graph, @@ -8108,6 +8188,152 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_cogvlm() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + + // Hidden dimension per head and number of heads + const int64_t n_embd_head = hparams.n_embd_head_k; + const int64_t num_heads = hparams.n_head(); + + // Multiplied directly to Q + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + // Get the cross vision encoder embedded picture + struct ggml_tensor * cross_embd; + cross_embd = llm_build_cross_embd(ctx0, ubatch); + + // Assuming text tokens are in ubatch.token, and image tokens are in ubatch.embd_tensor + bool batch_is_text; + if (ubatch.token) { + batch_is_text = true; + } else { + batch_is_text = false; + } + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ mask is given values in llama_set_inputs + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + struct ggml_tensor * inpSA = inpL; + for (int il = 0; il < n_layer; ++il) { + const llama_layer &cur_layer = model.layers[il]; + + struct ggml_tensor * self_attn_in = ggml_rms_norm(ctx0, inpSA, hparams.f_norm_rms_eps); + self_attn_in = ggml_mul(ctx0, self_attn_in, cur_layer.attn_norm); + + struct ggml_tensor * self_attn_qkv; + struct ggml_tensor * self_attn_dense; + // CogVLM uses different weights depending on + // whether tokens are text or image + if (batch_is_text) { + self_attn_qkv = cur_layer.wqkv_txt; + self_attn_dense = cur_layer.wdense_txt; + } else { + self_attn_qkv = cur_layer.wqkv_img; + self_attn_dense = cur_layer.wdense_img; + } + + // Calculate the Q, K, V values for self attention + struct ggml_tensor * qkv = ggml_mul_mat(ctx0, self_attn_qkv, self_attn_in); + struct ggml_tensor * qt = ggml_view_3d(ctx0, qkv, qkv->ne[0] / 3, qkv->ne[1], + qkv->ne[2], qkv->nb[1], qkv->nb[2], 0); + struct ggml_tensor * kt = ggml_view_3d(ctx0, qkv, qkv->ne[0] / 3, qkv->ne[1], + qkv->ne[2], qkv->nb[1], qkv->nb[2], qkv->ne[0] / 3 * qkv->nb[0]); + struct ggml_tensor * vt = ggml_view_3d(ctx0, qkv, qkv->ne[0] / 3, qkv->ne[1], + qkv->ne[2], qkv->nb[1], qkv->nb[2], 2 * qkv->ne[0] / 3 * qkv->nb[0]); + qt = ggml_cont(ctx0, qt); + kt = ggml_cont(ctx0, kt); + vt = ggml_cont(ctx0, vt); + + // Separate into heads + // K x H x L x B + qt = ggml_view_4d(ctx0, qt, qt->ne[0] / num_heads, num_heads, qt->ne[1], qt->ne[2], + qt->ne[0] / num_heads * qt->nb[0], qt->nb[1], qt->nb[2], 0); + kt = ggml_view_4d(ctx0, kt, kt->ne[0] / num_heads, num_heads, kt->ne[1], kt->ne[2], + kt->ne[0] / num_heads * kt->nb[0], kt->nb[1], kt->nb[2], 0); + qt = ggml_cont(ctx0, qt); + kt = ggml_cont(ctx0, kt); + + // Mode=2 for NEOX + qt = ggml_rope(ctx0, qt, inp_pos, n_embd_head, 2); + kt = ggml_rope(ctx0, kt, inp_pos, n_embd_head, 2); + + // The logic for the variables given to llm_build_kv is not ready yet + cur = llm_build_kv(ctx0, lctx, lctx.kv_self, gf, nullptr, nullptr, + kt, vt, qt, KQ_mask, this->n_tokens, kv_head, n_kv, kq_scale, cb, il); + cur = ggml_mul_mat(ctx0, self_attn_dense, cur); + + // Enter cross attention + inpSA = ggml_add(ctx0, inpSA, cur); + + struct ggml_tensor * cross_attn_in = ggml_rms_norm(ctx0, inpSA, hparams.f_norm_rms_eps); + cross_attn_in = ggml_mul(ctx0, cross_attn_in, cur_layer.attn_norm_2); + + qt = ggml_mul_mat(ctx0, cur_layer.wq_cross, cross_attn_in); + struct ggml_tensor * kvt = ggml_mul_mat(ctx0, cur_layer.wkv_cross, cross_embd); + kt = ggml_view_3d(ctx0, kvt, kvt->ne[0] / 2, kvt->ne[1], kvt->ne[2], + kvt->nb[1], kvt->nb[2], 0); + vt = ggml_view_3d(ctx0, kvt, kvt->ne[0] / 2, kvt->ne[1], kvt->ne[2], + kvt->nb[1], kvt->nb[2], kvt->nb[0] * kvt->ne[0] / 2); + kt = ggml_cont(ctx0, kt); + vt = ggml_cont(ctx0, vt); + + // Calculate cross attention score + struct ggml_tensor * cross_attn = llm_build_cross_kv(ctx0, lctx, + qt, kt, vt, gf, il); + + cross_attn = ggml_mul_mat(ctx0, cur_layer.wdense_cross, cross_attn); + + inpSA = ggml_add(ctx0, inpSA, cross_attn); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * mlp_input = ggml_rms_norm(ctx0, inpSA, hparams.f_norm_rms_eps); + mlp_input = ggml_mul(ctx0, mlp_input, cur_layer.ffn_norm); + + struct ggml_tensor * up; + struct ggml_tensor * gate; + struct ggml_tensor * down; + if (batch_is_text) { + up = cur_layer.ffn_up_txt; + gate = cur_layer.ffn_gate_txt; + down = cur_layer.ffn_down_txt; + } else { + up = cur_layer.ffn_up_img; + gate = cur_layer.ffn_gate_img; + down = cur_layer.ffn_down_img; + } + + struct ggml_tensor * gate_result = ggml_mul_mat(ctx0, gate, mlp_input); + struct ggml_tensor * up_result = ggml_mul_mat(ctx0, up, mlp_input); + gate_result = ggml_silu(ctx0, gate_result); + struct ggml_tensor * mlp_inter = ggml_mul(ctx0, gate_result, up_result); + cur = ggml_mul_mat(ctx0, down, mlp_inter); + + inpSA = ggml_add(ctx0, inpSA, cur); + } + + cur = ggml_rms_norm(ctx0, inpSA, hparams.f_norm_rms_eps); + cur = ggml_mul(ctx0, cur, model.output_norm); + + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -8400,6 +8626,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_wavtokenizer_dec(); } break; + case LLM_ARCH_COGVLM: + { + result = llm.build_cogvlm(); + } break; default: GGML_ABORT("fatal error"); } @@ -8569,6 +8799,30 @@ static int llama_prepare_ubatch( return 0; } +void save_tensor_desperate(struct ggml_tensor * input_tensor, std::string filename) { + std::string prefix = "/home/tianyue/myworkspace/"; + filename = prefix + filename; + gguf_context * gguf_ctx = gguf_init_empty(); + gguf_set_val_str(gguf_ctx, "model.architecture", "cogagent"); + gguf_set_val_u32(gguf_ctx, "general.file_type", GGML_TYPE_F32); + + struct ggml_init_params params = { + ggml_nbytes(input_tensor) + 1000000, // Memory to allocate + nullptr, // Buffer location + false, // Allocate tensor data + }; + struct ggml_context * tensor_ctx = ggml_init(params); + struct ggml_tensor * tensor_with_data = ggml_dup(tensor_ctx, input_tensor); + ggml_backend_tensor_get(input_tensor, tensor_with_data->data, + 0, ggml_nbytes(input_tensor)); + + ggml_set_name(tensor_with_data, "output_tensor"); + gguf_add_tensor(gguf_ctx, tensor_with_data); + gguf_write_to_file(gguf_ctx, filename.c_str(), false); + gguf_free(gguf_ctx); + ggml_free(tensor_ctx); +} + // decode a batch of tokens by evaluating the transformer // in case of unsuccessful decoding (error or warning), // the kv_cache state will be returned to its original state @@ -8684,6 +8938,10 @@ static int llama_decode_impl( } } + if (llama_model_has_cross_kv(&lctx.model)) { + lctx.kv_cross.cache_filled = true; + } + // update the kv ring buffer { kv_self.head += ubatch.n_tokens; @@ -9278,7 +9536,7 @@ static void llama_kv_cache_update_impl(struct llama_context & lctx) { uint32_t n_seqs = 1; // TODO: worst-case number of sequences uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); // initialize scheduler with the worst-case graph @@ -9746,6 +10004,15 @@ struct llama_context * llama_init_from_model( return nullptr; } + if (llama_model_has_cross_kv(model)) { + // TODO: Add parameter for cross kv cache size + if (!llama_cross_kv_cache_init(ctx->kv_cross, ctx->model, type_k, type_v, 1024 * 6400, cparams.offload_kqv)) { + LLAMA_LOG_ERROR("%s: llama_cross_kv_cache_init() failed\n", __func__); + llama_free(ctx); + return nullptr; + } + } + { size_t memory_size_k = 0; size_t memory_size_v = 0; @@ -9841,7 +10108,7 @@ struct llama_context * llama_init_from_model( uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true); // reserve pp graph first so that buffers are only allocated once @@ -9850,7 +10117,7 @@ struct llama_context * llama_init_from_model( int n_nodes_pp = ggml_graph_n_nodes(gf_pp); // reserve with tg graph to get the number of splits and nodes - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; ggml_cgraph * gf_tg = llama_build_graph(*ctx, ubatch_tg, true); ggml_backend_sched_reserve(ctx->sched.get(), gf_tg); int n_splits_tg = ggml_backend_sched_get_n_splits(ctx->sched.get()); @@ -9888,6 +10155,13 @@ struct llama_context * llama_init_from_model( } } + if (model->has_vision) { + ctx->vctx.model = &model->vit; + ctx->vctx.sched = ctx->sched.get(); + const size_t max_nodes = VISION_GRAPH_MAX_NODE; // TODO: make it dynamic + ctx->vctx.buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + } + return ctx; }