5
5
For most models, the prompt format should follow corresponding examples
6
6
on HuggingFace model repository.
7
7
"""
8
+ import random
9
+
8
10
from transformers import AutoTokenizer
9
11
10
12
from vllm import LLM , SamplingParams
@@ -23,7 +25,9 @@ def run_llava(question: str, modality: str):
23
25
24
26
prompt = f"USER: <image>\n { question } \n ASSISTANT:"
25
27
26
- llm = LLM (model = "llava-hf/llava-1.5-7b-hf" , max_model_len = 4096 )
28
+ llm = LLM (model = "llava-hf/llava-1.5-7b-hf" ,
29
+ max_model_len = 4096 ,
30
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
27
31
stop_token_ids = None
28
32
return llm , prompt , stop_token_ids
29
33
@@ -33,7 +37,9 @@ def run_llava_next(question: str, modality: str):
33
37
assert modality == "image"
34
38
35
39
prompt = f"[INST] <image>\n { question } [/INST]"
36
- llm = LLM (model = "llava-hf/llava-v1.6-mistral-7b-hf" , max_model_len = 8192 )
40
+ llm = LLM (model = "llava-hf/llava-v1.6-mistral-7b-hf" ,
41
+ max_model_len = 8192 ,
42
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
37
43
stop_token_ids = None
38
44
return llm , prompt , stop_token_ids
39
45
@@ -44,7 +50,9 @@ def run_llava_next_video(question: str, modality: str):
44
50
assert modality == "video"
45
51
46
52
prompt = f"USER: <video>\n { question } ASSISTANT:"
47
- llm = LLM (model = "llava-hf/LLaVA-NeXT-Video-7B-hf" , max_model_len = 8192 )
53
+ llm = LLM (model = "llava-hf/LLaVA-NeXT-Video-7B-hf" ,
54
+ max_model_len = 8192 ,
55
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
48
56
stop_token_ids = None
49
57
return llm , prompt , stop_token_ids
50
58
@@ -61,7 +69,8 @@ def run_llava_onevision(question: str, modality: str):
61
69
<|im_start|>assistant\n "
62
70
63
71
llm = LLM (model = "llava-hf/llava-onevision-qwen2-7b-ov-hf" ,
64
- max_model_len = 16384 )
72
+ max_model_len = 16384 ,
73
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
65
74
stop_token_ids = None
66
75
return llm , prompt , stop_token_ids
67
76
@@ -71,7 +80,10 @@ def run_fuyu(question: str, modality: str):
71
80
assert modality == "image"
72
81
73
82
prompt = f"{ question } \n "
74
- llm = LLM (model = "adept/fuyu-8b" , max_model_len = 2048 , max_num_seqs = 2 )
83
+ llm = LLM (model = "adept/fuyu-8b" ,
84
+ max_model_len = 2048 ,
85
+ max_num_seqs = 2 ,
86
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
75
87
stop_token_ids = None
76
88
return llm , prompt , stop_token_ids
77
89
@@ -107,6 +119,7 @@ def run_phi3v(question: str, modality: str):
107
119
max_num_seqs = 2 ,
108
120
# Note - mm_processor_kwargs can also be passed to generate/chat calls
109
121
mm_processor_kwargs = {"num_crops" : 16 },
122
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
110
123
)
111
124
stop_token_ids = None
112
125
return llm , prompt , stop_token_ids
@@ -118,7 +131,8 @@ def run_paligemma(question: str, modality: str):
118
131
119
132
# PaliGemma has special prompt format for VQA
120
133
prompt = "caption en"
121
- llm = LLM (model = "google/paligemma-3b-mix-224" )
134
+ llm = LLM (model = "google/paligemma-3b-mix-224" ,
135
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
122
136
stop_token_ids = None
123
137
return llm , prompt , stop_token_ids
124
138
@@ -128,7 +142,9 @@ def run_chameleon(question: str, modality: str):
128
142
assert modality == "image"
129
143
130
144
prompt = f"{ question } <image>"
131
- llm = LLM (model = "facebook/chameleon-7b" , max_model_len = 4096 )
145
+ llm = LLM (model = "facebook/chameleon-7b" ,
146
+ max_model_len = 4096 ,
147
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
132
148
stop_token_ids = None
133
149
return llm , prompt , stop_token_ids
134
150
@@ -154,6 +170,7 @@ def run_minicpmv(question: str, modality: str):
154
170
max_model_len = 4096 ,
155
171
max_num_seqs = 2 ,
156
172
trust_remote_code = True ,
173
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
157
174
)
158
175
# NOTE The stop_token_ids are different for various versions of MiniCPM-V
159
176
# 2.0
@@ -186,6 +203,7 @@ def run_h2ovl(question: str, modality: str):
186
203
model = model_name ,
187
204
trust_remote_code = True ,
188
205
max_model_len = 8192 ,
206
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
189
207
)
190
208
191
209
tokenizer = AutoTokenizer .from_pretrained (model_name ,
@@ -211,6 +229,7 @@ def run_internvl(question: str, modality: str):
211
229
model = model_name ,
212
230
trust_remote_code = True ,
213
231
max_model_len = 4096 ,
232
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
214
233
)
215
234
216
235
tokenizer = AutoTokenizer .from_pretrained (model_name ,
@@ -241,6 +260,7 @@ def run_nvlm_d(question: str, modality: str):
241
260
trust_remote_code = True ,
242
261
max_model_len = 4096 ,
243
262
tensor_parallel_size = 4 ,
263
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
244
264
)
245
265
246
266
tokenizer = AutoTokenizer .from_pretrained (model_name ,
@@ -260,7 +280,8 @@ def run_blip2(question: str, modality: str):
260
280
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
261
281
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
262
282
prompt = f"Question: { question } Answer:"
263
- llm = LLM (model = "Salesforce/blip2-opt-2.7b" )
283
+ llm = LLM (model = "Salesforce/blip2-opt-2.7b" ,
284
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
264
285
stop_token_ids = None
265
286
return llm , prompt , stop_token_ids
266
287
@@ -274,6 +295,7 @@ def run_qwen_vl(question: str, modality: str):
274
295
trust_remote_code = True ,
275
296
max_model_len = 1024 ,
276
297
max_num_seqs = 2 ,
298
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
277
299
)
278
300
279
301
prompt = f"{ question } Picture 1: <img></img>\n "
@@ -296,6 +318,7 @@ def run_qwen2_vl(question: str, modality: str):
296
318
"min_pixels" : 28 * 28 ,
297
319
"max_pixels" : 1280 * 28 * 28 ,
298
320
},
321
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
299
322
)
300
323
301
324
prompt = ("<|im_start|>system\n You are a helpful assistant.<|im_end|>\n "
@@ -315,6 +338,7 @@ def run_pixtral_hf(question: str, modality: str):
315
338
llm = LLM (
316
339
model = model_name ,
317
340
max_model_len = 8192 ,
341
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
318
342
)
319
343
320
344
prompt = f"<s>[INST]{ question } \n [IMG][/INST]"
@@ -338,6 +362,7 @@ def run_mllama(question: str, modality: str):
338
362
max_model_len = 4096 ,
339
363
max_num_seqs = 16 ,
340
364
enforce_eager = True ,
365
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
341
366
)
342
367
343
368
prompt = f"<|image|><|begin_of_text|>{ question } "
@@ -355,6 +380,7 @@ def run_molmo(question, modality):
355
380
model = model_name ,
356
381
trust_remote_code = True ,
357
382
dtype = "bfloat16" ,
383
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
358
384
)
359
385
360
386
prompt = question
@@ -371,7 +397,8 @@ def run_glm4v(question: str, modality: str):
371
397
max_model_len = 2048 ,
372
398
max_num_seqs = 2 ,
373
399
trust_remote_code = True ,
374
- enforce_eager = True )
400
+ enforce_eager = True ,
401
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
375
402
prompt = question
376
403
stop_token_ids = [151329 , 151336 , 151338 ]
377
404
return llm , prompt , stop_token_ids
@@ -394,6 +421,7 @@ def run_idefics3(question: str, modality: str):
394
421
"longest_edge" : 3 * 364
395
422
},
396
423
},
424
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
397
425
)
398
426
prompt = (
399
427
f"<|begin_of_text|>User:<image>{ question } <end_of_utterance>\n Assistant:"
@@ -410,7 +438,8 @@ def run_aria(question: str, modality: str):
410
438
llm = LLM (model = model_name ,
411
439
tokenizer_mode = "slow" ,
412
440
trust_remote_code = True ,
413
- dtype = "bfloat16" )
441
+ dtype = "bfloat16" ,
442
+ mm_cache_preprocessor = args .mm_cache_preprocessor )
414
443
415
444
prompt = (f"<|im_start|>user\n <fim_prefix><|img|><fim_suffix>\n { question } "
416
445
"<|im_end|>\n <|im_start|>assistant\n " )
@@ -430,6 +459,7 @@ def run_mantis(question: str, modality: str):
430
459
model = "TIGER-Lab/Mantis-8B-siglip-llama3" ,
431
460
max_model_len = 4096 ,
432
461
hf_overrides = {"architectures" : ["MantisForConditionalGeneration" ]},
462
+ mm_cache_preprocessor = args .mm_cache_preprocessor ,
433
463
)
434
464
stop_token_ids = [128009 ]
435
465
return llm , prompt , stop_token_ids
@@ -494,6 +524,35 @@ def get_multi_modal_input(args):
494
524
raise ValueError (msg )
495
525
496
526
527
+ def apply_image_repeat (image_repeat_prob , num_prompts , data , prompt , modality ):
528
+ """Repeats images with provided probability of "image_repeat_prob".
529
+ Used to simulate hit/miss for the MM preprocessor cache.
530
+ """
531
+ assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0 )
532
+ no_yes = [0 , 1 ]
533
+ probs = [1.0 - image_repeat_prob , image_repeat_prob ]
534
+
535
+ inputs = []
536
+ cur_image = data
537
+ for i in range (num_prompts ):
538
+ if image_repeat_prob is not None :
539
+ res = random .choices (no_yes , probs )[0 ]
540
+ if res == 0 :
541
+ # No repeat => Modify one pixel
542
+ cur_image = cur_image .copy ()
543
+ new_val = (i // 256 // 256 , i // 256 , i % 256 )
544
+ cur_image .putpixel ((0 , 0 ), new_val )
545
+
546
+ inputs .append ({
547
+ "prompt" : prompt ,
548
+ "multi_modal_data" : {
549
+ modality : cur_image
550
+ }
551
+ })
552
+
553
+ return inputs
554
+
555
+
497
556
def main (args ):
498
557
model = args .model_type
499
558
if model not in model_example_map :
@@ -524,14 +583,29 @@ def main(args):
524
583
525
584
else :
526
585
# Batch inference
527
- inputs = [{
528
- "prompt" : prompt ,
529
- "multi_modal_data" : {
530
- modality : data
531
- },
532
- } for _ in range (args .num_prompts )]
586
+ if args .image_repeat_prob is not None :
587
+ # Repeat images with specified probability of "image_repeat_prob"
588
+ inputs = apply_image_repeat (args .image_repeat_prob ,
589
+ args .num_prompts , data , prompt ,
590
+ modality )
591
+ else :
592
+ # Use the same image for all prompts
593
+ inputs = [{
594
+ "prompt" : prompt ,
595
+ "multi_modal_data" : {
596
+ modality : data
597
+ },
598
+ } for _ in range (args .num_prompts )]
599
+
600
+ if args .time_generate :
601
+ import time
602
+ start_time = time .time ()
603
+ outputs = llm .generate (inputs , sampling_params = sampling_params )
604
+ elapsed_time = time .time () - start_time
605
+ print ("-- generate time = {}" .format (elapsed_time ))
533
606
534
- outputs = llm .generate (inputs , sampling_params = sampling_params )
607
+ else :
608
+ outputs = llm .generate (inputs , sampling_params = sampling_params )
535
609
536
610
for o in outputs :
537
611
generated_text = o .outputs [0 ].text
@@ -561,5 +635,23 @@ def main(args):
561
635
type = int ,
562
636
default = 16 ,
563
637
help = 'Number of frames to extract from the video.' )
638
+
639
+ parser .add_argument (
640
+ '--image-repeat-prob' ,
641
+ type = float ,
642
+ default = None ,
643
+ help = 'Simulates the hit-ratio for multi-modal preprocessor cache'
644
+ ' (if enabled)' )
645
+
646
+ parser .add_argument (
647
+ '--mm-cache-preprocessor' ,
648
+ action = 'store_true' ,
649
+ help = 'If True, enable caching of multi-modal preprocessor/mapper.' )
650
+
651
+ parser .add_argument (
652
+ '--time-generate' ,
653
+ action = 'store_true' ,
654
+ help = 'If True, then print the total generate() call time' )
655
+
564
656
args = parser .parse_args ()
565
657
main (args )
0 commit comments