Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/strands/tools/mcp/mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
It allows MCP tools to be seamlessly integrated and used within the agent ecosystem.
"""

import asyncio
import logging
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -87,8 +86,7 @@ async def stream(self, tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolGenerat
"""
logger.debug("tool_name=<%s>, tool_use_id=<%s> | streaming", self.tool_name, tool_use["toolUseId"])

result = await asyncio.to_thread(
self.mcp_client.call_tool_sync,
result = await self.mcp_client.call_tool_async(
tool_use_id=tool_use["toolUseId"],
name=self.tool_name,
arguments=tool_use["input"],
Expand Down
89 changes: 65 additions & 24 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def stop(
async def _set_close_event() -> None:
self._close_event.set()

self._invoke_on_background_thread(_set_close_event())
self._invoke_on_background_thread(_set_close_event()).result()
self._log_debug_with_thread("waiting for background thread to join")
if self._background_thread is not None:
self._background_thread.join()
Expand Down Expand Up @@ -156,7 +156,7 @@ def list_tools_sync(self) -> List[MCPAgentTool]:
async def _list_tools_async() -> ListToolsResult:
return await self._background_thread_session.list_tools()

list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async())
list_tools_response: ListToolsResult = self._invoke_on_background_thread(_list_tools_async()).result()
self._log_debug_with_thread("received %d tools from MCP server", len(list_tools_response.tools))

mcp_tools = [MCPAgentTool(tool, self) for tool in list_tools_response.tools]
Expand Down Expand Up @@ -192,25 +192,68 @@ async def _call_tool_async() -> MCPCallToolResult:
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)

try:
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async())
self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content))

mapped_content = [
mapped_content
for content in call_tool_result.content
if (mapped_content := self._map_mcp_content_to_tool_result_content(content)) is not None
]

status: ToolResultStatus = "error" if call_tool_result.isError else "success"
self._log_debug_with_thread("tool execution completed with status: %s", status)
return ToolResult(status=status, toolUseId=tool_use_id, content=mapped_content)
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(_call_tool_async()).result()
return self._handle_tool_result(tool_use_id, call_tool_result)
except Exception as e:
logger.warning("tool execution failed: %s", str(e), exc_info=True)
return ToolResult(
status="error",
toolUseId=tool_use_id,
content=[{"text": f"Tool execution failed: {str(e)}"}],
)
logger.exception("tool execution failed")
return self._handle_tool_execution_error(tool_use_id, e)

async def call_tool_async(
self,
tool_use_id: str,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
) -> ToolResult:
"""Asynchronously calls a tool on the MCP server.

This method calls the asynchronous call_tool method on the MCP session
and converts the result to the ToolResult format.

Args:
tool_use_id: Unique identifier for this tool use
name: Name of the tool to call
arguments: Optional arguments to pass to the tool
read_timeout_seconds: Optional timeout for the tool call

Returns:
ToolResult: The result of the tool call
"""
self._log_debug_with_thread("calling MCP tool '%s' asynchronously with tool_use_id=%s", name, tool_use_id)
if not self._is_session_active():
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

async def _call_tool_async() -> MCPCallToolResult:
return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)

try:
future = self._invoke_on_background_thread(_call_tool_async())
call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future)
return self._handle_tool_result(tool_use_id, call_tool_result)
except Exception as e:
logger.exception("tool execution failed")
return self._handle_tool_execution_error(tool_use_id, e)

def _handle_tool_execution_error(self, tool_use_id: str, exception: Exception) -> ToolResult:
"""Create error ToolResult with consistent logging."""
return ToolResult(
status="error",
toolUseId=tool_use_id,
content=[{"text": f"Tool execution failed: {str(exception)}"}],
)

def _handle_tool_result(self, tool_use_id: str, call_tool_result: MCPCallToolResult) -> ToolResult:
self._log_debug_with_thread("received tool result with %d content items", len(call_tool_result.content))

mapped_content = [
mapped_content
for content in call_tool_result.content
if (mapped_content := self._map_mcp_content_to_tool_result_content(content)) is not None
]

status: ToolResultStatus = "error" if call_tool_result.isError else "success"
self._log_debug_with_thread("tool execution completed with status: %s", status)
return ToolResult(status=status, toolUseId=tool_use_id, content=mapped_content)

