diff --git a/examples/llava/CMakeLists.txt b/examples/llava/CMakeLists.txt index 319effd199aa4..f275ce1ccd003 100644 --- a/examples/llava/CMakeLists.txt +++ b/examples/llava/CMakeLists.txt @@ -51,6 +51,13 @@ install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) +set(TARGET llama-gemma3-cli) +add_executable(${TARGET} gemma3-cli.cpp) +set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + set(TARGET llama-llava-clip-quantize-cli) add_executable(${TARGET} clip-quantize-cli.cpp) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli) diff --git a/examples/llava/README-gemma3.md b/examples/llava/README-gemma3.md new file mode 100644 index 0000000000000..20bf73fb5c043 --- /dev/null +++ b/examples/llava/README-gemma3.md @@ -0,0 +1,30 @@ +# Gemma 3 vision + +> [!IMPORTANT] +> +> This is very experimental, only used for demo purpose. + +## How to get mmproj.gguf? + +```bash +cd gemma-3-4b-it +python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py . + +# output file is mmproj.gguf +``` + +## How to run it? + +What you need: +- The text model GGUF, can be converted using `convert_hf_to_gguf.py` +- The mmproj file from step above +- An image file + +```bash +# build +cmake -B build +cmake --build build --target llama-gemma3-cli + +# run it +./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg +``` diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index 7f892beb6edb1..a1f050e39a094 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -136,6 +136,8 @@ static std::string format(const char * fmt, ...) { #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_IMAGE_NEWLINE "model.image_newline" +#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 +#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" #define TN_MINICPMV_QUERY "resampler.query" @@ -162,6 +164,7 @@ enum projector_type { PROJECTOR_TYPE_RESAMPLER, PROJECTOR_TYPE_GLM_EDGE, PROJECTOR_TYPE_MERGER, + PROJECTOR_TYPE_GEMMA3, PROJECTOR_TYPE_UNKNOWN, }; @@ -172,6 +175,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_RESAMPLER, "resampler"}, { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, + { PROJECTOR_TYPE_GEMMA3, "gemma3"}, }; @@ -298,7 +302,7 @@ static projector_type clip_projector_type_from_string(const std::string & name) return kv.first; } } - return PROJECTOR_TYPE_UNKNOWN; + throw std::runtime_error(format("Unknown projector type: %s", name.c_str())); } #ifdef CLIP_DEBUG_FUNCTIONS @@ -555,6 +559,10 @@ struct clip_vision_model { struct ggml_tensor * mm_model_ln_kv_b; struct ggml_tensor * mm_model_ln_post_w; struct ggml_tensor * mm_model_ln_post_b; + + // gemma3 + struct ggml_tensor * mm_input_proj_w; + struct ggml_tensor * mm_soft_emb_norm_w; }; struct clip_ctx { @@ -569,7 +577,7 @@ struct clip_ctx { struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; - int32_t max_feature_layer; + int32_t max_feature_layer; // unused in newer models like gemma3 float image_mean[3]; float image_std[3]; bool use_gelu = false; @@ -595,7 +603,7 @@ struct clip_ctx { ggml_backend_sched_ptr sched; - struct clip_image_size * load_image_size; + struct clip_image_size * load_image_size = nullptr; clip_ctx(clip_context_params & ctx_params) { backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); @@ -631,7 +639,159 @@ struct clip_ctx { } }; -static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { +static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch * imgs) { + const auto & model = ctx->vision_model; + const auto & hparams = model.hparams; + + const int image_size = hparams.image_size; + int image_size_width = image_size; + int image_size_height = image_size; + + const int patch_size = hparams.patch_size; + const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); + const int hidden_size = hparams.hidden_size; + const int n_head = hparams.n_head; + const int d_head = hidden_size / n_head; + const int n_layer = hparams.n_layer; + const float eps = hparams.eps; + + GGML_ASSERT(imgs->size == 1); // batch_size == 1 + + struct ggml_init_params params = { + /*.mem_size =*/ ctx->buf_compute_meta.size(), + /*.mem_buffer =*/ ctx->buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + // input raw + struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3); + ggml_set_name(inp_raw, "inp_raw"); + ggml_set_input(inp_raw); + + struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); + inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + inp = ggml_add(ctx0, inp, model.patch_bias); + + // position embeddings + struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings); + + // loop over layers + for (int il = 0; il < n_layer; il++) { + struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + + // layernorm1 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b); + } + + // self-attention + { + + struct ggml_tensor * Q = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); + + Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches); + Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); + + struct ggml_tensor * K = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); + + K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches); + K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); + + struct ggml_tensor * V = + ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); + + V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches); + V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + KQ = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf((float)d_head)); + KQ = ggml_soft_max_inplace(ctx0, KQ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); + KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); + KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches); + } + + // attention output + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b); + + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, embeddings); + + embeddings = cur; // embeddings = residual, cur = hidden_states + + // layernorm2 + { + cur = ggml_norm(ctx0, cur, eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); + + // siglip uses gelu + cur = ggml_gelu(ctx0, cur); + + cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); + cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); + + // residual 2 + cur = ggml_add(ctx0, embeddings, cur); + + embeddings = cur; + } + + // post-layernorm + if (ctx->has_post_norm) { + embeddings = ggml_norm(ctx0, embeddings, eps); + ggml_set_name(embeddings, "post_ln"); + + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); + } + + if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + const int batch_size = 1; + const int mm_tokens_per_image = 256; // default value for gemma3 + const int tokens_per_side = sqrt(mm_tokens_per_image); + const int patches_per_image = sqrt(num_patches); + const int kernel_size = patches_per_image / tokens_per_side; + + embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); + embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size); + + // doing a pool2d to reduce the number of output tokens to 256 + embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); + embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size); + embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); + + // apply norm before projection + embeddings = ggml_rms_norm(ctx0, embeddings, eps); + embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w); + + // apply projection + embeddings = ggml_mul_mat(ctx0, + ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)), + embeddings); + } + + // build the graph + ggml_build_forward_expand(gf, embeddings); + + ggml_free(ctx0); + + return gf; +} + +static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { if (!ctx->has_vision_encoder) { LOG_ERR("This gguf file seems to have no vision encoder\n"); return nullptr; @@ -1177,7 +1337,8 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 } else { GGML_ABORT("fatel error"); } - } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + } + else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -1199,6 +1360,15 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 return gf; } +static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch * imgs, struct clip_image_size * load_image_size, bool is_inf = false) { + if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + return clip_image_build_graph_siglip(ctx, imgs); + } else { + // TODO: we should have one build_* function per model + return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf); + } +} + // read and create ggml_context containing the tensors and their data struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { return clip_init(fname, clip_context_params{ @@ -1358,8 +1528,12 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p GGML_ASSERT(new_clip->has_vision_encoder); GGML_ASSERT(!new_clip->has_text_encoder); - idx = get_key_idx(ctx, KEY_USE_GELU); - new_clip->use_gelu = gguf_get_val_bool(ctx, idx); + try { + idx = get_key_idx(ctx, KEY_USE_GELU); + new_clip->use_gelu = gguf_get_val_bool(ctx, idx); + } catch (std::runtime_error & /*e*/) { + new_clip->use_gelu = false; + } try { idx = get_key_idx(ctx, KEY_USE_SILU); @@ -1567,11 +1741,17 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p } try { - vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); + vision_model.patch_embeddings_0 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD); + } catch(const std::exception& /*e*/) { + vision_model.patch_embeddings_0 = nullptr; + } + + try { vision_model.position_embeddings = get_tensor(new_clip->ctx_data, format(TN_POS_EMBD, "v")); } catch(const std::exception& /*e*/) { - LOG_ERR("%s: failed to load vision model tensors\n", __func__); + vision_model.position_embeddings = nullptr; } + try { vision_model.patch_embeddings_1 = get_tensor(new_clip->ctx_data, TN_PATCH_EMBD_1); } catch(const std::exception& /*e*/) { @@ -1682,6 +1862,10 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p vision_model.mm_1_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "weight")); vision_model.mm_1_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 2, "bias")); } + else if (new_clip->proj_type == PROJECTOR_TYPE_GEMMA3) { + vision_model.mm_input_proj_w = get_tensor(new_clip->ctx_data, TN_MM_INP_PROJ); + vision_model.mm_soft_emb_norm_w = get_tensor(new_clip->ctx_data, TN_MM_SOFT_EMB_N); + } else { std::string proj_type = PROJECTOR_TYPE_NAMES[new_clip->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); @@ -2223,7 +2407,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli return true; } - if (ctx->has_glm_projector) { + if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { res_imgs->size = 1; res_imgs->data = new clip_image_f32[res_imgs->size]; clip_image_u8 resized_image; @@ -2748,6 +2932,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); } + else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + // do nothing + } else { struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); @@ -2960,6 +3147,9 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { return ctx->vision_model.mm_1_b->ne[0]; } + if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { + return ctx->vision_model.mm_input_proj_w->ne[0]; + } std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; throw std::runtime_error(format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp new file mode 100644 index 0000000000000..a07864d4e59f6 --- /dev/null +++ b/examples/llava/gemma3-cli.cpp @@ -0,0 +1,341 @@ +#include "arg.h" +#include "log.h" +#include "common.h" +#include "sampling.h" +#include "clip.h" +#include "stb_image.h" +#include "llama.h" +#include "ggml.h" +#include "console.h" + +#include +#include +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#endif + +static bool g_is_generating = false; + +/** + * Please note that this is NOT a production-ready stuff. + * It is a playground for trying Gemma 3 vision capabilities. + * For contributors: please keep this code simple and easy to understand. + */ + +static void show_additional_info(int /*argc*/, char ** argv) { + LOG( + "Experimental CLI for using Gemma 3 vision model\n\n" + "Usage: %s [options] -m --mmproj --image -p \n\n" + " -m and --mmproj are required\n" + " --image and -p are optional, if NOT provided, the CLI will run in chat mode\n", + argv[0] + ); +} + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) +static void sigint_handler(int signo) { + if (signo == SIGINT) { + if (g_is_generating) { + g_is_generating = false; + } else { + console::cleanup(); + LOG("\nInterrupted by user\n"); + _exit(130); + } + } +} +#endif + +struct gemma3_context { + struct clip_ctx * ctx_clip = NULL; + common_init_result llama_init; + + llama_model * model; + llama_context * lctx; + const llama_vocab * vocab; + llama_batch batch; + + int n_threads = 1; + llama_pos n_past = 0; + + gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) { + model = llama_init.model.get(); + lctx = llama_init.context.get(); + vocab = llama_model_get_vocab(model); + n_threads = params.cpuparams.n_threads; + batch = llama_batch_init(params.n_batch, 0, 1); + init_clip_model(params); + } + + void init_clip_model(common_params & params) { + const char * clip_path = params.mmproj.c_str(); + ctx_clip = clip_model_load(clip_path, params.verbosity > 1); + } + + ~gemma3_context() { + clip_free(ctx_clip); + } +}; + +struct decode_embd_batch { + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } +}; + +static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { + llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); + common_batch_clear(ctx.batch); + for (llama_token & t : tokens) { + common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false); + } + if (logits_last) { + ctx.batch.logits[ctx.batch.n_tokens - 1] = true; + } + // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); + if (llama_decode(ctx.lctx, ctx.batch)) { + LOG_ERR("Failed to decode text\n"); + return 1; + } + return 0; +} + +static int eval_image(gemma3_context & ctx, std::string & fname) { + std::vector image_embd_v; + int n_embd = llama_model_n_embd(ctx.model); + int n_tokens = 256; + image_embd_v.resize(n_tokens * n_embd); + + bool ok; + struct clip_image_u8 * img_u8 = clip_image_u8_init(); + ok = clip_image_load_from_file(fname.c_str(), img_u8); + if (!ok) { + LOG_ERR("Unable to load image %s\n", fname.c_str()); + clip_image_u8_free(img_u8); + return 2; // non-fatal error + } + + clip_image_f32_batch batch_f32; + ok = clip_image_preprocess(ctx.ctx_clip, img_u8, &batch_f32); + if (!ok) { + LOG_ERR("Unable to preprocess image\n"); + clip_image_f32_batch_free(&batch_f32); + clip_image_u8_free(img_u8); + return 1; + } + + int64_t t0 = ggml_time_ms(); + LOG("Encoding image %s\n", fname.c_str()); + ok = clip_image_batch_encode(ctx.ctx_clip, ctx.n_threads, &batch_f32, image_embd_v.data()); + if (!ok) { + LOG_ERR("Unable to encode image\n"); + clip_image_f32_batch_free(&batch_f32); + clip_image_u8_free(img_u8); + return 1; + } + LOG("Image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); + + clip_image_f32_batch_free(&batch_f32); + clip_image_u8_free(img_u8); + + // decode image embeddings + int64_t t1 = ggml_time_ms(); + eval_text(ctx, ""); + llama_set_causal_attn(ctx.lctx, false); + decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0); + if (llama_decode(ctx.lctx, batch_img.batch)) { + LOG_ERR("failed to decode image\n"); + return 1; + } + ctx.n_past += n_tokens; + llama_set_causal_attn(ctx.lctx, true); + eval_text(ctx, ""); + LOG("Image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1); + return 0; +} + +static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) { + for (int i = 0; i < n_predict; i++) { + if (i > n_predict || !g_is_generating) { + printf("\n"); + break; + } + + llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1); + common_sampler_accept(smpl, token_id, true); + + if (llama_vocab_is_eog(ctx.vocab, token_id)) { + printf("\n"); + break; // end of generation + } + + printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str()); + fflush(stdout); + + // eval the token + common_batch_clear(ctx.batch); + common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); + if (llama_decode(ctx.lctx, ctx.batch)) { + LOG_ERR("failed to decode token\n"); + return 1; + } + } + return 0; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + common_params params; + params.sampling.temp = 0.2; // lower temp by default for better quality + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) { + return 1; + } + + common_init(); + + if (params.mmproj.empty()) { + show_additional_info(argc, argv); + return 1; + } + + gemma3_context ctx(params); + printf("%s: %s\n", __func__, params.model.c_str()); + + bool is_single_turn = !params.prompt.empty() && !params.image.empty(); + + struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling); + int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict; + + // ctrl+C handling + { +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + } + + if (eval_text(ctx, "")) { + return 1; + } + + if (is_single_turn) { + g_is_generating = true; + if (eval_text(ctx, "user\n")) { + return 1; + } + for (auto & fname : params.image) { + if (eval_image(ctx, fname)) { + return 1; + } + } + if (eval_text(ctx, params.prompt + "model\n", true)) { + return 1; + } + if (generate_response(ctx, smpl, n_predict)) { + return 1; + } + + } else { + LOG("\n Running in chat mode, available commands:"); + LOG("\n /image load an image"); + LOG("\n /clear clear the chat history"); + LOG("\n /quit or /exit exit the program"); + LOG("\n"); + + if (eval_text(ctx, "user\n")) { + return 1; + } + + while (true) { + g_is_generating = false; + LOG("\n> "); + console::set_display(console::user_input); + std::string line; + console::readline(line, false); + console::set_display(console::reset); + line = string_strip(line); + if (line.empty()) { + continue; + } + if (line == "/quit" || line == "/exit") { + break; + } + if (line == "/clear") { + ctx.n_past = 0; + llama_kv_cache_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS + LOG("Chat history cleared\n\n"); + continue; + } + g_is_generating = true; + if (line.find("/image") == 0) { + std::string image = line.substr(7); + int res = eval_image(ctx, image); + if (res == 2) { + continue; // image not found + } + if (res) { + return 1; + } + continue; + } + if (eval_text(ctx, line + "model\n", true)) { + return 1; + } + if (generate_response(ctx, smpl, n_predict)) { + return 1; + } + if (eval_text(ctx, "user\n")) { + return 1; + } + } + } + + return 0; +} diff --git a/examples/llava/gemma3_convert_encoder_to_gguf.py b/examples/llava/gemma3_convert_encoder_to_gguf.py new file mode 100644 index 0000000000000..241b526b9ede7 --- /dev/null +++ b/examples/llava/gemma3_convert_encoder_to_gguf.py @@ -0,0 +1,307 @@ +import gguf +import argparse +import logging +import sys +import torch +import json +import os +import numpy as np +from typing import cast, ContextManager, Any, Iterator +from pathlib import Path +from torch import Tensor + +logger = logging.getLogger("gemma3-mmproj") + + +# (copied from convert_hf_to_gguf.py) +# tree of lazy tensors +class LazyTorchTensor(gguf.LazyBase): + _tensor_type = torch.Tensor + # to keep the type-checker happy + dtype: torch.dtype + shape: torch.Size + + # only used when converting a torch.Tensor to a np.ndarray + _dtype_map: dict[torch.dtype, type] = { + torch.float16: np.float16, + torch.float32: np.float32, + } + + # used for safetensors slices + # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 + # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 + _dtype_str_map: dict[str, torch.dtype] = { + "F64": torch.float64, + "F32": torch.float32, + "BF16": torch.bfloat16, + "F16": torch.float16, + # "U64": torch.uint64, + "I64": torch.int64, + # "U32": torch.uint32, + "I32": torch.int32, + # "U16": torch.uint16, + "I16": torch.int16, + "U8": torch.uint8, + "I8": torch.int8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + } + + def numpy(self) -> gguf.LazyNumpyTensor: + dtype = self._dtype_map[self.dtype] + return gguf.LazyNumpyTensor( + meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), + args=(self,), + func=(lambda s: s.numpy()) + ) + + @classmethod + def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: + return torch.empty(size=shape, dtype=dtype, device="meta") + + @classmethod + def from_safetensors_slice(cls, st_slice: Any) -> Tensor: + dtype = cls._dtype_str_map[st_slice.get_dtype()] + shape: tuple[int, ...] = tuple(st_slice.get_shape()) + lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) + return cast(torch.Tensor, lazy) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + del types # unused + + if kwargs is None: + kwargs = {} + + if func is torch.Tensor.numpy: + return args[0].numpy() + + return cls._wrap_fn(func)(*args, **kwargs) + + +class Gemma3VisionTower: + hparams: dict + gguf_writer: gguf.GGUFWriter + fname_out: Path + ftype: gguf.LlamaFileType + + @staticmethod + def load_hparams(dir_model: Path): + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: + part_names: list[str] = [] + for filename in os.listdir(dir_model): + if filename.startswith(prefix) and filename.endswith(suffix): + part_names.append(filename) + part_names.sort() + return part_names + + def __init__(self, + dir_model: Path, + fname_out: Path, + ftype: gguf.LlamaFileType, + is_big_endian: bool,): + hparams = Gemma3VisionTower.load_hparams(dir_model) + self.hparams = hparams + self.fname_out = fname_out + self.ftype = ftype + endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess) + + text_config = hparams["text_config"] + vision_config = hparams["vision_config"] + + assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration" + assert text_config is not None + assert vision_config is not None + + self.gguf_writer.add_string ("clip.projector_type", "gemma3") + self.gguf_writer.add_bool ("clip.has_text_encoder", False) + self.gguf_writer.add_bool ("clip.has_vision_encoder", True) + self.gguf_writer.add_bool ("clip.has_llava_projector", False) # legacy + self.gguf_writer.add_uint32 ("clip.vision.image_size", vision_config["image_size"]) + self.gguf_writer.add_uint32 ("clip.vision.patch_size", vision_config["patch_size"]) + self.gguf_writer.add_uint32 ("clip.vision.embedding_length", vision_config["hidden_size"]) + self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", vision_config["intermediate_size"]) + self.gguf_writer.add_uint32 ("clip.vision.projection_dim", text_config["hidden_size"]) + self.gguf_writer.add_uint32 ("clip.vision.block_count", vision_config["num_hidden_layers"]) + self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"]) + self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6)) + # default values taken from HF tranformers code + self.gguf_writer.add_array ("clip.vision.image_mean", [0.5, 0.5, 0.5]) + self.gguf_writer.add_array ("clip.vision.image_std", [0.5, 0.5, 0.5]) + self.gguf_writer.add_bool ("clip.use_gelu", True) + + # load tensors + for name, data_torch in self.get_tensors(dir_model): + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + self.add_tensor(name, data_torch) + + def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]: + part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors") + tensor_names_from_parts: set[str] = set() + for part_name in part_names: + logger.info(f"gguf: loading model part '{part_name}'") + from safetensors import safe_open + ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu")) + with ctx as model_part: + tensor_names_from_parts.update(model_part.keys()) + + for name in model_part.keys(): + data = model_part.get_slice(name) + data = LazyTorchTensor.from_safetensors_slice(data) + yield name, data + + def add_tensor(self, name: str, data_torch: Tensor): + is_1d = len(data_torch.shape) == 1 + is_embd = ".embeddings." in name + old_dtype = data_torch.dtype + can_quantize = not is_1d and not is_embd + data_qtype = gguf.GGMLQuantizationType.F32 + + # this is to support old checkpoint + # TODO: remove this when we have the final model + name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.") + name = name.replace("multimodal_projector.", "multi_modal_projector.") + + # filter only vision tensors + if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."): + return + # prefix + name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.") + name = name.replace("vision_tower.vision_model.", "v.") + # projector and input embd + name = name.replace(".embeddings.patch_embedding.", ".patch_embd.") + name = name.replace(".embeddings.position_embedding.", ".position_embd.") + name = name.replace( + "multi_modal_projector.mm_input_projection_weight", + "mm.input_projection.weight" + ) + name = name.replace( + "multi_modal_projector.mm_soft_emb_norm.weight", + "mm.soft_emb_norm.weight" + ) + name = name.replace("post_layernorm.", "post_ln.") + # each block + name = name.replace(".self_attn.k_proj.", ".attn_k.") + name = name.replace(".self_attn.v_proj.", ".attn_v.") + name = name.replace(".self_attn.q_proj.", ".attn_q.") + name = name.replace(".self_attn.out_proj.", ".attn_out.") + name = name.replace(".layer_norm1.", ".ln1.") + name = name.replace(".layer_norm2.", ".ln2.") + name = name.replace(".mlp.fc1.", ".ffn_down.") + name = name.replace(".mlp.fc2.", ".ffn_up.") + + if can_quantize: + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + data_qtype = gguf.GGMLQuantizationType.Q8_0 + else: + raise ValueError(f"Unsupported file type: {self.ftype}") + + # corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector + # the other norm values are part of SigLIP model, and they are already correct + # ref code: Gemma3RMSNorm + if "soft_emb_norm.weight" in name: + logger.info(f"Correcting norm value for '{name}'") + data_torch = data_torch + 1 + + data = data_torch.numpy() + + try: + data = gguf.quants.quantize(data, data_qtype) + except Exception as e: + logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}" + logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) + + def write(self): + self.gguf_writer.write_header_to_file(path=self.fname_out) + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file(progress=True) + self.gguf_writer.close() + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert Gemma 3 vision tower safetensors to GGUF format",) + parser.add_argument( + "--outfile", type=Path, default="mmproj.gguf", + help="path to write to", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", + help="output format", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", type=Path, + help="directory containing model file", + nargs="?", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + + args = parser.parse_args() + if args.model is None: + parser.error("the following arguments are required: model") + return args + + +def main() -> None: + args = parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + dir_model = args.model + + if not dir_model.is_dir(): + logger.error(f'Error: {args.model} is not a directory') + sys.exit(1) + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + } + + logger.info(f"Loading model: {dir_model.name}") + + with torch.inference_mode(): + gemma3_vision_tower = Gemma3VisionTower( + dir_model=dir_model, + fname_out=args.outfile, + ftype=ftype_map[args.outtype], + is_big_endian=args.bigendian, + ) + gemma3_vision_tower.write() + + +if __name__ == '__main__': + main() +