Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/tool_use/test_parallel_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
assert tool_call.type == "function"
assert tool_call.function is not None
assert isinstance(tool_call.id, str)
assert len(tool_call.id) > 16
assert len(tool_call.id) >= 9

# make sure the weather tool was called correctly
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
Expand Down Expand Up @@ -108,7 +108,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
if tool_call.id:
tool_call_id_count += 1
assert (isinstance(tool_call.id, str)
and (len(tool_call.id) > 16))
and (len(tool_call.id) >= 9))

# if parts of the function start being streamed
if tool_call.function:
Expand Down
4 changes: 2 additions & 2 deletions tests/tool_use/test_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
assert tool_calls[0].type == 'function'
assert tool_calls[0].function is not None
assert isinstance(tool_calls[0].id, str)
assert len(tool_calls[0].id) > 16
assert len(tool_calls[0].id) >= 9

# make sure the weather tool was called (classic example) with arguments
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
Expand Down Expand Up @@ -106,7 +106,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):

assert finish_reason_count == 1
assert role_name == 'assistant'
assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)

# validate the name and arguments
assert function_name == WEATHER_TOOL["function"]["name"]
Expand Down
20 changes: 18 additions & 2 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import re
from random import choices
from string import ascii_letters, digits
from typing import Dict, List, Sequence, Union

import partial_json_parser
from partial_json_parser.core.options import Allow
from pydantic import Field

from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
Expand All @@ -19,6 +22,19 @@

logger = init_logger(__name__)

ALPHANUMERIC = ascii_letters + digits


class MistralToolCall(ToolCall):
id: str = Field(
default_factory=lambda: MistralToolCall.generate_random_id())

@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))


class MistralToolParser(ToolParser):
"""
Expand Down Expand Up @@ -71,8 +87,8 @@ def extract_tool_calls(self,
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[ToolCall] = [
ToolCall(
tool_calls: List[MistralToolCall] = [
MistralToolCall(
type="function",
function=FunctionCall(
name=raw_function_call["name"],
Expand Down
Loading