Skip to content

Commit 8c6aefa

Browse files
author
Ivan Chikish
committed
examples/main: basic multimodal support ported from llava-cli
<image> keyword gets replaced with image embed within prompt.
1 parent 869ca4d commit 8c6aefa

File tree

5 files changed

+82
-21
lines changed

5 files changed

+82
-21
lines changed

Makefile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,9 +745,11 @@ clean:
745745
# Helper function that replaces .c, .cpp, and .cu file endings with .o:
746746
GET_OBJ_FILE = $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(patsubst %.cu,%.o,$(1))))
747747

748-
main: examples/main/main.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
748+
main: examples/main/main.cpp examples/llava/clip.h examples/llava/clip.cpp examples/llava/llava.h examples/llava/llava.cpp ggml.o llama.o $(COMMON_DEPS) console.o grammar-parser.o $(OBJS)
749749
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
750-
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
750+
$(CXX) $(CXXFLAGS) -c examples/llava/clip.cpp -o $(call GET_OBJ_FILE, examples/llava/clip.cpp) -Wno-cast-qual
751+
$(CXX) $(CXXFLAGS) -c examples/llava/llava.cpp -o $(call GET_OBJ_FILE, examples/llava/llava.cpp)
752+
$(CXX) $(CXXFLAGS) $(filter-out %.h $< examples/llava/clip.cpp examples/llava/llava.cpp,$^) $(call GET_OBJ_FILE, examples/llava/clip.cpp) $(call GET_OBJ_FILE, examples/llava/llava.cpp) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
751753
@echo
752754
@echo '==== Run ./main -h for help. ===='
753755
@echo

build.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ pub fn build(b: *std.build.Builder) !void {
129129
const clip = make.obj("clip", "examples/llava/clip.cpp");
130130
const llava = make.obj("llava", "examples/llava/llava.cpp");
131131

132-
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser });
132+
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo, sampling, console, grammar_parser, clip, llava });
133133
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
134134
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });
135135
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, sgemm, ggml_alloc, ggml_backend, ggml_quants, llama, unicode, unicode_data, common, json_schema_to_grammar, buildinfo });

examples/llava/clip.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#include "ggml-metal.h"
1717
#endif
1818

19+
#if defined(__clang__) || defined(__GNUC__) || defined(__GNUG__)
20+
#pragma GCC diagnostic ignored "-Wcast-qual"
21+
#endif
1922
#define STB_IMAGE_IMPLEMENTATION
2023
#include "stb_image.h"
2124

