Skip to content

Commit 9835673

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[misc] benchmark_throughput : Add LoRA (#11267)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent f26c4ae commit 9835673

File tree

1 file changed

+89
-13
lines changed

1 file changed

+89
-13
lines changed

benchmarks/benchmark_throughput.py

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import json
55
import random
66
import time
7-
from typing import List, Optional
7+
from functools import cache
8+
from typing import Dict, List, Optional, Tuple
89

910
import torch
1011
import uvloop
@@ -17,8 +18,11 @@
1718
from vllm.entrypoints.openai.api_server import (
1819
build_async_engine_client_from_engine_args)
1920
from vllm.inputs import TextPrompt
21+
from vllm.lora.request import LoRARequest
22+
from vllm.lora.utils import get_adapter_absolute_path
2023
from vllm.multimodal import MultiModalDataDict
2124
from vllm.sampling_params import BeamSearchParams
25+
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
2226
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
2327

2428

@@ -28,15 +32,17 @@ class SampleRequest:
2832
2933
Attributes:
3034
prompt: The input text prompt for the model.
31-
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
32-
images).
3335
prompt_len: The length of the prompt in tokens.
3436
expected_output_len: The expected length of the output in tokens.
37+
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
38+
images).
39+
lora_request: Optional LoRARequest specifying the LoRA to use.
3540
"""
3641
prompt: str
3742
prompt_len: int
3843
expected_output_len: int
3944
multi_modal_data: Optional[MultiModalDataDict] = None
45+
lora_request: Optional[LoRARequest] = None
4046

4147

4248
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
@@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
6066
raise ValueError(f"Unsupported model {model}")
6167

6268

69+
@cache
70+
def lora_path_on_disk(lora_path: str) -> str:
71+
return get_adapter_absolute_path(lora_path)
72+
73+
74+
lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}
75+
76+
77+
def get_random_lora_request(
78+
args: argparse.Namespace
79+
) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
80+
global lora_tokenizer_cache
81+
lora_id = random.randint(1, args.max_loras)
82+
lora_request = LoRARequest(lora_name=str(lora_id),
83+
lora_int_id=lora_id,
84+
lora_path=lora_path_on_disk(args.lora_path))
85+
if lora_id not in lora_tokenizer_cache:
86+
lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
87+
return lora_request, lora_tokenizer_cache[lora_id]
88+
89+
6390
def sample_requests(tokenizer: PreTrainedTokenizerBase,
6491
args: argparse.Namespace) -> List[SampleRequest]:
92+
6593
dataset_path: str = args.dataset
6694
num_requests: int = args.num_prompts
6795
fixed_output_len: Optional[int] = args.output_len
@@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
79107

80108
# Filter out sequences that are too long or too short
81109
filtered_dataset: List[SampleRequest] = []
82-
for data in dataset:
110+
for data in tqdm(dataset,
111+
total=len(filtered_dataset),
112+
desc="sampling requests"):
83113
if len(filtered_dataset) == num_requests:
84114
break
85115

@@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
102132
continue
103133
prompt = _get_prompt_for_image_model(question=prompt, model=model)
104134

135+
request_tokenizer = tokenizer
136+
lora_request: Optional[LoRARequest] = None
137+
if args.enable_lora:
138+
lora_request, lora_tokenizer = get_random_lora_request(args)
139+
if lora_tokenizer:
140+
request_tokenizer = lora_tokenizer
141+
105142
# Tokenize the prompts and completions.
106-
prompt_token_ids = tokenizer(prompt).input_ids
107-
completion_token_ids = tokenizer(completion).input_ids
143+
prompt_token_ids = request_tokenizer(prompt).input_ids
144+
completion_token_ids = request_tokenizer(completion).input_ids
108145
prompt_len = len(prompt_token_ids)
109146
output_len = len(completion_token_ids
110147
) if fixed_output_len is None else fixed_output_len
@@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
118155
SampleRequest(prompt=prompt,
119156
prompt_len=prompt_len,
120157
expected_output_len=output_len,
121-
multi_modal_data=multi_modal_data))
158+
multi_modal_data=multi_modal_data,
159+
lora_request=lora_request))
122160

123161
return filtered_dataset
124162

@@ -146,14 +184,21 @@ def run_vllm(
146184
ignore_eos=True,
147185
max_tokens=request.expected_output_len,
148186
))
187+
lora_requests: Optional[List[LoRARequest]] = None
188+
if engine_args.enable_lora:
189+
lora_requests = [request.lora_request for request in requests]
149190

150191
use_beam_search = False
151192

152193
if not use_beam_search:
153194
start = time.perf_counter()
154-
llm.generate(prompts, sampling_params, use_tqdm=True)
195+
llm.generate(prompts,
196+
sampling_params,
197+
lora_request=lora_requests,
198+
use_tqdm=True)
155199
end = time.perf_counter()
156200
else:
201+
assert lora_requests is None, "BeamSearch API does not support LoRA"
157202
prompts = [request.prompt for request in requests]
158203
# output_len should be the same for all requests.
159204
output_len = requests[0][2]
@@ -185,6 +230,7 @@ async def run_vllm_async(
185230
# Add the requests to the engine.
186231
prompts: List[TextPrompt] = []
187232
sampling_params: List[SamplingParams] = []
233+
lora_requests: List[Optional[LoRARequest]] = []
188234
for request in requests:
189235
prompts.append(
190236
TextPrompt(prompt=request.prompt,
@@ -197,11 +243,16 @@ async def run_vllm_async(
197243
ignore_eos=True,
198244
max_tokens=request.expected_output_len,
199245
))
246+
lora_requests.append(request.lora_request)
200247

201248
generators = []
202249
start = time.perf_counter()
203-
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
204-
generator = llm.generate(prompt, sp, request_id=f"test{i}")
250+
for i, (prompt, sp,
251+
lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
252+
generator = llm.generate(prompt,
253+
sp,
254+
lora_request=lr,
255+
request_id=f"test{i}")
205256
generators.append(generator)
206257
all_gens = merge_async_iterators(*generators)
207258
async for i, res in all_gens:
@@ -297,6 +348,14 @@ def main(args: argparse.Namespace):
297348
vocab_size = tokenizer.vocab_size
298349
requests = []
299350
for _ in range(args.num_prompts):
351+
352+
request_tokenizer = tokenizer
353+
lora_request: Optional[LoRARequest] = None
354+
if args.enable_lora:
355+
lora_request, lora_tokenizer = get_random_lora_request(args)
356+
if lora_tokenizer:
357+
request_tokenizer = lora_tokenizer
358+
300359
# Synthesize a prompt with the given input length.
301360
candidate_ids = [
302361
random.randint(0, vocab_size - 1)
@@ -305,8 +364,8 @@ def main(args: argparse.Namespace):
305364
# As tokenizer may add additional tokens like BOS, we need to try
306365
# different lengths to get the desired input length.
307366
for _ in range(5): # Max attempts to correct
308-
candidate_prompt = tokenizer.decode(candidate_ids)
309-
tokenized_len = len(tokenizer.encode(candidate_prompt))
367+
candidate_prompt = request_tokenizer.decode(candidate_ids)
368+
tokenized_len = len(request_tokenizer.encode(candidate_prompt))
310369

311370
if tokenized_len == args.input_len:
312371
break
@@ -323,7 +382,8 @@ def main(args: argparse.Namespace):
323382
requests.append(
324383
SampleRequest(prompt=candidate_prompt,
325384
prompt_len=args.input_len,
326-
expected_output_len=args.output_len))
385+
expected_output_len=args.output_len,
386+
lora_request=lora_request))
327387
else:
328388
requests = sample_requests(tokenizer, args)
329389

@@ -422,6 +482,14 @@ def main(args: argparse.Namespace):
422482
action='store_true',
423483
default=False,
424484
help="Disable decoupled async engine frontend.")
485+
# LoRA
486+
parser.add_argument(
487+
"--lora-path",
488+
type=str,
489+
default=None,
490+
help="Path to the lora adapters to use. This can be an absolute path, "
491+
"a relative path, or a Hugging Face model identifier.")
492+
425493
parser = AsyncEngineArgs.add_cli_args(parser)
426494
args = parser.parse_args()
427495
if args.tokenizer is None:
@@ -431,6 +499,8 @@ def main(args: argparse.Namespace):
431499
assert args.output_len is not None
432500
else:
433501
assert args.input_len is None
502+
if args.enable_lora:
503+
assert args.lora_path is not None
434504

435505
if args.backend == "vllm":
436506
if args.hf_max_batch_size is not None:
@@ -440,6 +510,9 @@ def main(args: argparse.Namespace):
440510
raise ValueError("HF max batch size is required for HF backend.")
441511
if args.quantization is not None:
442512
raise ValueError("Quantization is only for vLLM backend.")
513+
if args.enable_lora is not None:
514+
raise ValueError("LoRA benchmarking is only supported for vLLM"
515+
" backend")
443516
elif args.backend == "mii":
444517
if args.dtype != "auto":
445518
raise ValueError("dtype must be auto for MII backend.")
@@ -452,4 +525,7 @@ def main(args: argparse.Namespace):
452525
if args.tokenizer != args.model:
453526
raise ValueError("Tokenizer must be the same as the model for MII "
454527
"backend.")
528+
if args.enable_lora is not None:
529+
raise ValueError("LoRA benchmarking is only supported for vLLM"
530+
" backend")
455531
main(args)

0 commit comments

Comments
 (0)