async def _async_background_thread(self) -> None:
"""Asynchronous method that runs in the background thread to manage the MCP connection.
Expand Down Expand Up @@ -296,12 +339,10 @@ def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None:
"[Thread: %s, Session: %s] %s", threading.current_thread().name, self._session_id, formatted_msg, **kwargs
)

def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> T:
def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures.Future[T]:
if self._background_thread_session is None or self._background_thread_event_loop is None:
raise MCPClientInitializationError("the client session was not initialized")

future = asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop)
return future.result()
return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop)

def _is_session_active(self) -> bool:
return self._background_thread is not None and self._background_thread.is_alive()
4 changes: 2 additions & 2 deletions tests/strands/tools/mcp/test_mcp_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist):
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}

tru_events = await alist(mcp_agent_tool.stream(tool_use, {}))
exp_events = [mock_mcp_client.call_tool_sync.return_value]
exp_events = [mock_mcp_client.call_tool_async.return_value]

assert tru_events == exp_events
mock_mcp_client.call_tool_sync.assert_called_once_with(
mock_mcp_client.call_tool_async.assert_called_once_with(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}
)
149 changes: 149 additions & 0 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,155 @@ def test_call_tool_sync_exception(mock_transport, mock_session):
assert "Test exception" in result["content"][0]["text"]


@pytest.mark.asyncio
@pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")])
async def test_call_tool_async_status(mock_transport, mock_session, is_error, expected_status):
"""Test that call_tool_async correctly handles success and error results."""
mock_content = MCPTextContent(type="text", text="Test message")
mock_result = MCPCallToolResult(isError=is_error, content=[mock_content])
mock_session.call_tool.return_value = mock_result

with MCPClient(mock_transport["transport_callable"]) as client:
# Mock asyncio.run_coroutine_threadsafe and asyncio.wrap_future
with (
patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe,
patch("asyncio.wrap_future") as mock_wrap_future,
):
# Create a mock future that returns the mock result
mock_future = MagicMock()
mock_run_coroutine_threadsafe.return_value = mock_future

# Create an async mock that resolves to the mock result
async def mock_awaitable():
return mock_result

mock_wrap_future.return_value = mock_awaitable()

result = await client.call_tool_async(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}
)

# Verify the asyncio functions were called correctly
mock_run_coroutine_threadsafe.assert_called_once()
mock_wrap_future.assert_called_once_with(mock_future)

assert result["status"] == expected_status
assert result["toolUseId"] == "test-123"
assert len(result["content"]) == 1
assert result["content"][0]["text"] == "Test message"


@pytest.mark.asyncio
async def test_call_tool_async_session_not_active():
"""Test that call_tool_async raises an error when session is not active."""
client = MCPClient(MagicMock())

with pytest.raises(MCPClientInitializationError, match="client.session is not running"):
await client.call_tool_async(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})


@pytest.mark.asyncio
async def test_call_tool_async_exception(mock_transport, mock_session):
"""Test that call_tool_async correctly handles exceptions."""
with MCPClient(mock_transport["transport_callable"]) as client:
# Mock asyncio.run_coroutine_threadsafe to raise an exception
with patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe:
mock_run_coroutine_threadsafe.side_effect = Exception("Test exception")

result = await client.call_tool_async(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}
)

assert result["status"] == "error"
assert result["toolUseId"] == "test-123"
assert len(result["content"]) == 1
assert "Test exception" in result["content"][0]["text"]


@pytest.mark.asyncio
async def test_call_tool_async_with_timeout(mock_transport, mock_session):
"""Test that call_tool_async correctly passes timeout parameter."""
from datetime import timedelta

mock_content = MCPTextContent(type="text", text="Test message")
mock_result = MCPCallToolResult(isError=False, content=[mock_content])
mock_session.call_tool.return_value = mock_result

with MCPClient(mock_transport["transport_callable"]) as client:
timeout = timedelta(seconds=30)

with (
patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe,
patch("asyncio.wrap_future") as mock_wrap_future,
):
mock_future = MagicMock()
mock_run_coroutine_threadsafe.return_value = mock_future

# Create an async mock that resolves to the mock result
async def mock_awaitable():
return mock_result

mock_wrap_future.return_value = mock_awaitable()

result = await client.call_tool_async(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout
)

# Verify the timeout was passed to the session call_tool method
# We need to check that the coroutine passed to run_coroutine_threadsafe
# would call session.call_tool with the timeout
mock_run_coroutine_threadsafe.assert_called_once()
mock_wrap_future.assert_called_once_with(mock_future)

assert result["status"] == "success"
assert result["toolUseId"] == "test-123"


@pytest.mark.asyncio
async def test_call_tool_async_initialization_not_complete():
"""Test that call_tool_async returns error result when background thread is not initialized."""
client = MCPClient(MagicMock())

# Manually set the client state to simulate a partially initialized state
client._background_thread = MagicMock()
client._background_thread.is_alive.return_value = True
client._background_thread_session = None # Not initialized

result = await client.call_tool_async(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})

assert result["status"] == "error"
assert result["toolUseId"] == "test-123"
assert len(result["content"]) == 1
assert "client session was not initialized" in result["content"][0]["text"]


@pytest.mark.asyncio
async def test_call_tool_async_wrap_future_exception(mock_transport, mock_session):
"""Test that call_tool_async correctly handles exceptions from wrap_future."""
with MCPClient(mock_transport["transport_callable"]) as client:
with (
patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe,
patch("asyncio.wrap_future") as mock_wrap_future,
):
mock_future = MagicMock()
mock_run_coroutine_threadsafe.return_value = mock_future

# Create an async mock that raises an exception
async def mock_awaitable():
raise Exception("Wrap future exception")

mock_wrap_future.return_value = mock_awaitable()

result = await client.call_tool_async(
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}
)

assert result["status"] == "error"
assert result["toolUseId"] == "test-123"
assert len(result["content"]) == 1
assert "Wrap future exception" in result["content"][0]["text"]


def test_enter_with_initialization_exception(mock_transport):
"""Test that __enter__ handles exceptions during initialization properly."""
# Make the transport callable throw an exception
Expand Down
Loading