examples/main/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
set(TARGET main)
22
add_executable(${TARGET} main.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
4-
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
4+
target_link_libraries(${TARGET} PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/main/main.cpp

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "common.h"
22

3+
#include "../llava/clip.h"
4+
#include "../llava/llava.h"
35
#include "console.h"
46
#include "llama.h"
57

@@ -194,6 +196,9 @@ int main(int argc, char ** argv) {
194196
g_model = &model;
195197
g_ctx = &ctx;
196198

199+
clip_ctx* ctx_clip = nullptr;
200+
llava_image_embed* image_embed = nullptr;
201+
197202
// load the model and apply lora adapter, if any
198203
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
199204
std::tie(model, ctx) = llama_init_from_gpt_params(params);
@@ -207,6 +212,27 @@ int main(int argc, char ** argv) {
207212
return 1;
208213
}
209214

215+
if (!params.image.empty() && params.mmproj.empty()) {
216+
LOG_TEE("%s: error: image specified without mmproj\n", __func__);
217+
return 1;
218+
}
219+
220+
if (!params.mmproj.empty()) {
221+
ctx_clip = clip_model_load(params.mmproj.c_str(), /*verbosity=*/1);
222+
if (!ctx_clip) {
223+
LOG_TEE("%s: error: failed to load mmproj (CLIP)\n", __func__);
224+
return 1;
225+
}
226+
227+
if (!params.image.empty()) {
228+
image_embed = llava_image_embed_make_with_filename(ctx_clip, params.n_threads, params.image.c_str());
229+
if (!image_embed) {
230+
LOG_TEE("%s: error: failed to load image\n", __func__);
231+
return 1;
232+
}
233+
}
234+
}
235+
210236
const int n_ctx_train = llama_n_ctx_train(model);
211237
const int n_ctx = llama_n_ctx(ctx);
212238
LOG("n_ctx: %d\n", n_ctx);
@@ -250,13 +276,22 @@ int main(int argc, char ** argv) {
250276
LOG("add_bos: %d\n", add_bos);
251277

252278
std::vector<llama_token> embd_inp;
279+
int embd_img_pos = -1;
253280

254281
if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
255282
LOG("tokenize the prompt\n");
256283
if (params.chatml) {
257284
params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
258285
}
259-
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
286+
const auto epos = params.prompt.find("<image>");
287+
if (epos + 1 && image_embed) {
288+
embd_inp = ::llama_tokenize(ctx, params.prompt.substr(0, epos), true, true);
289+
embd_img_pos = embd_inp.size();
290+
auto end = ::llama_tokenize(ctx, params.prompt.substr(epos + 7), false, true);
291+
embd_inp.insert(embd_inp.end(), end.begin(), end.end());
292+
} else {
293+
embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
294+
}
260295
} else {
261296
LOG("use session tokens\n");
262297
embd_inp = session_tokens;
@@ -333,8 +368,10 @@ int main(int argc, char ** argv) {
333368
}
334369

335370
// number of tokens to keep when resetting context
336-
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
371+
bool n_keep_full = false;
372+
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct || params.chatml) {
337373
params.n_keep = (int)embd_inp.size();
374+
n_keep_full = true;
338375
} else {
339376
params.n_keep += add_bos; // always keep the BOS token
340377
}
@@ -454,6 +491,10 @@ int main(int argc, char ** argv) {
454491
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
455492
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
456493
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
494+
// Extend n_keep with embedded image size (there is an edge case with
495+
// explicit n_keep that it must include at least 1 token after img)
496+
if (embd_img_pos >= 0 && (params.n_keep > embd_img_pos || n_keep_full))
497+
params.n_keep += image_embed->n_image_pos;
457498

458499
// group-attention state
459500
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
@@ -659,26 +700,36 @@ int main(int argc, char ** argv) {
659700
}
660701
}
661702

662-
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
663-
int n_eval = (int) embd.size() - i;
664-
if (n_eval > params.n_batch) {
665-
n_eval = params.n_batch;
666-
}
703+
auto decode_tokens = [&](int start, int count) -> void {
704+
if (count == -1)
705+
count = embd.size() - start;
706+
for (int i = start; i < count; i += params.n_batch) {
707+
int n_eval = count - i;
708+
if (n_eval > params.n_batch) {
709+
n_eval = params.n_batch;
710+
}
667711

668-
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
712+
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
669713

670-
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
671-
LOG_TEE("%s : failed to eval\n", __func__);
672-
return 1;
673-
}
714+
llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0));
674715

675-
n_past += n_eval;
716+
n_past += n_eval;
676717

677-
LOG("n_past = %d\n", n_past);
678-
// Display total tokens alongside total time
679-
if (params.n_print > 0 && n_past % params.n_print == 0) {
680-
LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
718+
LOG("n_past = %d\n", n_past);
719+
// Display total tokens alongside total time
720+
if (params.n_print > 0 && n_past % params.n_print == 0) {
721+
LOG_TEE("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
722+
}
681723
}
724+
};
725+
726+
if (embd_img_pos >= 0) {
727+
decode_tokens(0, embd_img_pos);
728+
llava_eval_image_embed(ctx, image_embed, params.n_batch, &n_past);
729+
decode_tokens(embd_img_pos, -1);
730+
embd_img_pos = -1;
731+
} else {
732+
decode_tokens(0, embd.size());
682733
}
683734

684735
if (!embd.empty() && !path_session.empty()) {
@@ -943,6 +994,11 @@ int main(int argc, char ** argv) {
943994
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
944995

945996
if (ctx_guidance) { llama_free(ctx_guidance); }
997+
998+
if (image_embed)
999+
llava_image_embed_free(image_embed);
1000+
if (ctx_clip)
1001+
clip_free(ctx_clip);
9461002
llama_free(ctx);
9471003
llama_free_model(model);
9481004

0 commit comments

Comments
 (0)