From f641c76792ab74bc22fd73976347aba2cea38e77 Mon Sep 17 00:00:00 2001 From: Adam Brusselback Date: Fri, 27 Oct 2023 23:11:21 -0400 Subject: [PATCH 1/4] Add support for HF chat templates to OpenAI chat completions API. Add better documentation as well. --- docs/source/getting_started/quickstart.rst | 53 ++++++++++++- examples/template_chatml.json | 3 + vllm/entrypoints/openai/api_server.py | 89 +++++++++------------- 3 files changed, 92 insertions(+), 53 deletions(-) create mode 100644 examples/template_chatml.json diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 0abc357939e1..4c2c778095ec 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -87,6 +87,7 @@ OpenAI-Compatible Server ------------------------ vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. +By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models `_, `create chat completion `_, and `create completion `_ endpoints. We are actively adding support for more endpoints. Start the server: @@ -95,7 +96,12 @@ Start the server: $ python -m vllm.entrypoints.openai.api_server \ $ --model facebook/opt-125m -By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models `_ and `create completion `_ endpoints. We are actively adding support for more endpoints. +By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument: +.. code-block:: console + + $ python -m vllm.entrypoints.openai.api_server \ + --model facebook/opt-125m \ + --chat-template ./examples/template_chatml.json This server can be queried in the same format as OpenAI API. For example, list the models: @@ -103,6 +109,9 @@ This server can be queried in the same format as OpenAI API. For example, list t $ curl http://localhost:8000/v1/models +Using OpenAI Completions API with vLLM +------------------------------- + Query the model with input prompts: .. code-block:: console @@ -129,3 +138,45 @@ Since this server is compatible with OpenAI API, you can use it as a drop-in rep print("Completion result:", completion) For a more detailed client example, refer to `examples/openai_completion_client.py `_. + +Using OpenAI Chat API with vLLM +------------------------------- + +The vLLM server is designed to support the OpenAI Chat API, allowing you to engage in dynamic conversations with the model. The chat interface is a more interactive way to communicate with the model, allowing back-and-forth exchanges that can be stored in the chat history. This is useful for tasks that require context or more detailed explanations. + +Querying the model using OpenAI Chat API: + +You can use the `create chat completion `_ endpoint to communicate with the model in a chat-like interface: + +.. code-block:: console + + $ curl http://localhost:8000/v1/chat/completions \ + $ -H "Content-Type: application/json" \ + $ -d '{ + $ "model": "facebook/opt-125m", + $ "messages": [ + $ {"role": "system", "content": "You are a helpful assistant."}, + $ {"role": "user", "content": "Who won the world series in 2020?"}, + $ ] + $ }' + +Python Client Example: + +Using the `openai` python package, you can also communicate with the model in a chat-like manner: + +.. code-block:: python + + import openai + # Set OpenAI's API key and API base to use vLLM's API server. + openai.api_key = "EMPTY" + openai.api_base = "http://localhost:8000/v1" + chat_response = openai.ChatCompletion.create( + model="facebook/opt-125m", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Tell me a joke."}, + ] + ) + print("Chat response:", chat_response) + +For more in-depth examples and advanced features of the chat API, you can refer to the official OpenAI documentation. diff --git a/examples/template_chatml.json b/examples/template_chatml.json new file mode 100644 index 000000000000..c8516b1225f0 --- /dev/null +++ b/examples/template_chatml.json @@ -0,0 +1,3 @@ +{ + "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" +} \ No newline at end of file diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 80d6f271fb39..63a34d348de5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,7 +14,6 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse -from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -31,20 +30,13 @@ from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import random_uuid -try: - import fastchat - from fastchat.conversation import Conversation, SeparatorStyle - from fastchat.model.model_adapter import get_conversation_template - _fastchat_available = True -except ImportError: - _fastchat_available = False - TIMEOUT_KEEP_ALIVE = 5 # seconds logger = init_logger(__name__) served_model = None app = fastapi.FastAPI() engine = None +chat_template = None def create_error_response(status_code: HTTPStatus, @@ -70,50 +62,17 @@ async def check_model(request) -> Optional[JSONResponse]: async def get_gen_prompt(request) -> str: - if not _fastchat_available: - raise ModuleNotFoundError( - "fastchat is not installed. Please install fastchat to use " - "the chat completion and conversation APIs: `$ pip install fschat`" - ) - if version.parse(fastchat.__version__) < version.parse("0.2.23"): - raise ImportError( - f"fastchat version is low. Current version: {fastchat.__version__} " - "Please upgrade fastchat to use: `$ pip install -U fschat`") - - conv = get_conversation_template(request.model) - conv = Conversation( - name=conv.name, - system_template=conv.system_template, - system_message=conv.system_message, - roles=conv.roles, - messages=list(conv.messages), # prevent in-place modification - offset=conv.offset, - sep_style=SeparatorStyle(conv.sep_style), - sep=conv.sep, - sep2=conv.sep2, - stop_str=conv.stop_str, - stop_token_ids=conv.stop_token_ids, - ) - - if isinstance(request.messages, str): - prompt = request.messages + if chat_template is not None: + return tokenizer.apply_chat_template(conversation=request.messages, + chat_template=chat_template, + tokenize=False) + elif tokenizer.chat_template is not None: + return tokenizer.apply_chat_template(conversation=request.messages, + tokenize=False) else: - for message in request.messages: - msg_role = message["role"] - if msg_role == "system": - conv.system_message = message["content"] - elif msg_role == "user": - conv.append_message(conv.roles[0], message["content"]) - elif msg_role == "assistant": - conv.append_message(conv.roles[1], message["content"]) - else: - raise ValueError(f"Unknown role: {msg_role}") - - # Add a blank message for the assistant. - conv.append_message(conv.roles[1], None) - prompt = conv.get_prompt() - - return prompt + raise ValueError("No chat template defined. Please use a tokenizer " + "that includes a chat template, or pass in " + "a jinja template using the --chat-template flag.") async def check_length( @@ -590,6 +549,11 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: help="The model name used in the API. If not " "specified, the model name will be the same as " "the huggingface name.") + parser.add_argument("--chat-template", + type=str, + default=None, + help="The path to the chat template to use " + "with the specified model.") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -609,6 +573,20 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: else: served_model = args.model + if args.chat_template is not None: + with open(args.chat_template, "r") as f: + content = f.read() + try: + # Try to parse as JSON and if chat_template exists, use value + data = json.loads(content) + if "chat_template" in data: + chat_template = data["chat_template"] + else: + chat_template = content + except json.JSONDecodeError: + # If parsing as JSON fails, use the file content as raw text + chat_template = content + engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) engine_model_config = asyncio.run(engine.get_model_config()) @@ -619,6 +597,13 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: tokenizer_mode=engine_args.tokenizer_mode, trust_remote_code=engine_args.trust_remote_code) + if chat_template or tokenizer.chat_template: + logger.info( + f"Chat template:\n{chat_template or tokenizer.chat_template}") + else: + logger.warning( + "No chat template loaded, the chat endpoint will be disabled.") + uvicorn.run(app, host=args.host, port=args.port, From a2f6df12e04706a2dae3841f623992672bb1a553 Mon Sep 17 00:00:00 2001 From: Adam Brusselback Date: Fri, 27 Oct 2023 23:45:01 -0400 Subject: [PATCH 2/4] Change variable name to fix pylint error. --- vllm/entrypoints/openai/api_server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 63a34d348de5..b5bd3d2609b1 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -578,9 +578,9 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: content = f.read() try: # Try to parse as JSON and if chat_template exists, use value - data = json.loads(content) - if "chat_template" in data: - chat_template = data["chat_template"] + json_data = json.loads(content) + if "chat_template" in json_data: + chat_template = json_data["chat_template"] else: chat_template = content except json.JSONDecodeError: From 9ca35c12f07e72cbb9ca8bc0cb4636ff31677ad5 Mon Sep 17 00:00:00 2001 From: Adam Brusselback Date: Tue, 14 Nov 2023 05:28:20 -0500 Subject: [PATCH 3/4] Added 'add_generation_prompt' Request Parameter (Default: True): This parameter controls whether the prompt ends with tokens indicating the start of an assistant message. When set to False (and the template/model supports it) the model should complete the last response in the list. By default, it maintains compatibility with OpenAI's API behavior. Fixed Role Determination in Responses: Resolved an issue where the role in responses defaulted to "assistant" regardless of context. This fix ensures the response role aligns with the intended conversational participant, enhancing the API's versatility in various chat scenarios. Introduced 'return_full_response' Request Parameter (Default: False): This parameter, when set to True, negates the need for client-side response merging. It simplifies client integration by providing complete responses even when the client had started the response for the model. --- vllm/entrypoints/openai/api_server.py | 49 ++++++++++++++++++++------- vllm/entrypoints/openai/protocol.py | 2 ++ 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7f17178ed7f2..d34cceaf2093 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,7 +14,6 @@ from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse, Response -from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -64,12 +63,16 @@ async def check_model(request) -> Optional[JSONResponse]: async def get_gen_prompt(request) -> str: if chat_template is not None: - return tokenizer.apply_chat_template(conversation=request.messages, - chat_template=chat_template, - tokenize=False) + return tokenizer.apply_chat_template( + conversation=request.messages, + chat_template=chat_template, + tokenize=False, + add_generation_prompt=request.add_generation_prompt) elif tokenizer.chat_template is not None: - return tokenizer.apply_chat_template(conversation=request.messages, - tokenize=False) + return tokenizer.apply_chat_template( + conversation=request.messages, + tokenize=False, + add_generation_prompt=request.add_generation_prompt) else: raise ValueError("No chat template defined. Please use a tokenizer " "that includes a chat template, or pass in " @@ -201,14 +204,20 @@ async def create_chat_completion(request: ChatCompletionRequest, result_generator = engine.generate(prompt, sampling_params, request_id, token_ids) + def get_role() -> str: + if request.add_generation_prompt: + return "assistant" + else: + return request.messages[-1]["role"] + def create_stream_response_json( - index: int, - text: str, - finish_reason: Optional[str] = None, - ) -> str: + index: int, + text: str, + role: str, + finish_reason: Optional[str] = None) -> str: choice_data = ChatCompletionResponseStreamChoice( index=index, - delta=DeltaMessage(content=text), + delta=DeltaMessage(role=role, content=text), finish_reason=finish_reason, ) response = ChatCompletionStreamResponse( @@ -223,10 +232,11 @@ def create_stream_response_json( async def completion_stream_generator() -> AsyncGenerator[str, None]: # First chunk with role + role = get_role() for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( index=i, - delta=DeltaMessage(role="assistant"), + delta=DeltaMessage(role=role), finish_reason=None, ) chunk = ChatCompletionStreamResponse(id=request_id, @@ -246,12 +256,14 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: previous_num_tokens[i] = len(output.token_ids) response_json = create_stream_response_json( index=i, + role=role, text=delta_text, ) yield f"data: {response_json}\n\n" if output.finish_reason is not None: response_json = create_stream_response_json( index=i, + role=role, text="", finish_reason=output.finish_reason, ) @@ -274,14 +286,25 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: final_res = res assert final_res is not None choices = [] + role = get_role() for output in final_res.outputs: choice_data = ChatCompletionResponseChoice( index=output.index, - message=ChatMessage(role="assistant", content=output.text), + message=ChatMessage(role=role, content=output.text), finish_reason=output.finish_reason, ) choices.append(choice_data) + if request.return_full_response: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get("role") == role: + last_msg_content = request.messages[-1]["content"] + + for choice in choices: + choice.message.content = last_msg_content + choice.message.content + num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7700c5dd483e..8234d3b45abf 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -73,6 +73,8 @@ class ChatCompletionRequest(BaseModel): stop_token_ids: Optional[List[int]] = Field(default_factory=list) skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True + add_generation_prompt: Optional[bool] = True + return_full_response: Optional[bool] = False class CompletionRequest(BaseModel): From f159a628634501ac7241a5ccaacac1570141163e Mon Sep 17 00:00:00 2001 From: Adam Brusselback Date: Wed, 15 Nov 2023 00:18:12 -0500 Subject: [PATCH 4/4] Renamed `return_full_response` to `echo`, and made it also work with streaming responses after testing how it worked with the regular OpenAI completion API. Fixed inconsistencies with official OpenAI API, and what we were returning for streaming responses for the chat completion API. Added error handling so if there is an issue with applying the template, it is reported to the user through an API error, and logged. --- vllm/entrypoints/openai/api_server.py | 164 +++++++++++++++----------- vllm/entrypoints/openai/protocol.py | 2 +- 2 files changed, 95 insertions(+), 71 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d34cceaf2093..e3bb20116e5d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,6 +3,7 @@ import argparse import asyncio +import codecs import json import time from http import HTTPStatus @@ -62,21 +63,13 @@ async def check_model(request) -> Optional[JSONResponse]: async def get_gen_prompt(request) -> str: - if chat_template is not None: - return tokenizer.apply_chat_template( - conversation=request.messages, - chat_template=chat_template, - tokenize=False, - add_generation_prompt=request.add_generation_prompt) - elif tokenizer.chat_template is not None: + try: return tokenizer.apply_chat_template( conversation=request.messages, tokenize=False, add_generation_prompt=request.add_generation_prompt) - else: - raise ValueError("No chat template defined. Please use a tokenizer " - "that includes a chat template, or pass in " - "a jinja template using the --chat-template flag.") + except Exception as e: + raise RuntimeError(f"Error generating prompt: {str(e)}") from e async def check_length( @@ -172,7 +165,12 @@ async def create_chat_completion(request: ChatCompletionRequest, return create_error_response(HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported") - prompt = await get_gen_prompt(request) + try: + prompt = await get_gen_prompt(request) + except RuntimeError as e: + logger.error(f"Error in generating prompt from request: {str(e)}") + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) + token_ids, error_check_ret = await check_length(request, prompt=prompt) if error_check_ret is not None: return error_check_ret @@ -180,6 +178,7 @@ async def create_chat_completion(request: ChatCompletionRequest, model_name = request.model request_id = f"cmpl-{random_uuid()}" created_time = int(time.monotonic()) + obj_str = "chat.completion.chunk" try: spaces_between_special_tokens = request.spaces_between_special_tokens sampling_params = SamplingParams( @@ -210,64 +209,77 @@ def get_role() -> str: else: return request.messages[-1]["role"] - def create_stream_response_json( - index: int, - text: str, - role: str, - finish_reason: Optional[str] = None) -> str: - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(role=role, content=text), - finish_reason=finish_reason, - ) - response = ChatCompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - ) - response_json = response.json(ensure_ascii=False) - - return response_json - async def completion_stream_generator() -> AsyncGenerator[str, None]: - # First chunk with role role = get_role() for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( - index=i, - delta=DeltaMessage(role=role), - finish_reason=None, - ) + index=i, delta=DeltaMessage(role=role), finish_reason=None) chunk = ChatCompletionStreamResponse(id=request_id, + object=obj_str, + created=created_time, choices=[choice_data], model=model_name) data = chunk.json(exclude_unset=True, ensure_ascii=False) yield f"data: {data}\n\n" + # Handle echoing of the first input + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + if last_msg_content: + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + finish_reason=None) + chunk = ChatCompletionStreamResponse(id=request_id, + object=obj_str, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + previous_texts = [""] * request.n previous_num_tokens = [0] * request.n + finish_reason_sent = [False] * request.n async for res in result_generator: res: RequestOutput for output in res.outputs: i = output.index - delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - response_json = create_stream_response_json( - index=i, - role=role, - text=delta_text, - ) - yield f"data: {response_json}\n\n" - if output.finish_reason is not None: - response_json = create_stream_response_json( + if output.finish_reason is None and not finish_reason_sent[i]: + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + choice_data = ChatCompletionResponseStreamChoice( index=i, - role=role, - text="", - finish_reason=output.finish_reason, - ) - yield f"data: {response_json}\n\n" + delta=DeltaMessage(content=delta_text), + finish_reason=None) + chunk = ChatCompletionStreamResponse(id=request_id, + object=obj_str, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, ensure_ascii=False) + yield f"data: {data}\n\n" + if output.finish_reason is not None and not finish_reason_sent[ + i]: + choice_data = ChatCompletionResponseStreamChoice( + index=i, delta=[], finish_reason=output.finish_reason) + chunk = ChatCompletionStreamResponse(id=request_id, + object=obj_str, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.json(exclude_unset=True, + exclude_none=True, + ensure_ascii=False) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True yield "data: [DONE]\n\n" # Streaming response @@ -295,7 +307,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: ) choices.append(choice_data) - if request.return_full_response: + if request.echo: last_msg_content = "" if request.messages and isinstance( request.messages, list) and request.messages[-1].get( @@ -586,8 +598,9 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: parser.add_argument("--chat-template", type=str, default=None, - help="The path to the chat template to use " - "with the specified model.") + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model") parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -608,18 +621,26 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: served_model = args.model if args.chat_template is not None: - with open(args.chat_template, "r") as f: - content = f.read() - try: - # Try to parse as JSON and if chat_template exists, use value - json_data = json.loads(content) - if "chat_template" in json_data: - chat_template = json_data["chat_template"] - else: + try: + with open(args.chat_template, "r") as f: + content = f.read() + try: + json_data = json.loads(content) + if "chat_template" in json_data: + chat_template = json_data["chat_template"] + else: + chat_template = content + except json.JSONDecodeError: + # If JSON fails, use the file content as raw text chat_template = content + except OSError as e: + try: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + chat_template = codecs.decode(args.chat_template, + "unicode_escape") except json.JSONDecodeError: - # If parsing as JSON fails, use the file content as raw text - chat_template = content + logger.error("Unable to set template.") engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args(engine_args) @@ -631,12 +652,15 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: tokenizer_mode=engine_args.tokenizer_mode, trust_remote_code=engine_args.trust_remote_code) - if chat_template or tokenizer.chat_template: - logger.info( - f"Chat template:\n{chat_template or tokenizer.chat_template}") + if chat_template is not None: + tokenizer.chat_template = chat_template + + tmp_template = tokenizer.chat_template or tokenizer.default_chat_template + if tmp_template: + logger.info(f"Chat template:\n{tmp_template}") else: logger.warning( - "No chat template loaded, the chat endpoint will be disabled.") + "No chat template loaded, the chat endpoint will be not work.") uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8234d3b45abf..6c97ce87aa67 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -74,7 +74,7 @@ class ChatCompletionRequest(BaseModel): skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True add_generation_prompt: Optional[bool] = True - return_full_response: Optional[bool] = False + echo: Optional[bool] = False class CompletionRequest(BaseModel):