6
6
7
7
from vllm .engine .arg_utils import EngineArgs
8
8
from vllm .engine .llm_engine import LLMEngine
9
+ from vllm .entrypoints .chat_utils import (ChatCompletionMessageParam ,
10
+ apply_chat_template ,
11
+ parse_chat_messages )
9
12
from vllm .inputs import PromptInputs , TextPrompt , TokensPrompt
10
13
from vllm .inputs .parse import parse_and_batch_prompt
11
14
from vllm .logger import init_logger
@@ -87,7 +90,7 @@ class LLM:
87
90
disable_custom_all_reduce: See ParallelConfig
88
91
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
89
92
:ref:`engine_args`)
90
-
93
+
91
94
Note:
92
95
This class is intended to be used for offline inference. For online
93
96
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
@@ -138,8 +141,12 @@ def __init__(
138
141
139
142
if "disable_log_stats" not in kwargs :
140
143
kwargs ["disable_log_stats" ] = True
141
- removed_vision_keys = ("image_token_id" , "image_feature_size" ,
142
- "image_input_shape" , "image_input_type" )
144
+ removed_vision_keys = (
145
+ "image_token_id" ,
146
+ "image_feature_size" ,
147
+ "image_input_shape" ,
148
+ "image_input_type" ,
149
+ )
143
150
if any (k in kwargs for k in removed_vision_keys ):
144
151
raise TypeError (
145
152
"There is no need to pass vision-related arguments anymore." )
@@ -259,11 +266,12 @@ def generate(
259
266
) -> List [RequestOutput ]:
260
267
...
261
268
262
- @deprecate_kwargs ("prompts" ,
263
- "prompt_token_ids" ,
264
- is_deprecated = lambda : LLM .DEPRECATE_LEGACY ,
265
- additional_message = "Please use the 'inputs' parameter "
266
- "instead." )
269
+ @deprecate_kwargs (
270
+ "prompts" ,
271
+ "prompt_token_ids" ,
272
+ is_deprecated = lambda : LLM .DEPRECATE_LEGACY ,
273
+ additional_message = "Please use the 'inputs' parameter instead." ,
274
+ )
267
275
def generate (
268
276
self ,
269
277
prompts : Union [Union [PromptInputs , Sequence [PromptInputs ]],
@@ -286,17 +294,17 @@ def generate(
286
294
Args:
287
295
inputs: A list of inputs to generate completions for.
288
296
sampling_params: The sampling parameters for text generation. If
289
- None, we use the default sampling parameters.
290
- When it is a single value, it is applied to every prompt.
291
- When it is a list, the list must have the same length as the
297
+ None, we use the default sampling parameters.
298
+ When it is a single value, it is applied to every prompt.
299
+ When it is a list, the list must have the same length as the
292
300
prompts and it is paired one by one with the prompt.
293
301
use_tqdm: Whether to use tqdm to display the progress bar.
294
302
lora_request: LoRA request to use for generation, if any.
295
- prompt_adapter_request: Prompt Adapter request to use for
303
+ prompt_adapter_request: Prompt Adapter request to use for
296
304
generation, if any.
297
305
298
306
Returns:
299
- A list of `RequestOutput` objects containing the
307
+ A list of `` RequestOutput` ` objects containing the
300
308
generated completions in the same order as the input prompts.
301
309
302
310
Note:
@@ -339,6 +347,62 @@ def generate(
339
347
outputs = self ._run_engine (use_tqdm = use_tqdm )
340
348
return LLMEngine .validate_outputs (outputs , RequestOutput )
341
349
350
+ def chat (
351
+ self ,
352
+ messages : List [ChatCompletionMessageParam ],
353
+ sampling_params : Optional [Union [SamplingParams ,
354
+ List [SamplingParams ]]] = None ,
355
+ use_tqdm : bool = True ,
356
+ lora_request : Optional [LoRARequest ] = None ,
357
+ chat_template : Optional [str ] = None ,
358
+ add_generation_template : bool = True ,
359
+ ) -> List [RequestOutput ]:
360
+ """
361
+ Generates responses for chat messages.
362
+
363
+ Converts the messages to prompts using the tokenizer and calls
364
+ the :meth:`generate` method to generate the responses.
365
+
366
+ Args:
367
+ messages: A list of messages to generate responses for. Each
368
+ message is a list of dictionaries with 'role' and 'content'
369
+ keys.
370
+ sampling_params: The sampling parameters for text generation.
371
+ If None, we use the default sampling parameters. When it
372
+ is a single value, it is applied to every prompt. When it
373
+ is a list, the list must have the same length as the
374
+ prompts and it is paired one by one with the prompt.
375
+ use_tqdm: Whether to use tqdm to display the progress bar.
376
+ lora_request: LoRA request to use for generation, if any.
377
+ chat_template: The template to use for structuring the chat.
378
+ If not provided, the model's default chat template will be used.
379
+ add_generation_template: If True, adds a generation template
380
+ to each message.
381
+
382
+ Returns:
383
+ A list of ``RequestOutput`` objects containing the generated
384
+ responses in the same order as the input messages.
385
+ """
386
+
387
+ tokenizer = self .get_tokenizer ()
388
+ model_config = self .llm_engine .get_model_config ()
389
+
390
+ conversations , _ = parse_chat_messages (messages , model_config ,
391
+ tokenizer )
392
+
393
+ prompts = apply_chat_template (
394
+ tokenizer ,
395
+ conversations ,
396
+ chat_template = chat_template ,
397
+ add_generation_template = add_generation_template )
398
+
399
+ return self .generate (
400
+ prompts ,
401
+ sampling_params ,
402
+ use_tqdm = use_tqdm ,
403
+ lora_request = lora_request ,
404
+ )
405
+
342
406
@overload # LEGACY: single (prompt + optional token ids)
343
407
def encode (
344
408
self ,
@@ -413,11 +477,12 @@ def encode(
413
477
) -> List [EmbeddingRequestOutput ]:
414
478
...
415
479
416
- @deprecate_kwargs ("prompts" ,
417
- "prompt_token_ids" ,
418
- is_deprecated = lambda : LLM .DEPRECATE_LEGACY ,
419
- additional_message = "Please use the 'inputs' parameter "
420
- "instead." )
480
+ @deprecate_kwargs (
481
+ "prompts" ,
482
+ "prompt_token_ids" ,
483
+ is_deprecated = lambda : LLM .DEPRECATE_LEGACY ,
484
+ additional_message = "Please use the 'inputs' parameter instead." ,
485
+ )
421
486
def encode (
422
487
self ,
423
488
prompts : Union [Union [PromptInputs , Sequence [PromptInputs ]],
@@ -443,7 +508,7 @@ def encode(
443
508
use the default pooling parameters.
444
509
use_tqdm: Whether to use tqdm to display the progress bar.
445
510
lora_request: LoRA request to use for generation, if any.
446
- prompt_adapter_request: Prompt Adapter request to use for
511
+ prompt_adapter_request: Prompt Adapter request to use for
447
512
generation, if any.
448
513
449
514
Returns:
@@ -563,23 +628,24 @@ def _validate_and_add_requests(
563
628
params [i ] if isinstance (params , Sequence ) else params ,
564
629
lora_request = lora_request [i ] if isinstance (
565
630
lora_request , Sequence ) else lora_request ,
566
- prompt_adapter_request = prompt_adapter_request )
631
+ prompt_adapter_request = prompt_adapter_request ,
632
+ )
567
633
568
634
def _add_request (
569
- self ,
570
- inputs : PromptInputs ,
571
- params : Union [SamplingParams , PoolingParams ],
572
- lora_request : Optional [Union [List [LoRARequest ],
573
- LoRARequest ]] = None ,
574
- prompt_adapter_request : Optional [PromptAdapterRequest ] = None
635
+ self ,
636
+ inputs : PromptInputs ,
637
+ params : Union [SamplingParams , PoolingParams ],
638
+ lora_request : Optional [Union [List [LoRARequest ], LoRARequest ]] = None ,
639
+ prompt_adapter_request : Optional [PromptAdapterRequest ] = None ,
575
640
) -> None :
576
641
request_id = str (next (self .request_counter ))
577
642
self .llm_engine .add_request (
578
643
request_id ,
579
644
inputs ,
580
645
params ,
581
646
lora_request = lora_request ,
582
- prompt_adapter_request = prompt_adapter_request )
647
+ prompt_adapter_request = prompt_adapter_request ,
648
+ )
583
649
584
650
def _add_guided_processor (
585
651
self ,
@@ -628,8 +694,8 @@ def _run_engine(
628
694
in_spd = total_in_toks / pbar .format_dict ["elapsed" ]
629
695
total_out_toks += sum (
630
696
len (stp .token_ids ) for stp in output .outputs )
631
- out_spd = total_out_toks / pbar . format_dict [
632
- "elapsed" ]
697
+ out_spd = ( total_out_toks /
698
+ pbar . format_dict [ "elapsed" ])
633
699
pbar .postfix = (
634
700
f"est. speed input: { in_spd :.2f} toks/s, "
635
701
f"output: { out_spd :.2f} toks/s" )
0 commit comments