Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion src/google/adk/tools/mcp_tool/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,12 @@ async def _run_async_impl(
headers=final_headers
)

response = await session.call_tool(self._mcp_tool.name, arguments=args)
# Transform arguments to match MCP schema
transformed_args = self._transform_args_to_mcp_format(
args, self._mcp_tool.inputSchema
)

response = await session.call_tool(self._mcp_tool.name, arguments=transformed_args)
return response.model_dump(exclude_none=True, mode="json")

async def _get_headers(
Expand Down Expand Up @@ -305,6 +310,78 @@ async def _get_headers(

return headers

def _transform_args_to_mcp_format(
self, args: Dict[str, Any], mcp_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""Transform arguments to match MCP schema.

Handles cases where model output simplifies array-of-objects to
array-of-primitives for schemas with single-property objects.

Args:
args: Tool arguments from model output.
mcp_schema: MCP tool input schema.

Returns:
Transformed arguments matching schema, or original if no transformation needed.
"""
if not args or not mcp_schema:
return args

properties = mcp_schema.get("properties", {})
if not properties:
return args

transformed = {}
for key, value in args.items():
if key in properties:
transformed[key] = self._transform_value_to_schema(value, properties[key])
else:
transformed[key] = value

return transformed

def _transform_value_to_schema(
self, value: Any, schema: Dict[str, Any]
) -> Any:
"""Transform value to match schema.

Args:
value: Value to transform.
schema: JSON schema for the value.

Returns:
Transformed value or original if no transformation needed.
"""
if value is None or not schema:
return value

schema_type = schema.get("type")

if schema_type == "array" and isinstance(value, list) and value:
items_schema = schema.get("items")
if not items_schema or items_schema.get("type") != "object":
return value

if not isinstance(value[0], dict):
if not all(not isinstance(item, dict) for item in value):
logger.warning(
"Mixed types in array for MCP tool %s", self.name
)
return value

item_properties = items_schema.get("properties", {})
if len(item_properties) == 1:
property_name = next(iter(item_properties))
logger.debug(
"Transforming array for MCP tool %s with property '%s'",
self.name,
property_name,
)
return [{property_name: item} for item in value]
Comment on lines +361 to +381

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for transforming values has a potential issue with mixed-type arrays. It assumes that if the first element in a list is a dictionary, the entire list is correctly formatted. This could lead to errors if the list contains a mix of dictionaries and primitives (e.g., [{'type': 'web'}, 'images']).

The suggested change refactors the logic to be more robust by explicitly checking if the list contains all dictionaries, all primitives, or a mix of both, ensuring that mixed lists are handled gracefully by logging a warning and returning the original value. This prevents potential downstream errors in the MCP tool.

It would also be beneficial to add a new unit test to cover this mixed-list scenario to prevent future regressions.

    if schema_type == "array" and isinstance(value, list) and value:
      items_schema = schema.get("items")
      if not items_schema or items_schema.get("type") != "object":
        return value

      is_list_of_dicts = all(isinstance(item, dict) for item in value)
      if is_list_of_dicts:
        return value

      is_list_of_primitives = all(not isinstance(item, dict) for item in value)
      if not is_list_of_primitives:
        logger.warning(
            "Mixed types in array for MCP tool %s", self.name
        )
        return value

      item_properties = items_schema.get("properties", {})
      if len(item_properties) == 1:
        property_name = next(iter(item_properties))
        logger.debug(
            "Transforming array for MCP tool %s with property '%s'",
            self.name,
            property_name,
        )
        return [{property_name: item} for item in value]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current logic for validating the array hinges on the type of the first element (value[0]). This creates a potential failure point if the array contains mixed types.


return value


class MCPTool(McpTool):
"""Deprecated name, use `McpTool` instead."""
Expand Down
163 changes: 163 additions & 0 deletions tests/unittests/tools/mcp_tool/test_mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,166 @@ async def test_run_async_impl_with_header_provider_and_oauth2(self):
self.mock_session.call_tool.assert_called_once_with(
"test_tool", arguments=args
)

def test_transform_args_simple_types(self):
"""Test that simple types are not transformed."""
schema = {
"properties": {
"name": {"type": "string"},
"count": {"type": "integer"},
}
}
self.mock_mcp_tool.inputSchema = schema
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)

args = {"name": "test", "count": 42}
result = tool._transform_args_to_mcp_format(args, schema)

assert result == args

def test_transform_args_array_of_primitives_to_objects(self):
"""Test transformation of array of primitives to array of objects."""
schema = {
"properties": {
"sources": {
"type": "array",
"items": {
"type": "object",
"properties": {"type": {"type": "string"}},
},
}
}
}
self.mock_mcp_tool.inputSchema = schema
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)

args = {"sources": ["web", "images"]}
result = tool._transform_args_to_mcp_format(args, schema)

expected = {"sources": [{"type": "web"}, {"type": "images"}]}
assert result == expected

def test_transform_args_already_correct_format(self):
"""Test that already correct format is not modified."""
schema = {
"properties": {
"sources": {
"type": "array",
"items": {
"type": "object",
"properties": {"type": {"type": "string"}},
},
}
}
}
self.mock_mcp_tool.inputSchema = schema
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)

args = {"sources": [{"type": "web"}, {"type": "images"}]}
result = tool._transform_args_to_mcp_format(args, schema)

assert result == args

def test_transform_args_empty_array(self):
"""Test transformation with empty array."""
schema = {
"properties": {
"sources": {
"type": "array",
"items": {
"type": "object",
"properties": {"type": {"type": "string"}},
},
}
}
}
self.mock_mcp_tool.inputSchema = schema
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)

args = {"sources": []}
result = tool._transform_args_to_mcp_format(args, schema)

assert result == {"sources": []}

def test_transform_args_no_transformation_for_multi_property_objects(self):
"""Test that objects with multiple properties are not transformed."""
schema = {
"properties": {
"filters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"field": {"type": "string"},
"value": {"type": "string"},
},
},
}
}
}
self.mock_mcp_tool.inputSchema = schema
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)

args = {"filters": ["field1", "field2"]}
result = tool._transform_args_to_mcp_format(args, schema)

assert result == args

@pytest.mark.asyncio
async def test_run_async_impl_transforms_args(self):
"""Test that _run_async_impl applies argument transformation."""
schema = {
"properties": {
"sources": {
"type": "array",
"items": {
"type": "object",
"properties": {"type": {"type": "string"}},
},
}
}
}
self.mock_mcp_tool.inputSchema = schema
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)

mcp_response = CallToolResult(
content=[TextContent(type="text", text="success")]
)
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)

tool_context = Mock(spec=ToolContext)
credential = None

# Model generates simplified format
args = {"sources": ["web", "images"]}

result = await tool._run_async_impl(
args=args, tool_context=tool_context, credential=credential
)

# Verify transformation was applied before calling MCP tool
self.mock_session.call_tool.assert_called_once()
call_args = self.mock_session.call_tool.call_args
transformed_args = call_args[1]["arguments"]

# Should be transformed to correct format
assert transformed_args == {"sources": [{"type": "web"}, {"type": "images"}]}
assert result == mcp_response.model_dump(exclude_none=True, mode="json")