From 87a5b46ac6d39772f568105e4180a727ef4efd5f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 25 Apr 2025 10:47:36 +0000 Subject: [PATCH] [Bugfix] Fix mistral model tests Signed-off-by: DarkLight1337 --- .../decoder_only/language/test_mistral.py | 64 +++++++++++-------- .../tool_parsers/mistral_tool_parser.py | 4 ++ 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index ec885386dd94..79778072cc8b 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -10,8 +10,8 @@ import jsonschema.exceptions import pytest -from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa - MistralToolParser) +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( + MistralToolCall, MistralToolParser) from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...utils import check_logprobs_close @@ -194,7 +194,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, ) -@pytest.mark.skip("RE-ENABLE: test is currently failing on main.") @pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) @@ -246,10 +245,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, assert "�" not in outputs[0].outputs[0].text.strip() -@pytest.mark.skip("RE-ENABLE: test is currently failing on main.") +@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("model", - MISTRAL_FORMAT_MODELS) # v1 can't do func calling def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: with vllm_runner(model, dtype=dtype, @@ -270,7 +267,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: parsed_message = tool_parser.extract_tool_calls(model_output, None) assert parsed_message.tools_called - assert parsed_message.tool_calls[0].id == "0UAqFzWsD" + + assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id) assert parsed_message.tool_calls[ 0].function.name == "get_current_weather" assert parsed_message.tool_calls[ @@ -281,28 +279,38 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("guided_backend", ["outlines", "lm-format-enforcer", "xgrammar"]) -def test_mistral_guided_decoding(vllm_runner, model: str, - guided_backend: str) -> None: - with vllm_runner(model, dtype='bfloat16', - tokenizer_mode="mistral") as vllm_model: +def test_mistral_guided_decoding( + monkeypatch: pytest.MonkeyPatch, + vllm_runner, + model: str, + guided_backend: str, +) -> None: + with monkeypatch.context() as m: + # Guided JSON not supported in xgrammar + V1 yet + m.setenv("VLLM_USE_V1", "0") - guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA, - backend=guided_backend) - params = SamplingParams(max_tokens=512, - temperature=0.7, - guided_decoding=guided_decoding) - - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": - "user", - "content": - f"Give an example JSON for an employee profile that " - f"fits this schema: {SAMPLE_JSON_SCHEMA}" - }] - outputs = vllm_model.model.chat(messages, sampling_params=params) + with vllm_runner( + model, + dtype='bfloat16', + tokenizer_mode="mistral", + guided_decoding_backend=guided_backend, + ) as vllm_model: + guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA) + params = SamplingParams(max_tokens=512, + temperature=0.7, + guided_decoding=guided_decoding) + + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": + "user", + "content": + f"Give an example JSON for an employee profile that " + f"fits this schema: {SAMPLE_JSON_SCHEMA}" + }] + outputs = vllm_model.model.chat(messages, sampling_params=params) generated_text = outputs[0].outputs[0].text json_response = json.loads(generated_text) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index f0000daa0a41..9dbfe85ecc68 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -38,6 +38,10 @@ def generate_random_id(): # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 return "".join(choices(ALPHANUMERIC, k=9)) + @staticmethod + def is_valid_id(id: str) -> bool: + return id.isalnum() and len(id) == 9 + @ToolParserManager.register_module("mistral") class MistralToolParser(ToolParser):