Skip to content

Commit 3b19e39

Browse files
nunjunjnunjunjnunjunjnunjunjDarkLight1337
authored
Chat method for offline llm (#5049)
Co-authored-by: nunjunj <[email protected]> Co-authored-by: nunjunj <[email protected]> Co-authored-by: nunjunj <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: DarkLight1337 <[email protected]>
1 parent 4cd7d47 commit 3b19e39

File tree

4 files changed

+168
-29
lines changed

4 files changed

+168
-29
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ steps:
147147
- pip install awscli tensorizer # for llava example and tensorizer test
148148
- python3 offline_inference.py
149149
- python3 cpu_offload.py
150+
- python3 offline_inference_chat.py
150151
- python3 offline_inference_with_prefix.py
151152
- python3 llm_engine_example.py
152153
- python3 offline_inference_vision_language.py

examples/offline_inference_chat.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from vllm import LLM, SamplingParams
2+
3+
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
4+
sampling_params = SamplingParams(temperature=0.5)
5+
6+
7+
def print_outputs(outputs):
8+
for output in outputs:
9+
prompt = output.prompt
10+
generated_text = output.outputs[0].text
11+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
12+
print("-" * 80)
13+
14+
15+
print("=" * 80)
16+
17+
# In this script, we demonstrate how to pass input to the chat method:
18+
19+
conversation = [
20+
{
21+
"role": "system",
22+
"content": "You are a helpful assistant"
23+
},
24+
{
25+
"role": "user",
26+
"content": "Hello"
27+
},
28+
{
29+
"role": "assistant",
30+
"content": "Hello! How can I assist you today?"
31+
},
32+
{
33+
"role": "user",
34+
"content": "Write an essay about the importance of higher education.",
35+
},
36+
]
37+
outputs = llm.chat(conversation,
38+
sampling_params=sampling_params,
39+
use_tqdm=False)
40+
print_outputs(outputs)
41+
42+
# A chat template can be optionally supplied.
43+
# If not, the model will use its default chat template.
44+
45+
# with open('template_falcon_180b.jinja', "r") as f:
46+
# chat_template = f.read()
47+
48+
# outputs = llm.chat(
49+
# conversations,
50+
# sampling_params=sampling_params,
51+
# use_tqdm=False,
52+
# chat_template=chat_template,
53+
# )

tests/entrypoints/llm/test_generate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,22 @@ def test_multiple_sampling_params(llm: LLM):
140140
# sampling_params is None, default params should be applied
141141
outputs = llm.generate(PROMPTS, sampling_params=None)
142142
assert len(PROMPTS) == len(outputs)
143+
144+
145+
def test_chat():
146+
147+
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
148+
149+
prompt1 = "Explain the concept of entropy."
150+
messages = [
151+
{
152+
"role": "system",
153+
"content": "You are a helpful assistant"
154+
},
155+
{
156+
"role": "user",
157+
"content": prompt1
158+
},
159+
]
160+
outputs = llm.chat(messages)
161+
assert len(outputs) == 1

vllm/entrypoints/llm.py

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
from vllm.engine.arg_utils import EngineArgs
88
from vllm.engine.llm_engine import LLMEngine
9+
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
10+
apply_chat_template,
11+
parse_chat_messages)
912
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
1013
from vllm.inputs.parse import parse_and_batch_prompt
1114
from vllm.logger import init_logger
@@ -87,7 +90,7 @@ class LLM:
8790
disable_custom_all_reduce: See ParallelConfig
8891
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
8992
:ref:`engine_args`)
90-
93+
9194
Note:
9295
This class is intended to be used for offline inference. For online
9396
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
@@ -138,8 +141,12 @@ def __init__(
138141

139142
if "disable_log_stats" not in kwargs:
140143
kwargs["disable_log_stats"] = True
141-
removed_vision_keys = ("image_token_id", "image_feature_size",
142-
"image_input_shape", "image_input_type")
144+
removed_vision_keys = (
145+
"image_token_id",
146+
"image_feature_size",
147+
"image_input_shape",
148+
"image_input_type",
149+
)
143150
if any(k in kwargs for k in removed_vision_keys):
144151
raise TypeError(
145152
"There is no need to pass vision-related arguments anymore.")
@@ -259,11 +266,12 @@ def generate(
259266
) -> List[RequestOutput]:
260267
...
261268

262-
@deprecate_kwargs("prompts",
263-
"prompt_token_ids",
264-
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
265-
additional_message="Please use the 'inputs' parameter "
266-
"instead.")
269+
@deprecate_kwargs(
270+
"prompts",
271+
"prompt_token_ids",
272+
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
273+
additional_message="Please use the 'inputs' parameter instead.",
274+
)
267275
def generate(
268276
self,
269277
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
@@ -286,17 +294,17 @@ def generate(
286294
Args:
287295
inputs: A list of inputs to generate completions for.
288296
sampling_params: The sampling parameters for text generation. If
289-
None, we use the default sampling parameters.
290-
When it is a single value, it is applied to every prompt.
291-
When it is a list, the list must have the same length as the
297+
None, we use the default sampling parameters.
298+
When it is a single value, it is applied to every prompt.
299+
When it is a list, the list must have the same length as the
292300
prompts and it is paired one by one with the prompt.
293301
use_tqdm: Whether to use tqdm to display the progress bar.
294302
lora_request: LoRA request to use for generation, if any.
295-
prompt_adapter_request: Prompt Adapter request to use for
303+
prompt_adapter_request: Prompt Adapter request to use for
296304
generation, if any.
297305
298306
Returns:
299-
A list of `RequestOutput` objects containing the
307+
A list of ``RequestOutput`` objects containing the
300308
generated completions in the same order as the input prompts.
301309
302310
Note:
@@ -339,6 +347,62 @@ def generate(
339347
outputs = self._run_engine(use_tqdm=use_tqdm)
340348
return LLMEngine.validate_outputs(outputs, RequestOutput)
341349

350+
def chat(
351+
self,
352+
messages: List[ChatCompletionMessageParam],
353+
sampling_params: Optional[Union[SamplingParams,
354+
List[SamplingParams]]] = None,
355+
use_tqdm: bool = True,
356+
lora_request: Optional[LoRARequest] = None,
357+
chat_template: Optional[str] = None,
358+
add_generation_template: bool = True,
359+
) -> List[RequestOutput]:
360+
"""
361+
Generates responses for chat messages.
362+
363+
Converts the messages to prompts using the tokenizer and calls
364+
the :meth:`generate` method to generate the responses.
365+
366+
Args:
367+
messages: A list of messages to generate responses for. Each
368+
message is a list of dictionaries with 'role' and 'content'
369+
keys.
370+
sampling_params: The sampling parameters for text generation.
371+
If None, we use the default sampling parameters. When it
372+
is a single value, it is applied to every prompt. When it
373+
is a list, the list must have the same length as the
374+
prompts and it is paired one by one with the prompt.
375+
use_tqdm: Whether to use tqdm to display the progress bar.
376+
lora_request: LoRA request to use for generation, if any.
377+
chat_template: The template to use for structuring the chat.
378+
If not provided, the model's default chat template will be used.
379+
add_generation_template: If True, adds a generation template
380+
to each message.
381+
382+
Returns:
383+
A list of ``RequestOutput`` objects containing the generated
384+
responses in the same order as the input messages.
385+
"""
386+
387+
tokenizer = self.get_tokenizer()
388+
model_config = self.llm_engine.get_model_config()
389+
390+
conversations, _ = parse_chat_messages(messages, model_config,
391+
tokenizer)
392+
393+
prompts = apply_chat_template(
394+
tokenizer,
395+
conversations,
396+
chat_template=chat_template,
397+
add_generation_template=add_generation_template)
398+
399+
return self.generate(
400+
prompts,
401+
sampling_params,
402+
use_tqdm=use_tqdm,
403+
lora_request=lora_request,
404+
)
405+
342406
@overload # LEGACY: single (prompt + optional token ids)
343407
def encode(
344408
self,
@@ -413,11 +477,12 @@ def encode(
413477
) -> List[EmbeddingRequestOutput]:
414478
...
415479

416-
@deprecate_kwargs("prompts",
417-
"prompt_token_ids",
418-
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
419-
additional_message="Please use the 'inputs' parameter "
420-
"instead.")
480+
@deprecate_kwargs(
481+
"prompts",
482+
"prompt_token_ids",
483+
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
484+
additional_message="Please use the 'inputs' parameter instead.",
485+
)
421486
def encode(
422487
self,
423488
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
@@ -443,7 +508,7 @@ def encode(
443508
use the default pooling parameters.
444509
use_tqdm: Whether to use tqdm to display the progress bar.
445510
lora_request: LoRA request to use for generation, if any.
446-
prompt_adapter_request: Prompt Adapter request to use for
511+
prompt_adapter_request: Prompt Adapter request to use for
447512
generation, if any.
448513
449514
Returns:
@@ -563,23 +628,24 @@ def _validate_and_add_requests(
563628
params[i] if isinstance(params, Sequence) else params,
564629
lora_request=lora_request[i] if isinstance(
565630
lora_request, Sequence) else lora_request,
566-
prompt_adapter_request=prompt_adapter_request)
631+
prompt_adapter_request=prompt_adapter_request,
632+
)
567633

568634
def _add_request(
569-
self,
570-
inputs: PromptInputs,
571-
params: Union[SamplingParams, PoolingParams],
572-
lora_request: Optional[Union[List[LoRARequest],
573-
LoRARequest]] = None,
574-
prompt_adapter_request: Optional[PromptAdapterRequest] = None
635+
self,
636+
inputs: PromptInputs,
637+
params: Union[SamplingParams, PoolingParams],
638+
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
639+
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
575640
) -> None:
576641
request_id = str(next(self.request_counter))
577642
self.llm_engine.add_request(
578643
request_id,
579644
inputs,
580645
params,
581646
lora_request=lora_request,
582-
prompt_adapter_request=prompt_adapter_request)
647+
prompt_adapter_request=prompt_adapter_request,
648+
)
583649

584650
def _add_guided_processor(
585651
self,
@@ -628,8 +694,8 @@ def _run_engine(
628694
in_spd = total_in_toks / pbar.format_dict["elapsed"]
629695
total_out_toks += sum(
630696
len(stp.token_ids) for stp in output.outputs)
631-
out_spd = total_out_toks / pbar.format_dict[
632-
"elapsed"]
697+
out_spd = (total_out_toks /
698+
pbar.format_dict["elapsed"])
633699
pbar.postfix = (
634700
f"est. speed input: {in_spd:.2f} toks/s, "
635701
f"output: {out_spd:.2f} toks/s")

0 commit comments

Comments
 (0)