1
1
#include " common.h"
2
2
3
+ #include " ../llava/clip.h"
4
+ #include " ../llava/llava.h"
3
5
#include " console.h"
4
6
#include " llama.h"
5
7
@@ -194,6 +196,9 @@ int main(int argc, char ** argv) {
194
196
g_model = &model;
195
197
g_ctx = &ctx;
196
198
199
+ clip_ctx* ctx_clip = nullptr ;
200
+ llava_image_embed* image_embed = nullptr ;
201
+
197
202
// load the model and apply lora adapter, if any
198
203
LOG (" %s: load the model and apply lora adapter, if any\n " , __func__);
199
204
std::tie (model, ctx) = llama_init_from_gpt_params (params);
@@ -207,6 +212,27 @@ int main(int argc, char ** argv) {
207
212
return 1 ;
208
213
}
209
214
215
+ if (!params.image .empty () && params.mmproj .empty ()) {
216
+ LOG_TEE (" %s: error: image specified without mmproj\n " , __func__);
217
+ return 1 ;
218
+ }
219
+
220
+ if (!params.mmproj .empty ()) {
221
+ ctx_clip = clip_model_load (params.mmproj .c_str (), /* verbosity=*/ 1 );
222
+ if (!ctx_clip) {
223
+ LOG_TEE (" %s: error: failed to load mmproj (CLIP)\n " , __func__);
224
+ return 1 ;
225
+ }
226
+
227
+ if (!params.image .empty ()) {
228
+ image_embed = llava_image_embed_make_with_filename (ctx_clip, params.n_threads , params.image .c_str ());
229
+ if (!image_embed) {
230
+ LOG_TEE (" %s: error: failed to load image\n " , __func__);
231
+ return 1 ;
232
+ }
233
+ }
234
+ }
235
+
210
236
const int n_ctx_train = llama_n_ctx_train (model);
211
237
const int n_ctx = llama_n_ctx (ctx);
212
238
LOG (" n_ctx: %d\n " , n_ctx);
@@ -250,13 +276,22 @@ int main(int argc, char ** argv) {
250
276
LOG (" add_bos: %d\n " , add_bos);
251
277
252
278
std::vector<llama_token> embd_inp;
279
+ int embd_img_pos = -1 ;
253
280
254
281
if (params.interactive_first || params.instruct || params.chatml || !params.prompt .empty () || session_tokens.empty ()) {
255
282
LOG (" tokenize the prompt\n " );
256
283
if (params.chatml ) {
257
284
params.prompt = " <|im_start|>system\n " + params.prompt + " <|im_end|>" ;
258
285
}
259
- embd_inp = ::llama_tokenize (ctx, params.prompt , true , true );
286
+ const auto epos = params.prompt .find (" <image>" );
287
+ if (epos + 1 && image_embed) {
288
+ embd_inp = ::llama_tokenize (ctx, params.prompt .substr (0 , epos), true , true );
289
+ embd_img_pos = embd_inp.size ();
290
+ auto end = ::llama_tokenize (ctx, params.prompt .substr (epos + 7 ), false , true );
291
+ embd_inp.insert (embd_inp.end (), end.begin (), end.end ());
292
+ } else {
293
+ embd_inp = ::llama_tokenize (ctx, params.prompt , true , true );
294
+ }
260
295
} else {
261
296
LOG (" use session tokens\n " );
262
297
embd_inp = session_tokens;
@@ -333,8 +368,10 @@ int main(int argc, char ** argv) {
333
368
}
334
369
335
370
// number of tokens to keep when resetting context
336
- if (params.n_keep < 0 || params.n_keep > (int ) embd_inp.size () || params.instruct || params.chatml ) {
371
+ bool n_keep_full = false ;
372
+ if (params.n_keep < 0 || params.n_keep > (int )embd_inp.size () || params.instruct || params.chatml ) {
337
373
params.n_keep = (int )embd_inp.size ();
374
+ n_keep_full = true ;
338
375
} else {
339
376
params.n_keep += add_bos; // always keep the BOS token
340
377
}
@@ -454,6 +491,10 @@ int main(int argc, char ** argv) {
454
491
LOG_TEE (" sampling: \n %s\n " , llama_sampling_print (sparams).c_str ());
455
492
LOG_TEE (" sampling order: \n %s\n " , llama_sampling_order_print (sparams).c_str ());
456
493
LOG_TEE (" generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
494
+ // Extend n_keep with embedded image size (there is an edge case with
495
+ // explicit n_keep that it must include at least 1 token after img)
496
+ if (embd_img_pos >= 0 && (params.n_keep > embd_img_pos || n_keep_full))
497
+ params.n_keep += image_embed->n_image_pos ;
457
498
458
499
// group-attention state
459
500
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
@@ -659,26 +700,36 @@ int main(int argc, char ** argv) {
659
700
}
660
701
}
661
702
662
- for (int i = 0 ; i < (int ) embd.size (); i += params.n_batch ) {
663
- int n_eval = (int ) embd.size () - i;
664
- if (n_eval > params.n_batch ) {
665
- n_eval = params.n_batch ;
666
- }
703
+ auto decode_tokens = [&](int start, int count) -> void {
704
+ if (count == -1 )
705
+ count = embd.size () - start;
706
+ for (int i = start; i < count; i += params.n_batch ) {
707
+ int n_eval = count - i;
708
+ if (n_eval > params.n_batch ) {
709
+ n_eval = params.n_batch ;
710
+ }
667
711
668
- LOG (" eval: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd).c_str ());
712
+ LOG (" eval: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd).c_str ());
669
713
670
- if (llama_decode (ctx, llama_batch_get_one (&embd[i], n_eval, n_past, 0 ))) {
671
- LOG_TEE (" %s : failed to eval\n " , __func__);
672
- return 1 ;
673
- }
714
+ llama_decode (ctx, llama_batch_get_one (&embd[i], n_eval, n_past, 0 ));
674
715
675
- n_past += n_eval;
716
+ n_past += n_eval;
676
717
677
- LOG (" n_past = %d\n " , n_past);
678
- // Display total tokens alongside total time
679
- if (params.n_print > 0 && n_past % params.n_print == 0 ) {
680
- LOG_TEE (" \n\033 [31mTokens consumed so far = %d / %d \033 [0m\n " , n_past, n_ctx);
718
+ LOG (" n_past = %d\n " , n_past);
719
+ // Display total tokens alongside total time
720
+ if (params.n_print > 0 && n_past % params.n_print == 0 ) {
721
+ LOG_TEE (" \n\033 [31mTokens consumed so far = %d / %d \033 [0m\n " , n_past, n_ctx);
722
+ }
681
723
}
724
+ };
725
+
726
+ if (embd_img_pos >= 0 ) {
727
+ decode_tokens (0 , embd_img_pos);
728
+ llava_eval_image_embed (ctx, image_embed, params.n_batch , &n_past);
729
+ decode_tokens (embd_img_pos, -1 );
730
+ embd_img_pos = -1 ;
731
+ } else {
732
+ decode_tokens (0 , embd.size ());
682
733
}
683
734
684
735
if (!embd.empty () && !path_session.empty ()) {
@@ -943,6 +994,11 @@ int main(int argc, char ** argv) {
943
994
write_logfile (ctx, params, model, input_tokens, output_ss.str (), output_tokens);
944
995
945
996
if (ctx_guidance) { llama_free (ctx_guidance); }
997
+
998
+ if (image_embed)
999
+ llava_image_embed_free (image_embed);
1000
+ if (ctx_clip)
1001
+ clip_free (ctx_clip);
946
1002
llama_free (ctx);
947
1003
llama_free_model (model);
948
1004
0 commit comments