Skip to content

Commit 4e694f5

Browse files
lk-chenLinkun Chen
authored andcommitted
[Feature] Update benchmark_throughput.py to support image input (vllm-project#9851)
Signed-off-by: Linkun Chen <[email protected]> Co-authored-by: Linkun Chen <[email protected]> Signed-off-by: Loc Huynh <[email protected]>
1 parent 7f56954 commit 4e694f5

File tree

2 files changed

+75
-18
lines changed

2 files changed

+75
-18
lines changed

benchmarks/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,14 @@ You can download the dataset by running:
66
```bash
77
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
88
```
9+
10+
## Downloading the ShareGPT4V dataset
11+
12+
The json file refers to several image datasets (coco, llava, etc.). The benchmark scripts
13+
will ignore a datapoint if the referred image is missing.
14+
```bash
15+
wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/resolve/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
16+
mkdir coco -p
17+
wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
18+
unzip coco/train2017.zip -d coco/
19+
```

benchmarks/benchmark_throughput.py

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
import uvloop
11+
from PIL import Image
1112
from tqdm import tqdm
1213
from transformers import (AutoModelForCausalLM, AutoTokenizer,
1314
PreTrainedTokenizerBase)
@@ -38,12 +39,33 @@ class SampleRequest:
3839
multi_modal_data: Optional[MultiModalDataDict] = None
3940

4041

41-
def sample_requests(
42-
dataset_path: str,
43-
num_requests: int,
44-
tokenizer: PreTrainedTokenizerBase,
45-
fixed_output_len: Optional[int],
46-
) -> List[SampleRequest]:
42+
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
43+
"""Prepend and append special tokens around the question to form a prompt.
44+
45+
Args:
46+
question: The input question text to wrap with special tokens
47+
model: The name of the model being used, to determine which special
48+
tokens to add
49+
50+
Returns:
51+
The formatted prompt string with appropriate special tokens for the
52+
model
53+
54+
Raises:
55+
ValueError: If an unsupported model name is provided
56+
"""
57+
model = model.lower()
58+
if "pixtral" in model:
59+
return f"<s>[INST]{question}\n[IMG][/INST]"
60+
raise ValueError(f"Unsupported model {model}")
61+
62+
63+
def sample_requests(tokenizer: PreTrainedTokenizerBase,
64+
args: argparse.Namespace) -> List[SampleRequest]:
65+
dataset_path: str = args.dataset
66+
num_requests: int = args.num_prompts
67+
fixed_output_len: Optional[int] = args.output_len
68+
model: str = args.model
4769
if fixed_output_len is not None and fixed_output_len < 4:
4870
raise ValueError("output_len too small")
4971

@@ -52,23 +74,36 @@ def sample_requests(
5274
dataset = json.load(f)
5375
# Filter out the conversations with less than 2 turns.
5476
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
55-
# Only keep the first two turns of each conversation.
56-
dataset = [(data["conversations"][0]["value"],
57-
data["conversations"][1]["value"]) for data in dataset]
58-
5977
# Shuffle the dataset.
6078
random.shuffle(dataset)
6179

6280
# Filter out sequences that are too long or too short
6381
filtered_dataset: List[SampleRequest] = []
64-
for i in range(len(dataset)):
82+
for data in dataset:
6583
if len(filtered_dataset) == num_requests:
6684
break
6785

86+
# Only keep the first two turns of each conversation.
87+
prompt = data["conversations"][0]["value"]
88+
completion = data["conversations"][1]["value"]
89+
90+
multi_modal_data: Optional[MultiModalDataDict] = None
91+
if "image" in data:
92+
multi_modal_data = multi_modal_data or {}
93+
image_path = data["image"]
94+
# TODO(vllm-project/vllm/issues/9778): Support multiple images.
95+
assert isinstance(image_path,
96+
str), "Only support single image input"
97+
try:
98+
multi_modal_data["image"] = Image.open(image_path).convert(
99+
"RGB")
100+
except FileNotFoundError:
101+
# Ignore datapoint where asset is missing
102+
continue
103+
prompt = _get_prompt_for_image_model(question=prompt, model=model)
104+
68105
# Tokenize the prompts and completions.
69-
prompt = dataset[i][0]
70106
prompt_token_ids = tokenizer(prompt).input_ids
71-
completion = dataset[i][1]
72107
completion_token_ids = tokenizer(completion).input_ids
73108
prompt_len = len(prompt_token_ids)
74109
output_len = len(completion_token_ids
@@ -82,7 +117,8 @@ def sample_requests(
82117
filtered_dataset.append(
83118
SampleRequest(prompt=prompt,
84119
prompt_len=prompt_len,
85-
expected_output_len=output_len))
120+
expected_output_len=output_len,
121+
multi_modal_data=multi_modal_data))
86122

87123
return filtered_dataset
88124

@@ -99,7 +135,9 @@ def run_vllm(
99135
prompts: List[TextPrompt] = []
100136
sampling_params: List[SamplingParams] = []
101137
for request in requests:
102-
prompts.append(TextPrompt(prompt=request.prompt))
138+
prompts.append(
139+
TextPrompt(prompt=request.prompt,
140+
multi_modal_data=request.multi_modal_data))
103141
sampling_params.append(
104142
SamplingParams(
105143
n=n,
@@ -148,7 +186,9 @@ async def run_vllm_async(
148186
prompts: List[TextPrompt] = []
149187
sampling_params: List[SamplingParams] = []
150188
for request in requests:
151-
prompts.append(TextPrompt(prompt=request.prompt))
189+
prompts.append(
190+
TextPrompt(prompt=request.prompt,
191+
multi_modal_data=request.multi_modal_data))
152192
sampling_params.append(
153193
SamplingParams(
154194
n=n,
@@ -272,9 +312,10 @@ def main(args: argparse.Namespace):
272312
for _ in range(args.num_prompts)
273313
]
274314
else:
275-
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
276-
args.output_len)
315+
requests = sample_requests(tokenizer, args)
277316

317+
is_multi_modal = any(request.multi_modal_data is not None
318+
for request in requests)
278319
if args.backend == "vllm":
279320
if args.async_engine:
280321
elapsed_time = uvloop.run(
@@ -300,6 +341,11 @@ def main(args: argparse.Namespace):
300341
for request in requests)
301342
total_output_tokens = sum(request.expected_output_len
302343
for request in requests)
344+
if is_multi_modal:
345+
print("\033[91mWARNING\033[0m: Multi-modal request detected. The "
346+
"following metrics are not accurate because image tokens are not"
347+
" counted. See vllm-project/vllm/issues/9778 for details.")
348+
# TODO(vllm-project/vllm/issues/9778): Count molti-modal token length.
303349
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
304350
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
305351
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")

0 commit comments

Comments
 (0)