8
8
9
9
import torch
10
10
import uvloop
11
+ from PIL import Image
11
12
from tqdm import tqdm
12
13
from transformers import (AutoModelForCausalLM , AutoTokenizer ,
13
14
PreTrainedTokenizerBase )
@@ -38,12 +39,33 @@ class SampleRequest:
38
39
multi_modal_data : Optional [MultiModalDataDict ] = None
39
40
40
41
41
- def sample_requests (
42
- dataset_path : str ,
43
- num_requests : int ,
44
- tokenizer : PreTrainedTokenizerBase ,
45
- fixed_output_len : Optional [int ],
46
- ) -> List [SampleRequest ]:
42
+ def _get_prompt_for_image_model (question : str , * , model : str ) -> str :
43
+ """Prepend and append special tokens around the question to form a prompt.
44
+
45
+ Args:
46
+ question: The input question text to wrap with special tokens
47
+ model: The name of the model being used, to determine which special
48
+ tokens to add
49
+
50
+ Returns:
51
+ The formatted prompt string with appropriate special tokens for the
52
+ model
53
+
54
+ Raises:
55
+ ValueError: If an unsupported model name is provided
56
+ """
57
+ model = model .lower ()
58
+ if "pixtral" in model :
59
+ return f"<s>[INST]{ question } \n [IMG][/INST]"
60
+ raise ValueError (f"Unsupported model { model } " )
61
+
62
+
63
+ def sample_requests (tokenizer : PreTrainedTokenizerBase ,
64
+ args : argparse .Namespace ) -> List [SampleRequest ]:
65
+ dataset_path : str = args .dataset
66
+ num_requests : int = args .num_prompts
67
+ fixed_output_len : Optional [int ] = args .output_len
68
+ model : str = args .model
47
69
if fixed_output_len is not None and fixed_output_len < 4 :
48
70
raise ValueError ("output_len too small" )
49
71
@@ -52,23 +74,36 @@ def sample_requests(
52
74
dataset = json .load (f )
53
75
# Filter out the conversations with less than 2 turns.
54
76
dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
55
- # Only keep the first two turns of each conversation.
56
- dataset = [(data ["conversations" ][0 ]["value" ],
57
- data ["conversations" ][1 ]["value" ]) for data in dataset ]
58
-
59
77
# Shuffle the dataset.
60
78
random .shuffle (dataset )
61
79
62
80
# Filter out sequences that are too long or too short
63
81
filtered_dataset : List [SampleRequest ] = []
64
- for i in range ( len ( dataset )) :
82
+ for data in dataset :
65
83
if len (filtered_dataset ) == num_requests :
66
84
break
67
85
86
+ # Only keep the first two turns of each conversation.
87
+ prompt = data ["conversations" ][0 ]["value" ]
88
+ completion = data ["conversations" ][1 ]["value" ]
89
+
90
+ multi_modal_data : Optional [MultiModalDataDict ] = None
91
+ if "image" in data :
92
+ multi_modal_data = multi_modal_data or {}
93
+ image_path = data ["image" ]
94
+ # TODO(vllm-project/vllm/issues/9778): Support multiple images.
95
+ assert isinstance (image_path ,
96
+ str ), "Only support single image input"
97
+ try :
98
+ multi_modal_data ["image" ] = Image .open (image_path ).convert (
99
+ "RGB" )
100
+ except FileNotFoundError :
101
+ # Ignore datapoint where asset is missing
102
+ continue
103
+ prompt = _get_prompt_for_image_model (question = prompt , model = model )
104
+
68
105
# Tokenize the prompts and completions.
69
- prompt = dataset [i ][0 ]
70
106
prompt_token_ids = tokenizer (prompt ).input_ids
71
- completion = dataset [i ][1 ]
72
107
completion_token_ids = tokenizer (completion ).input_ids
73
108
prompt_len = len (prompt_token_ids )
74
109
output_len = len (completion_token_ids
@@ -82,7 +117,8 @@ def sample_requests(
82
117
filtered_dataset .append (
83
118
SampleRequest (prompt = prompt ,
84
119
prompt_len = prompt_len ,
85
- expected_output_len = output_len ))
120
+ expected_output_len = output_len ,
121
+ multi_modal_data = multi_modal_data ))
86
122
87
123
return filtered_dataset
88
124
@@ -99,7 +135,9 @@ def run_vllm(
99
135
prompts : List [TextPrompt ] = []
100
136
sampling_params : List [SamplingParams ] = []
101
137
for request in requests :
102
- prompts .append (TextPrompt (prompt = request .prompt ))
138
+ prompts .append (
139
+ TextPrompt (prompt = request .prompt ,
140
+ multi_modal_data = request .multi_modal_data ))
103
141
sampling_params .append (
104
142
SamplingParams (
105
143
n = n ,
@@ -148,7 +186,9 @@ async def run_vllm_async(
148
186
prompts : List [TextPrompt ] = []
149
187
sampling_params : List [SamplingParams ] = []
150
188
for request in requests :
151
- prompts .append (TextPrompt (prompt = request .prompt ))
189
+ prompts .append (
190
+ TextPrompt (prompt = request .prompt ,
191
+ multi_modal_data = request .multi_modal_data ))
152
192
sampling_params .append (
153
193
SamplingParams (
154
194
n = n ,
@@ -272,9 +312,10 @@ def main(args: argparse.Namespace):
272
312
for _ in range (args .num_prompts )
273
313
]
274
314
else :
275
- requests = sample_requests (args .dataset , args .num_prompts , tokenizer ,
276
- args .output_len )
315
+ requests = sample_requests (tokenizer , args )
277
316
317
+ is_multi_modal = any (request .multi_modal_data is not None
318
+ for request in requests )
278
319
if args .backend == "vllm" :
279
320
if args .async_engine :
280
321
elapsed_time = uvloop .run (
@@ -300,6 +341,11 @@ def main(args: argparse.Namespace):
300
341
for request in requests )
301
342
total_output_tokens = sum (request .expected_output_len
302
343
for request in requests )
344
+ if is_multi_modal :
345
+ print ("\033 [91mWARNING\033 [0m: Multi-modal request detected. The "
346
+ "following metrics are not accurate because image tokens are not"
347
+ " counted. See vllm-project/vllm/issues/9778 for details." )
348
+ # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
303
349
print (f"Throughput: { len (requests ) / elapsed_time :.2f} requests/s, "
304
350
f"{ total_num_tokens / elapsed_time :.2f} total tokens/s, "
305
351
f"{ total_output_tokens / elapsed_time :.2f} output tokens/s" )
0 commit comments