diff --git a/tests/mistral_tool_use/__init__.py b/tests/mistral_tool_use/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/mistral_tool_use/conftest.py b/tests/mistral_tool_use/conftest.py new file mode 100644 index 000000000000..39ab01c9b874 --- /dev/null +++ b/tests/mistral_tool_use/conftest.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import pytest_asyncio +from huggingface_hub import snapshot_download + +from tests.utils import RemoteOpenAIServer +from vllm.platforms import current_platform + +from .utils import ARGS, CONFIGS, ServerConfig + + +# for each server config, download the model and return the config +@pytest.fixture(scope="session", params=CONFIGS.keys()) +def server_config(request): + config = CONFIGS[request.param] + + if current_platform.is_rocm() and not config.get("supports_rocm", True): + pytest.skip("The {} model can't be tested on the ROCm platform".format( + config["model"])) + + # download model and tokenizer using transformers + snapshot_download(config["model"]) + yield CONFIGS[request.param] + + +# run this for each server config +@pytest.fixture(scope="session") +def server(request, server_config: ServerConfig): + model = server_config["model"] + args_for_model = server_config["arguments"] + with RemoteOpenAIServer(model, ARGS + args_for_model, + max_wait_seconds=480) as server: + yield server + + +@pytest_asyncio.fixture +async def client(server: RemoteOpenAIServer): + async with server.get_async_client() as async_client: + yield async_client diff --git a/tests/mistral_tool_use/test_mistral_tool_calls.py b/tests/mistral_tool_use/test_mistral_tool_calls.py new file mode 100644 index 000000000000..bbb3a07895f6 --- /dev/null +++ b/tests/mistral_tool_use/test_mistral_tool_calls.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai +import pytest + +from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL + + +# test: a tool_choice with mistral-tokenizer results in an ID of length 9 +@pytest.mark.asyncio +async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_completion_tokens=100, + model=model_name, + tools=[WEATHER_TOOL], + tool_choice=WEATHER_TOOL, + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 1 + assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral diff --git a/tests/mistral_tool_use/utils.py b/tests/mistral_tool_use/utils.py new file mode 100644 index 000000000000..971ed55ca3c0 --- /dev/null +++ b/tests/mistral_tool_use/utils.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List, Optional + +from typing_extensions import TypedDict + + +class ServerConfig(TypedDict, total=False): + model: str + arguments: List[str] + system_prompt: Optional[str] + supports_parallel: Optional[bool] + supports_rocm: Optional[bool] + + +ARGS: List[str] = ["--max-model-len", "1024"] + +CONFIGS: Dict[str, ServerConfig] = { + "mistral": { + "model": + "mistralai/Mistral-7B-Instruct-v0.3", + "arguments": [ + "--tokenizer-mode", "mistral", + "--ignore-patterns=\"consolidated.safetensors\"" + ], + "system_prompt": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." + }, +} diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 107220d548af..934bd2a95063 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -28,12 +28,15 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( + MistralToolCall) from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls +from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, + truncate_tool_call_ids) logger = init_logger(__name__) @@ -150,11 +153,12 @@ async def create_chat_completion( return self.create_error_response( "tool_choice = \"required\" is not supported!") - # because of issues with pydantic we need to potentially - # re-serialize the tool_calls field of the request - # for more info: see comment in `maybe_serialize_tool_calls` if isinstance(tokenizer, MistralTokenizer): + # because of issues with pydantic we need to potentially + # re-serialize the tool_calls field of the request + # for more info: see comment in `maybe_serialize_tool_calls` maybe_serialize_tool_calls(request) + truncate_tool_call_ids(request) if (request.tool_choice == "auto" and not (self.enable_auto_tools and tool_parser is not None) @@ -745,11 +749,13 @@ async def chat_completion_full_generator( elif request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_call_class = MistralToolCall if isinstance( + tokenizer, MistralTokenizer) else ToolCall message = ChatMessage( role=role, content="", tool_calls=[ - ToolCall(function=FunctionCall( + tool_call_class(function=FunctionCall( name=request.tool_choice.function.name, arguments=output.text)) ]) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 51354f7c9562..4f0480882992 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -33,7 +33,7 @@ class MistralToolCall(ToolCall): @staticmethod def generate_random_id(): - # Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. + # Mistral Tool Call Ids must be alphanumeric with a length of 9. # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 return "".join(choices(ALPHANUMERIC, k=9)) diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index 2b64f3fc7056..c12388d9b20b 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -from .mistral import MistralTokenizer, maybe_serialize_tool_calls +from .mistral import (MistralTokenizer, maybe_serialize_tool_calls, + truncate_tool_call_ids) -__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"] +__all__ = [ + "MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids" +] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 1550f978ed20..bd78b16a9db8 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -67,6 +67,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): request.messages[i]["tool_calls"] = validated_tool_calls +def truncate_tool_call_ids(request: "ChatCompletionRequest"): + """Truncates tool call IDs for Mistral's ID requirements.""" + for i, message in enumerate(request.messages): + if message.get("role") == 'assistant': + tool_calls = message.get("tool_calls", []) + for tool_call in tool_calls: + if len(tool_call["id"]) > 9: + logger.warning( + "Truncating tool call ID: %s to %s", + tool_call["id"], + tool_call["id"][-9:], + ) + tool_call["id"] = tool_call["id"][-9:] + + request.messages[i]["tool_calls"] = tool_calls + + elif message.get("role") in {"tool_results", "tool"}: + if "tool_call_id" in message: + tool_call_id = message["tool_call_id"] + + if len(tool_call_id) > 9: + logger.warning( + "Truncating tool_call_id: %s to %s", + tool_call_id, + tool_call_id[-9:], + ) + tool_call_id = tool_call_id[-9:] + request.messages[i]["tool_call_id"] = tool_call_id + + def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]: repo_cache = os.path.join( huggingface_hub.constants.HF_HUB_CACHE,