Skip to content

Commit 4df9263

Browse files
Jacksunweicopybara-github
authored andcommitted
fix: Returns dict as result from McpTool
The `BaseTool` expects the run_async to return a json-serializable object. By model_dump the McpTool result explicitly can allow what ADK runtime sees is identical to what is persisted in the session event list. Before the change, runtime sees CallToolResult instance and Session persists its serialized dict. https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/types.py#L916-L922 PiperOrigin-RevId: 822465432
1 parent d4dc645 commit 4df9263

File tree

2 files changed

+42
-25
lines changed

2 files changed

+42
-25
lines changed

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(
8080
Callable[[ReadonlyContext], Dict[str, str]]
8181
] = None,
8282
):
83-
"""Initializes an MCPTool.
83+
"""Initializes an McpTool.
8484
8585
This tool wraps an MCP Tool interface and uses a session manager to
8686
communicate with the MCP server.
@@ -186,7 +186,7 @@ async def run_async(
186186
@override
187187
async def _run_async_impl(
188188
self, *, args, tool_context: ToolContext, credential: AuthCredential
189-
):
189+
) -> Dict[str, Any]:
190190
"""Runs the tool asynchronously.
191191
192192
Args:
@@ -217,7 +217,7 @@ async def _run_async_impl(
217217
)
218218

219219
response = await session.call_tool(self._mcp_tool.name, arguments=args)
220-
return response
220+
return response.model_dump(exclude_none=True, mode="json")
221221

222222
async def _get_headers(
223223
self, tool_context: ToolContext, credential: AuthCredential
@@ -282,7 +282,7 @@ async def _get_headers(
282282
!= APIKeyIn.header
283283
):
284284
error_msg = (
285-
"MCPTool only supports header-based API key authentication."
285+
"McpTool only supports header-based API key authentication."
286286
" Configured location:"
287287
f" {self._credentials_manager._auth_config.auth_scheme.in_}"
288288
)

tests/unittests/tools/mcp_tool/test_mcp_tool.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
3737
from google.adk.tools.tool_context import ToolContext
3838
from google.genai.types import FunctionDeclaration
39+
from mcp.types import CallToolResult
40+
from mcp.types import TextContent
3941
except ImportError as e:
4042
if sys.version_info < (3, 10):
4143
# Create dummy classes to prevent NameError during test collection
@@ -47,6 +49,8 @@ class DummyClass:
4749
MCPTool = DummyClass
4850
ToolContext = DummyClass
4951
FunctionDeclaration = DummyClass
52+
CallToolResult = DummyClass
53+
TextContent = DummyClass
5054
else:
5155
raise e
5256

@@ -150,9 +154,11 @@ async def test_run_async_impl_no_auth(self):
150154
mcp_session_manager=self.mock_session_manager,
151155
)
152156

153-
# Mock the session response
154-
expected_response = {"result": "success"}
155-
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
157+
# Mock the session response - must return CallToolResult
158+
mcp_response = CallToolResult(
159+
content=[TextContent(type="text", text="success")]
160+
)
161+
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
156162

157163
tool_context = Mock(spec=ToolContext)
158164
args = {"param1": "test_value"}
@@ -161,7 +167,8 @@ async def test_run_async_impl_no_auth(self):
161167
args=args, tool_context=tool_context, credential=None
162168
)
163169

164-
assert result == expected_response
170+
# Verify the result matches the model_dump output
171+
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
165172
self.mock_session_manager.create_session.assert_called_once_with(
166173
headers=None
167174
)
@@ -184,9 +191,11 @@ async def test_run_async_impl_with_oauth2(self):
184191
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
185192
)
186193

187-
# Mock the session response
188-
expected_response = {"result": "success"}
189-
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
194+
# Mock the session response - must return CallToolResult
195+
mcp_response = CallToolResult(
196+
content=[TextContent(type="text", text="success")]
197+
)
198+
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
190199

191200
tool_context = Mock(spec=ToolContext)
192201
args = {"param1": "test_value"}
@@ -195,7 +204,7 @@ async def test_run_async_impl_with_oauth2(self):
195204
args=args, tool_context=tool_context, credential=credential
196205
)
197206

198-
assert result == expected_response
207+
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
199208
# Check that headers were passed correctly
200209
self.mock_session_manager.create_session.assert_called_once()
201210
call_args = self.mock_session_manager.create_session.call_args
@@ -322,7 +331,7 @@ async def test_get_headers_api_key_with_query_scheme_raises_error(self):
322331

323332
with pytest.raises(
324333
ValueError,
325-
match="MCPTool only supports header-based API key authentication",
334+
match="McpTool only supports header-based API key authentication",
326335
):
327336
await tool._get_headers(tool_context, auth_credential)
328337

@@ -354,7 +363,7 @@ async def test_get_headers_api_key_with_cookie_scheme_raises_error(self):
354363

355364
with pytest.raises(
356365
ValueError,
357-
match="MCPTool only supports header-based API key authentication",
366+
match="McpTool only supports header-based API key authentication",
358367
):
359368
await tool._get_headers(tool_context, auth_credential)
360369

@@ -460,9 +469,11 @@ async def test_run_async_impl_with_api_key_header_auth(self):
460469
auth_credential=auth_credential,
461470
)
462471

463-
# Mock the session response
464-
expected_response = {"result": "authenticated_success"}
465-
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
472+
# Mock the session response - must return CallToolResult
473+
mcp_response = CallToolResult(
474+
content=[TextContent(type="text", text="authenticated_success")]
475+
)
476+
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
466477

467478
tool_context = Mock(spec=ToolContext)
468479
args = {"param1": "test_value"}
@@ -471,7 +482,7 @@ async def test_run_async_impl_with_api_key_header_auth(self):
471482
args=args, tool_context=tool_context, credential=auth_credential
472483
)
473484

474-
assert result == expected_response
485+
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
475486
# Check that headers were passed correctly with custom API key header
476487
self.mock_session_manager.create_session.assert_called_once()
477488
call_args = self.mock_session_manager.create_session.call_args
@@ -545,7 +556,7 @@ async def test_get_headers_api_key_error_logging(self):
545556
mock_logger.error.assert_called_once()
546557
logged_message = mock_logger.error.call_args[0][0]
547558
assert (
548-
"MCPTool only supports header-based API key authentication"
559+
"McpTool only supports header-based API key authentication"
549560
in logged_message
550561
)
551562

@@ -652,8 +663,11 @@ async def test_run_async_impl_with_header_provider_no_auth(self):
652663
header_provider=header_provider,
653664
)
654665

655-
expected_response = {"result": "success"}
656-
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
666+
# Mock the session response - must return CallToolResult
667+
mcp_response = CallToolResult(
668+
content=[TextContent(type="text", text="success")]
669+
)
670+
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
657671

658672
tool_context = Mock(spec=ToolContext)
659673
tool_context._invocation_context = Mock()
@@ -663,7 +677,7 @@ async def test_run_async_impl_with_header_provider_no_auth(self):
663677
args=args, tool_context=tool_context, credential=None
664678
)
665679

666-
assert result == expected_response
680+
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
667681
header_provider.assert_called_once()
668682
self.mock_session_manager.create_session.assert_called_once_with(
669683
headers=expected_headers
@@ -688,8 +702,11 @@ async def test_run_async_impl_with_header_provider_and_oauth2(self):
688702
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
689703
)
690704

691-
expected_response = {"result": "success"}
692-
self.mock_session.call_tool = AsyncMock(return_value=expected_response)
705+
# Mock the session response - must return CallToolResult
706+
mcp_response = CallToolResult(
707+
content=[TextContent(type="text", text="success")]
708+
)
709+
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
693710

694711
tool_context = Mock(spec=ToolContext)
695712
tool_context._invocation_context = Mock()
@@ -699,7 +716,7 @@ async def test_run_async_impl_with_header_provider_and_oauth2(self):
699716
args=args, tool_context=tool_context, credential=credential
700717
)
701718

702-
assert result == expected_response
719+
assert result == mcp_response.model_dump(exclude_none=True, mode="json")
703720
header_provider.assert_called_once()
704721
self.mock_session_manager.create_session.assert_called_once()
705722
call_args = self.mock_session_manager.create_session.call_args

0 commit comments

Comments
 (0)