|
7 | 7 | import random |
8 | 8 | import time |
9 | 9 | from functools import cache |
10 | | -from typing import Any, Optional |
| 10 | +from typing import Any, Optional, Union |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | import uvloop |
|
20 | 20 | from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs |
21 | 21 | from vllm.entrypoints.openai.api_server import ( |
22 | 22 | build_async_engine_client_from_engine_args) |
23 | | -from vllm.inputs import TextPrompt |
| 23 | +from vllm.inputs import TextPrompt, TokensPrompt |
24 | 24 | from vllm.lora.request import LoRARequest |
25 | 25 | from vllm.lora.utils import get_adapter_absolute_path |
26 | 26 | from vllm.multimodal import MultiModalDataDict |
@@ -178,10 +178,13 @@ def run_vllm( |
178 | 178 | "Please ensure that max_model_len is greater than the sum of" |
179 | 179 | " prompt_len and expected_output_len for all requests.") |
180 | 180 | # Add the requests to the engine. |
181 | | - prompts: list[TextPrompt] = [] |
| 181 | + prompts: list[Union[TextPrompt, TokensPrompt]] = [] |
182 | 182 | sampling_params: list[SamplingParams] = [] |
183 | 183 | for request in requests: |
184 | 184 | prompts.append( |
| 185 | + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], |
| 186 | + multi_modal_data=request.multi_modal_data) |
| 187 | + if "prompt_token_ids" in request.prompt else \ |
185 | 188 | TextPrompt(prompt=request.prompt, |
186 | 189 | multi_modal_data=request.multi_modal_data)) |
187 | 190 | sampling_params.append( |
@@ -242,11 +245,14 @@ async def run_vllm_async( |
242 | 245 | " prompt_len and expected_output_len for all requests.") |
243 | 246 |
|
244 | 247 | # Add the requests to the engine. |
245 | | - prompts: list[TextPrompt] = [] |
| 248 | + prompts: list[Union[TextPrompt, TokensPrompt]] = [] |
246 | 249 | sampling_params: list[SamplingParams] = [] |
247 | 250 | lora_requests: list[Optional[LoRARequest]] = [] |
248 | 251 | for request in requests: |
249 | 252 | prompts.append( |
| 253 | + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], |
| 254 | + multi_modal_data=request.multi_modal_data) |
| 255 | + if "prompt_token_ids" in request.prompt else \ |
250 | 256 | TextPrompt(prompt=request.prompt, |
251 | 257 | multi_modal_data=request.multi_modal_data)) |
252 | 258 | sampling_params.append( |
@@ -393,24 +399,29 @@ def main(args: argparse.Namespace): |
393 | 399 | random.randint(0, vocab_size - 1) |
394 | 400 | for _ in range(args.input_len) |
395 | 401 | ] |
396 | | - # As tokenizer may add additional tokens like BOS, we need to try |
397 | | - # different lengths to get the desired input length. |
398 | | - for _ in range(5): # Max attempts to correct |
399 | | - candidate_prompt = request_tokenizer.decode(candidate_ids) |
400 | | - tokenized_len = len(request_tokenizer.encode(candidate_prompt)) |
401 | | - |
402 | | - if tokenized_len == args.input_len: |
403 | | - break |
404 | | - |
405 | | - # Adjust length based on difference |
406 | | - diff = args.input_len - tokenized_len |
407 | | - if diff > 0: |
408 | | - candidate_ids.extend([ |
409 | | - random.randint(100, vocab_size - 100) |
410 | | - for _ in range(diff) |
411 | | - ]) |
412 | | - else: |
413 | | - candidate_ids = candidate_ids[:diff] |
| 402 | + |
| 403 | + candidate_prompt = {"prompt_token_ids": candidate_ids} |
| 404 | + |
| 405 | + if not args.skip_tokenizer_init: |
| 406 | + # As tokenizer may add additional tokens like BOS, we need |
| 407 | + # to try different lengths to get the desired input length. |
| 408 | + for _ in range(5): # Max attempts to correct |
| 409 | + candidate_prompt = request_tokenizer.decode(candidate_ids) |
| 410 | + tokenized_len = len( |
| 411 | + request_tokenizer.encode(candidate_prompt)) |
| 412 | + |
| 413 | + if tokenized_len == args.input_len: |
| 414 | + break |
| 415 | + |
| 416 | + # Adjust length based on difference |
| 417 | + diff = args.input_len - tokenized_len |
| 418 | + if diff > 0: |
| 419 | + candidate_ids.extend([ |
| 420 | + random.randint(100, vocab_size - 100) |
| 421 | + for _ in range(diff) |
| 422 | + ]) |
| 423 | + else: |
| 424 | + candidate_ids = candidate_ids[:diff] |
414 | 425 | requests.append( |
415 | 426 | SampleRequest(prompt=candidate_prompt, |
416 | 427 | prompt_len=args.input_len, |
|
0 commit comments