diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 1b5f4a42a..521d4491e 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -419,7 +419,7 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] yield {"chunk_type": "message_start"} content_started = False - current_tool_calls: dict[str, dict[str, str]] = {} + tool_calls: dict[str, list[Any]] = {} accumulated_text = "" async for chunk in stream_response: @@ -440,24 +440,23 @@ async def stream(self, request: dict[str, Any]) -> AsyncGenerator[dict[str, Any] if hasattr(delta, "tool_calls") and delta.tool_calls: for tool_call in delta.tool_calls: tool_id = tool_call.id + tool_calls.setdefault(tool_id, []).append(tool_call) - if tool_id not in current_tool_calls: - yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_call} - current_tool_calls[tool_id] = {"name": tool_call.function.name, "arguments": ""} + if hasattr(choice, "finish_reason") and choice.finish_reason: + if content_started: + yield {"chunk_type": "content_stop", "data_type": "text"} + + for tool_deltas in tool_calls.values(): + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} - if hasattr(tool_call.function, "arguments"): - current_tool_calls[tool_id]["arguments"] += tool_call.function.arguments + for tool_delta in tool_deltas: + if hasattr(tool_delta.function, "arguments"): yield { "chunk_type": "content_delta", "data_type": "tool", - "data": tool_call.function.arguments, + "data": tool_delta.function.arguments, } - if hasattr(choice, "finish_reason") and choice.finish_reason: - if content_started: - yield {"chunk_type": "content_stop", "data_type": "text"} - - for _ in current_tool_calls: yield {"chunk_type": "content_stop", "data_type": "tool"} yield {"chunk_type": "message_stop", "data": choice.finish_reason} diff --git a/tests-integ/test_model_mistral.py b/tests-integ/test_model_mistral.py index 62a20fff7..1803dc5c7 100644 --- a/tests-integ/test_model_mistral.py +++ b/tests-integ/test_model_mistral.py @@ -78,41 +78,32 @@ class Weather(BaseModel): @pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") def test_agent_invoke(agent): - # TODO: https://github.com/strands-agents/sdk-python/issues/374 - # result = streaming_agent("What is the time and weather in New York?") - result = agent("What is the time in New York?") + result = agent("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() - # assert all(string in text for string in ["12:00", "sunny"]) - assert all(string in text for string in ["12:00"]) + assert all(string in text for string in ["12:00", "sunny"]) @pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") @pytest.mark.asyncio async def test_agent_invoke_async(agent): - # TODO: https://github.com/strands-agents/sdk-python/issues/374 - # result = await streaming_agent.invoke_async("What is the time and weather in New York?") - result = await agent.invoke_async("What is the time in New York?") + result = await agent.invoke_async("What is the time and weather in New York?") text = result.message["content"][0]["text"].lower() - # assert all(string in text for string in ["12:00", "sunny"]) - assert all(string in text for string in ["12:00"]) + assert all(string in text for string in ["12:00", "sunny"]) @pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing") @pytest.mark.asyncio async def test_agent_stream_async(agent): - # TODO: https://github.com/strands-agents/sdk-python/issues/374 - # stream = streaming_agent.stream_async("What is the time and weather in New York?") - stream = agent.stream_async("What is the time in New York?") + stream = agent.stream_async("What is the time and weather in New York?") async for event in stream: _ = event result = event["result"] text = result.message["content"][0]["text"].lower() - # assert all(string in text for string in ["12:00", "sunny"]) - assert all(string in text for string in ["12:00"]) + assert all(string in text for string in ["12:00", "sunny"]) @pytest.mark.skipif("MISTRAL_API_KEY" not in os.environ, reason="MISTRAL_API_KEY environment variable missing")