diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 0abc357939e1..4c2c778095ec 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -87,6 +87,7 @@ OpenAI-Compatible Server ------------------------ vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. +By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models `_, `create chat completion `_, and `create completion `_ endpoints. We are actively adding support for more endpoints. Start the server: @@ -95,7 +96,12 @@ Start the server: $ python -m vllm.entrypoints.openai.api_server \ $ --model facebook/opt-125m -By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models `_ and `create completion `_ endpoints. We are actively adding support for more endpoints. +By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument: +.. code-block:: console + + $ python -m vllm.entrypoints.openai.api_server \ + --model facebook/opt-125m \ + --chat-template ./examples/template_chatml.json This server can be queried in the same format as OpenAI API. For example, list the models: @@ -103,6 +109,9 @@ This server can be queried in the same format as OpenAI API. For example, list t $ curl http://localhost:8000/v1/models +Using OpenAI Completions API with vLLM +------------------------------- + Query the model with input prompts: .. code-block:: console @@ -129,3 +138,45 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep print("Completion result:", completion) For a more detailed client example, refer to `examples/openai_completion_client.py `_. + +Using OpenAI Chat API with vLLM +------------------------------- + +The vLLM server is designed to support the OpenAI Chat API, allowing you to engage in dynamic conversations with the model. The chat interface is a more interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations. + +Querying the model using OpenAI Chat API: + +You can use the `create chat completion `_ endpoint to communicate with the model in a chat-like interface: + +.. code-block:: console + + $ curl http://localhost:8000/v1/chat/completions \ + $ -H "Content-Type: application/json" \ + $ -d '{ + $ "model": "facebook/opt-125m", + $ "messages": [ + $ {"role": "system", "content": "You are a helpful assistant."}, + $ {"role": "user", "content": "Who won the world series in 2020?"}, + $ ] + $ }' + +Python Client Example: + +Using the `openai` python package, you can also communicate with the model in a chat-like manner: + +.. code-block:: python + + import openai + # Set OpenAI's API key and API base to use vLLM's API server. + openai.api_key = "EMPTY" + openai.api_base = "http://localhost:8000/v1" + chat_response = openai.ChatCompletion.create( + model="facebook/opt-125m", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + ) + print("Chat response:", chat_response) + +For more in-depth examples and advanced features of the chat API, you can refer to the official OpenAI documentation. diff --git a/examples/template_chatml.json b/examples/template_chatml.json new file mode 100644 index 000000000000..c8516b1225f0 --- /dev/null +++ b/examples/template_chatml.json @@ -0,0 +1,3 @@ +{ + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" +} \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f336b4656555..e3bb20116e5d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,6 +3,7 @@ import argparse import asyncio +import codecs import json import time from http import HTTPStatus @@ -14,7 +15,6 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse, Response -from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -31,20 +31,13 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import random_uuid -try: - import fastchat - from fastchat.conversation import Conversation, SeparatorStyle - from fastchat.model.model_adapter import get_conversation_template - _fastchat_available = True -except ImportError: - _fastchat_available = False - TIMEOUT_KEEP_ALIVE = 5 # seconds logger = init_logger(__name__) served_model = None app = fastapi.FastAPI() engine = None +chat_template = None def create_error_response(status_code: HTTPStatus, @@ -70,50 +63,13 @@ async def check_model(request) -> Optional[JSONResponse]: async def get_gen_prompt(request) -> str: - if not _fastchat_available: - raise ModuleNotFoundError( - "fastchat is not installed. Please install fastchat to use " - "the chat completion and conversation APIs: `$ pip install fschat`" - ) - if version.parse(fastchat.__version__) < version.parse("0.2.23"): - raise ImportError( - f"fastchat version is low. Current version: {fastchat.__version__} " - "Please upgrade fastchat to use: `$ pip install -U fschat`") - - conv = get_conversation_template(request.model) - conv = Conversation( - name=conv.name, - system_template=conv.system_template, - system_message=conv.system_message, - roles=conv.roles, - messages=list(conv.messages), # prevent in-place modification - offset=conv.offset, - sep_style=SeparatorStyle(conv.sep_style), - sep=conv.sep, - sep2=conv.sep2, - stop_str=conv.stop_str, - stop_token_ids=conv.stop_token_ids, - ) - - if isinstance(request.messages, str): - prompt = request.messages - else: - for message in request.messages: - msg_role = message["role"] - if msg_role == "system": - conv.system_message = message["content"] - elif msg_role == "user": - conv.append_message(conv.roles[0], message["content"]) - elif msg_role == "assistant": - conv.append_message(conv.roles[1], message["content"]) - else: - raise ValueError(f"Unknown role: {msg_role}") - - # Add a blank message for the assistant. - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - - return prompt + try: + return tokenizer.apply_chat_template( + conversation=request.messages, + tokenize=False, + add_generation_prompt=request.add_generation_prompt) + except Exception as e: + raise RuntimeError(f"Error generating prompt: {str(e)}") from e async def check_length( @@ -209,7 +165,12 @@ async def create_chat_completion(request: ChatCompletionRequest, return create_error_response(HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported") - prompt = await get_gen_prompt(request) + try: + prompt = await get_gen_prompt(request) + except RuntimeError as e: + logger.error(f"Error in generating prompt from request: {str(e)}") + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + token_ids, error_check_ret = await check_length(request, prompt=prompt) if error_check_ret is not None: return error_check_ret @@ -217,6 +178,7 @@ async def create_chat_completion(request: ChatCompletionRequest, model_name = request.model request_id = f"cmpl-{random_uuid()}" created_time = int(time.monotonic()) + obj_str = "chat.completion.chunk" try: spaces_between_special_tokens = request.spaces_between_special_tokens sampling_params = SamplingParams( @@ -241,61 +203,83 @@ async def create_chat_completion(request: ChatCompletionRequest, result_generator = engine.generate(prompt, sampling_params, request_id, token_ids) - def create_stream_response_json( - index: int, - text: str, - finish_reason: Optional[str] = None, - ) -> str: - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(content=text), - finish_reason=finish_reason, - ) - response = ChatCompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - ) - response_json = response.json(ensure_ascii=False) - - return response_json + def get_role() -> str: + if request.add_generation_prompt: + return "assistant" + else: + return request.messages[-1]["role"] async def completion_stream_generator() -> AsyncGenerator[str, None]: - # First chunk with role + role = get_role() for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(role="assistant"), - finish_reason=None, - ) + index=i, delta=DeltaMessage(role=role), finish_reason=None) chunk = ChatCompletionStreamResponse(id=request_id, + object=obj_str, + created=created_time, choices=[choice_data], model=model_name) data = chunk.json(exclude_unset=True, ensure_ascii=False) yield f"data: {data}\n\n" + # Handle echoing of the first input + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + if last_msg_content: + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + finish_reason=None) + chunk = ChatCompletionStreamResponse(id=request_id, + object=obj_str, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + previous_texts = [""] * request.n previous_num_tokens = [0] * request.n + finish_reason_sent = [False] * request.n async for res in result_generator: res: RequestOutput for output in res.outputs: i = output.index - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - response_json = create_stream_response_json( - index=i, - text=delta_text, - ) - yield f"data: {response_json}\n\n" - if output.finish_reason is not None: - response_json = create_stream_response_json( + if output.finish_reason is None and not finish_reason_sent[i]: + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + choice_data = ChatCompletionResponseStreamChoice( index=i, - text="", - finish_reason=output.finish_reason, - ) - yield f"data: {response_json}\n\n" + delta=DeltaMessage(content=delta_text), + finish_reason=None) + chunk = ChatCompletionStreamResponse(id=request_id, + object=obj_str, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + if output.finish_reason is not None and not finish_reason_sent[ + i]: + choice_data = ChatCompletionResponseStreamChoice( + index=i, delta=[], finish_reason=output.finish_reason) + chunk = ChatCompletionStreamResponse(id=request_id, + object=obj_str, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, + exclude_none=True, + ensure_ascii=False) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True yield "data: [DONE]\n\n" # Streaming response @@ -314,14 +298,25 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: final_res = res assert final_res is not None choices = [] + role = get_role() for output in final_res.outputs: choice_data = ChatCompletionResponseChoice( index=output.index, - message=ChatMessage(role="assistant", content=output.text), + message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason, ) choices.append(choice_data) + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get("role") == role: + last_msg_content = request.messages[-1]["content"] + + for choice in choices: + choice.message.content = last_msg_content + choice.message.content + num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) @@ -600,6 +595,12 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: help="The model name used in the API. If not " "specified, the model name will be the same as " "the huggingface name.") + parser.add_argument("--chat-template", + type=str, + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -619,6 +620,28 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: else: served_model = args.model + if args.chat_template is not None: + try: + with open(args.chat_template, "r") as f: + content = f.read() + try: + json_data = json.loads(content) + if "chat_template" in json_data: + chat_template = json_data["chat_template"] + else: + chat_template = content + except json.JSONDecodeError: + # If JSON fails, use the file content as raw text + chat_template = content + except OSError as e: + try: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + chat_template = codecs.decode(args.chat_template, + "unicode_escape") + except json.JSONDecodeError: + logger.error("Unable to set template.") + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) engine_model_config = asyncio.run(engine.get_model_config()) @@ -629,6 +652,16 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: tokenizer_mode=engine_args.tokenizer_mode, trust_remote_code=engine_args.trust_remote_code) + if chat_template is not None: + tokenizer.chat_template = chat_template + + tmp_template = tokenizer.chat_template or tokenizer.default_chat_template + if tmp_template: + logger.info(f"Chat template:\n{tmp_template}") + else: + logger.warning( + "No chat template loaded, the chat endpoint will be not work.") + uvicorn.run(app, host=args.host, port=args.port, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7700c5dd483e..6c97ce87aa67 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -73,6 +73,8 @@ class ChatCompletionRequest(BaseModel): stop_token_ids: Optional[List[int]] = Field(default_factory=list) skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True + add_generation_prompt: Optional[bool] = True + echo: Optional[bool] = False class CompletionRequest(BaseModel):