diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index efbffc7f25..c10ea6203f 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -218,7 +218,12 @@ async def _run_async_impl( headers=final_headers ) - response = await session.call_tool(self._mcp_tool.name, arguments=args) + # Transform arguments to match MCP schema + transformed_args = self._transform_args_to_mcp_format( + args, self._mcp_tool.inputSchema + ) + + response = await session.call_tool(self._mcp_tool.name, arguments=transformed_args) return response.model_dump(exclude_none=True, mode="json") async def _get_headers( @@ -305,6 +310,78 @@ async def _get_headers( return headers + def _transform_args_to_mcp_format( + self, args: Dict[str, Any], mcp_schema: Dict[str, Any] + ) -> Dict[str, Any]: + """Transform arguments to match MCP schema. + + Handles cases where model output simplifies array-of-objects to + array-of-primitives for schemas with single-property objects. + + Args: + args: Tool arguments from model output. + mcp_schema: MCP tool input schema. + + Returns: + Transformed arguments matching schema, or original if no transformation needed. + """ + if not args or not mcp_schema: + return args + + properties = mcp_schema.get("properties", {}) + if not properties: + return args + + transformed = {} + for key, value in args.items(): + if key in properties: + transformed[key] = self._transform_value_to_schema(value, properties[key]) + else: + transformed[key] = value + + return transformed + + def _transform_value_to_schema( + self, value: Any, schema: Dict[str, Any] + ) -> Any: + """Transform value to match schema. + + Args: + value: Value to transform. + schema: JSON schema for the value. + + Returns: + Transformed value or original if no transformation needed. + """ + if value is None or not schema: + return value + + schema_type = schema.get("type") + + if schema_type == "array" and isinstance(value, list) and value: + items_schema = schema.get("items") + if not items_schema or items_schema.get("type") != "object": + return value + + if not isinstance(value[0], dict): + if not all(not isinstance(item, dict) for item in value): + logger.warning( + "Mixed types in array for MCP tool %s", self.name + ) + return value + + item_properties = items_schema.get("properties", {}) + if len(item_properties) == 1: + property_name = next(iter(item_properties)) + logger.debug( + "Transforming array for MCP tool %s with property '%s'", + self.name, + property_name, + ) + return [{property_name: item} for item in value] + + return value + class MCPTool(McpTool): """Deprecated name, use `McpTool` instead.""" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index c1fbb5bc63..fcfe538880 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -732,3 +732,166 @@ async def test_run_async_impl_with_header_provider_and_oauth2(self): self.mock_session.call_tool.assert_called_once_with( "test_tool", arguments=args ) + + def test_transform_args_simple_types(self): + """Test that simple types are not transformed.""" + schema = { + "properties": { + "name": {"type": "string"}, + "count": {"type": "integer"}, + } + } + self.mock_mcp_tool.inputSchema = schema + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + args = {"name": "test", "count": 42} + result = tool._transform_args_to_mcp_format(args, schema) + + assert result == args + + def test_transform_args_array_of_primitives_to_objects(self): + """Test transformation of array of primitives to array of objects.""" + schema = { + "properties": { + "sources": { + "type": "array", + "items": { + "type": "object", + "properties": {"type": {"type": "string"}}, + }, + } + } + } + self.mock_mcp_tool.inputSchema = schema + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + args = {"sources": ["web", "images"]} + result = tool._transform_args_to_mcp_format(args, schema) + + expected = {"sources": [{"type": "web"}, {"type": "images"}]} + assert result == expected + + def test_transform_args_already_correct_format(self): + """Test that already correct format is not modified.""" + schema = { + "properties": { + "sources": { + "type": "array", + "items": { + "type": "object", + "properties": {"type": {"type": "string"}}, + }, + } + } + } + self.mock_mcp_tool.inputSchema = schema + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + args = {"sources": [{"type": "web"}, {"type": "images"}]} + result = tool._transform_args_to_mcp_format(args, schema) + + assert result == args + + def test_transform_args_empty_array(self): + """Test transformation with empty array.""" + schema = { + "properties": { + "sources": { + "type": "array", + "items": { + "type": "object", + "properties": {"type": {"type": "string"}}, + }, + } + } + } + self.mock_mcp_tool.inputSchema = schema + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + args = {"sources": []} + result = tool._transform_args_to_mcp_format(args, schema) + + assert result == {"sources": []} + + def test_transform_args_no_transformation_for_multi_property_objects(self): + """Test that objects with multiple properties are not transformed.""" + schema = { + "properties": { + "filters": { + "type": "array", + "items": { + "type": "object", + "properties": { + "field": {"type": "string"}, + "value": {"type": "string"}, + }, + }, + } + } + } + self.mock_mcp_tool.inputSchema = schema + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + args = {"filters": ["field1", "field2"]} + result = tool._transform_args_to_mcp_format(args, schema) + + assert result == args + + @pytest.mark.asyncio + async def test_run_async_impl_transforms_args(self): + """Test that _run_async_impl applies argument transformation.""" + schema = { + "properties": { + "sources": { + "type": "array", + "items": { + "type": "object", + "properties": {"type": {"type": "string"}}, + }, + } + } + } + self.mock_mcp_tool.inputSchema = schema + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + ) + + mcp_response = CallToolResult( + content=[TextContent(type="text", text="success")] + ) + self.mock_session.call_tool = AsyncMock(return_value=mcp_response) + + tool_context = Mock(spec=ToolContext) + credential = None + + # Model generates simplified format + args = {"sources": ["web", "images"]} + + result = await tool._run_async_impl( + args=args, tool_context=tool_context, credential=credential + ) + + # Verify transformation was applied before calling MCP tool + self.mock_session.call_tool.assert_called_once() + call_args = self.mock_session.call_tool.call_args + transformed_args = call_args[1]["arguments"] + + # Should be transformed to correct format + assert transformed_args == {"sources": [{"type": "web"}, {"type": "images"}]} + assert result == mcp_response.model_dump(exclude_none=True, mode="json")