4
4
import json
5
5
import random
6
6
import time
7
- from typing import List , Optional
7
+ from functools import cache
8
+ from typing import Dict , List , Optional , Tuple
8
9
9
10
import torch
10
11
import uvloop
17
18
from vllm .entrypoints .openai .api_server import (
18
19
build_async_engine_client_from_engine_args )
19
20
from vllm .inputs import TextPrompt
21
+ from vllm .lora .request import LoRARequest
22
+ from vllm .lora .utils import get_adapter_absolute_path
20
23
from vllm .multimodal import MultiModalDataDict
21
24
from vllm .sampling_params import BeamSearchParams
25
+ from vllm .transformers_utils .tokenizer import AnyTokenizer , get_lora_tokenizer
22
26
from vllm .utils import FlexibleArgumentParser , merge_async_iterators
23
27
24
28
@@ -28,15 +32,17 @@ class SampleRequest:
28
32
29
33
Attributes:
30
34
prompt: The input text prompt for the model.
31
- multi_modal_data: Optional dictionary containing multi-modal data (e.g.
32
- images).
33
35
prompt_len: The length of the prompt in tokens.
34
36
expected_output_len: The expected length of the output in tokens.
37
+ multi_modal_data: Optional dictionary containing multi-modal data (e.g.
38
+ images).
39
+ lora_request: Optional LoRARequest specifying the LoRA to use.
35
40
"""
36
41
prompt : str
37
42
prompt_len : int
38
43
expected_output_len : int
39
44
multi_modal_data : Optional [MultiModalDataDict ] = None
45
+ lora_request : Optional [LoRARequest ] = None
40
46
41
47
42
48
def _get_prompt_for_image_model (question : str , * , model : str ) -> str :
@@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
60
66
raise ValueError (f"Unsupported model { model } " )
61
67
62
68
69
+ @cache
70
+ def lora_path_on_disk (lora_path : str ) -> str :
71
+ return get_adapter_absolute_path (lora_path )
72
+
73
+
74
+ lora_tokenizer_cache : Dict [int , AnyTokenizer ] = {}
75
+
76
+
77
+ def get_random_lora_request (
78
+ args : argparse .Namespace
79
+ ) -> Tuple [LoRARequest , Optional [AnyTokenizer ]]:
80
+ global lora_tokenizer_cache
81
+ lora_id = random .randint (1 , args .max_loras )
82
+ lora_request = LoRARequest (lora_name = str (lora_id ),
83
+ lora_int_id = lora_id ,
84
+ lora_path = lora_path_on_disk (args .lora_path ))
85
+ if lora_id not in lora_tokenizer_cache :
86
+ lora_tokenizer_cache [lora_id ] = get_lora_tokenizer (lora_request )
87
+ return lora_request , lora_tokenizer_cache [lora_id ]
88
+
89
+
63
90
def sample_requests (tokenizer : PreTrainedTokenizerBase ,
64
91
args : argparse .Namespace ) -> List [SampleRequest ]:
92
+
65
93
dataset_path : str = args .dataset
66
94
num_requests : int = args .num_prompts
67
95
fixed_output_len : Optional [int ] = args .output_len
@@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
79
107
80
108
# Filter out sequences that are too long or too short
81
109
filtered_dataset : List [SampleRequest ] = []
82
- for data in dataset :
110
+ for data in tqdm (dataset ,
111
+ total = len (filtered_dataset ),
112
+ desc = "sampling requests" ):
83
113
if len (filtered_dataset ) == num_requests :
84
114
break
85
115
@@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
102
132
continue
103
133
prompt = _get_prompt_for_image_model (question = prompt , model = model )
104
134
135
+ request_tokenizer = tokenizer
136
+ lora_request : Optional [LoRARequest ] = None
137
+ if args .enable_lora :
138
+ lora_request , lora_tokenizer = get_random_lora_request (args )
139
+ if lora_tokenizer :
140
+ request_tokenizer = lora_tokenizer
141
+
105
142
# Tokenize the prompts and completions.
106
- prompt_token_ids = tokenizer (prompt ).input_ids
107
- completion_token_ids = tokenizer (completion ).input_ids
143
+ prompt_token_ids = request_tokenizer (prompt ).input_ids
144
+ completion_token_ids = request_tokenizer (completion ).input_ids
108
145
prompt_len = len (prompt_token_ids )
109
146
output_len = len (completion_token_ids
110
147
) if fixed_output_len is None else fixed_output_len
@@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
118
155
SampleRequest (prompt = prompt ,
119
156
prompt_len = prompt_len ,
120
157
expected_output_len = output_len ,
121
- multi_modal_data = multi_modal_data ))
158
+ multi_modal_data = multi_modal_data ,
159
+ lora_request = lora_request ))
122
160
123
161
return filtered_dataset
124
162
@@ -146,14 +184,21 @@ def run_vllm(
146
184
ignore_eos = True ,
147
185
max_tokens = request .expected_output_len ,
148
186
))
187
+ lora_requests : Optional [List [LoRARequest ]] = None
188
+ if engine_args .enable_lora :
189
+ lora_requests = [request .lora_request for request in requests ]
149
190
150
191
use_beam_search = False
151
192
152
193
if not use_beam_search :
153
194
start = time .perf_counter ()
154
- llm .generate (prompts , sampling_params , use_tqdm = True )
195
+ llm .generate (prompts ,
196
+ sampling_params ,
197
+ lora_request = lora_requests ,
198
+ use_tqdm = True )
155
199
end = time .perf_counter ()
156
200
else :
201
+ assert lora_requests is None , "BeamSearch API does not support LoRA"
157
202
prompts = [request .prompt for request in requests ]
158
203
# output_len should be the same for all requests.
159
204
output_len = requests [0 ][2 ]
@@ -185,6 +230,7 @@ async def run_vllm_async(
185
230
# Add the requests to the engine.
186
231
prompts : List [TextPrompt ] = []
187
232
sampling_params : List [SamplingParams ] = []
233
+ lora_requests : List [Optional [LoRARequest ]] = []
188
234
for request in requests :
189
235
prompts .append (
190
236
TextPrompt (prompt = request .prompt ,
@@ -197,11 +243,16 @@ async def run_vllm_async(
197
243
ignore_eos = True ,
198
244
max_tokens = request .expected_output_len ,
199
245
))
246
+ lora_requests .append (request .lora_request )
200
247
201
248
generators = []
202
249
start = time .perf_counter ()
203
- for i , (prompt , sp ) in enumerate (zip (prompts , sampling_params )):
204
- generator = llm .generate (prompt , sp , request_id = f"test{ i } " )
250
+ for i , (prompt , sp ,
251
+ lr ) in enumerate (zip (prompts , sampling_params , lora_requests )):
252
+ generator = llm .generate (prompt ,
253
+ sp ,
254
+ lora_request = lr ,
255
+ request_id = f"test{ i } " )
205
256
generators .append (generator )
206
257
all_gens = merge_async_iterators (* generators )
207
258
async for i , res in all_gens :
@@ -297,6 +348,14 @@ def main(args: argparse.Namespace):
297
348
vocab_size = tokenizer .vocab_size
298
349
requests = []
299
350
for _ in range (args .num_prompts ):
351
+
352
+ request_tokenizer = tokenizer
353
+ lora_request : Optional [LoRARequest ] = None
354
+ if args .enable_lora :
355
+ lora_request , lora_tokenizer = get_random_lora_request (args )
356
+ if lora_tokenizer :
357
+ request_tokenizer = lora_tokenizer
358
+
300
359
# Synthesize a prompt with the given input length.
301
360
candidate_ids = [
302
361
random .randint (0 , vocab_size - 1 )
@@ -305,8 +364,8 @@ def main(args: argparse.Namespace):
305
364
# As tokenizer may add additional tokens like BOS, we need to try
306
365
# different lengths to get the desired input length.
307
366
for _ in range (5 ): # Max attempts to correct
308
- candidate_prompt = tokenizer .decode (candidate_ids )
309
- tokenized_len = len (tokenizer .encode (candidate_prompt ))
367
+ candidate_prompt = request_tokenizer .decode (candidate_ids )
368
+ tokenized_len = len (request_tokenizer .encode (candidate_prompt ))
310
369
311
370
if tokenized_len == args .input_len :
312
371
break
@@ -323,7 +382,8 @@ def main(args: argparse.Namespace):
323
382
requests .append (
324
383
SampleRequest (prompt = candidate_prompt ,
325
384
prompt_len = args .input_len ,
326
- expected_output_len = args .output_len ))
385
+ expected_output_len = args .output_len ,
386
+ lora_request = lora_request ))
327
387
else :
328
388
requests = sample_requests (tokenizer , args )
329
389
@@ -422,6 +482,14 @@ def main(args: argparse.Namespace):
422
482
action = 'store_true' ,
423
483
default = False ,
424
484
help = "Disable decoupled async engine frontend." )
485
+ # LoRA
486
+ parser .add_argument (
487
+ "--lora-path" ,
488
+ type = str ,
489
+ default = None ,
490
+ help = "Path to the lora adapters to use. This can be an absolute path, "
491
+ "a relative path, or a Hugging Face model identifier." )
492
+
425
493
parser = AsyncEngineArgs .add_cli_args (parser )
426
494
args = parser .parse_args ()
427
495
if args .tokenizer is None :
@@ -431,6 +499,8 @@ def main(args: argparse.Namespace):
431
499
assert args .output_len is not None
432
500
else :
433
501
assert args .input_len is None
502
+ if args .enable_lora :
503
+ assert args .lora_path is not None
434
504
435
505
if args .backend == "vllm" :
436
506
if args .hf_max_batch_size is not None :
@@ -440,6 +510,9 @@ def main(args: argparse.Namespace):
440
510
raise ValueError ("HF max batch size is required for HF backend." )
441
511
if args .quantization is not None :
442
512
raise ValueError ("Quantization is only for vLLM backend." )
513
+ if args .enable_lora is not None :
514
+ raise ValueError ("LoRA benchmarking is only supported for vLLM"
515
+ " backend" )
443
516
elif args .backend == "mii" :
444
517
if args .dtype != "auto" :
445
518
raise ValueError ("dtype must be auto for MII backend." )
@@ -452,4 +525,7 @@ def main(args: argparse.Namespace):
452
525
if args .tokenizer != args .model :
453
526
raise ValueError ("Tokenizer must be the same as the model for MII "
454
527
"backend." )
528
+ if args .enable_lora is not None :
529
+ raise ValueError ("LoRA benchmarking is only supported for vLLM"
530
+ " backend" )
455
531
main (args )
0 commit comments