Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
from typing import (ClassVar, Dict, List, Optional, Sequence, Union, cast,
overload)

from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
Expand Down Expand Up @@ -173,6 +174,35 @@ def set_tokenizer(
self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
tokenizer)

def apply_chat_template(
self,
messages_list: Union[List[List[Dict[str, str]]], List[Dict[str, str]]],
add_generation_prompt: bool = True,
tokenize: bool = False
) -> Union[Union[List[List[str]], List[str]], Union[List[List[int]],
List[int]]]:
"""Applies a chat template to the given messages.

Args:
messages_list: A list of list of dicts of list of dicts messages to
be processed. add_generation_prompt: Whether to add a generation
prompt at the end.
tokenize: Whether to return token IDs (True) or
raw text (False).

Returns:
A list of token IDs if tokenize is True, otherwise a
list of text strings.
"""
tokenizer = self.get_tokenizer()
ids_or_text = [
tokenizer.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
tokenize=tokenize) for messages in messages_list
]
return ids_or_text

@overload # LEGACY: single (prompt + optional token ids)
def generate(
self,
Expand Down Expand Up @@ -259,6 +289,8 @@ def generate(
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
messages_list: Optional[Union[List[List[Dict[str, str]]],
List[Dict[str, str]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand Down Expand Up @@ -295,6 +327,12 @@ def generate(
"LLM.generate() is only supported for generation models "
"(XForCausalLM).")

# Use apply_chat_template if prompt_token_ids is not provided
if prompt_token_ids is None and messages_list is not None:
prompt_token_ids = cast(
Union[List[int], List[List[int]]],
self.apply_chat_template(messages_list, tokenize=True))

if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
Expand Down