Skip to content

Commit 8d87cb3

Browse files
committed
Convert MCP schemas to strict where possible
## Summary: Towards #404. I made this configurable because it's not clear this is always a good thing to do. I also made it default to False because I'm not sure if this could cause errors. If it works out well, we can switch the default in the future as a small breaking changes ## Test Plan: Unit tests
1 parent 45c25f8 commit 8d87cb3

File tree

3 files changed

+196
-13
lines changed

3 files changed

+196
-13
lines changed

src/agents/agent.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass, field
77
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
88

9-
from typing_extensions import TypeAlias, TypedDict
9+
from typing_extensions import NotRequired, TypeAlias, TypedDict
1010

1111
from .guardrail import InputGuardrail, OutputGuardrail
1212
from .handoffs import Handoff
@@ -53,6 +53,15 @@ class StopAtTools(TypedDict):
5353
"""A list of tool names, any of which will stop the agent from running further."""
5454

5555

56+
class MCPConfig(TypedDict):
57+
"""Configuration for MCP servers."""
58+
59+
convert_schemas_to_strict: NotRequired[bool]
60+
"""If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a
61+
best-effort conversion, so some schemas may not be convertible. Defaults to False.
62+
"""
63+
64+
5665
@dataclass
5766
class Agent(Generic[TContext]):
5867
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
@@ -119,6 +128,9 @@ class Agent(Generic[TContext]):
119128
longer needed.
120129
"""
121130

131+
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
132+
"""Configuration for MCP servers."""
133+
122134
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
123135
"""A list of checks that run in parallel to the agent's execution, before generating a
124136
response. Runs only if the agent is the first agent in the chain.
@@ -224,7 +236,8 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
224236

225237
async def get_mcp_tools(self) -> list[Tool]:
226238
"""Fetches the available tools from the MCP servers."""
227-
return await MCPUtil.get_all_function_tools(self.mcp_servers)
239+
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
240+
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
228241

229242
async def get_all_tools(self) -> list[Tool]:
230243
"""All agent tools, including MCP tools and function tools."""

src/agents/mcp/util.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import json
33
from typing import TYPE_CHECKING, Any
44

5+
from agents.strict_schema import ensure_strict_json_schema
6+
57
from .. import _debug
68
from ..exceptions import AgentsException, ModelBehaviorError, UserError
79
from ..logger import logger
@@ -19,12 +21,14 @@ class MCPUtil:
1921
"""Set of utilities for interop between MCP and Agents SDK tools."""
2022

2123
@classmethod
22-
async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]:
24+
async def get_all_function_tools(
25+
cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
26+
) -> list[Tool]:
2327
"""Get all function tools from a list of MCP servers."""
2428
tools = []
2529
tool_names: set[str] = set()
2630
for server in servers:
27-
server_tools = await cls.get_function_tools(server)
31+
server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
2832
server_tool_names = {tool.name for tool in server_tools}
2933
if len(server_tool_names & tool_names) > 0:
3034
raise UserError(
@@ -37,25 +41,37 @@ async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]:
3741
return tools
3842

3943
@classmethod
40-
async def get_function_tools(cls, server: "MCPServer") -> list[Tool]:
44+
async def get_function_tools(
45+
cls, server: "MCPServer", convert_schemas_to_strict: bool
46+
) -> list[Tool]:
4147
"""Get all function tools from a single MCP server."""
4248

4349
with mcp_tools_span(server=server.name) as span:
4450
tools = await server.list_tools()
4551
span.span_data.result = [tool.name for tool in tools]
4652

47-
return [cls.to_function_tool(tool, server) for tool in tools]
53+
return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
4854

4955
@classmethod
50-
def to_function_tool(cls, tool: "MCPTool", server: "MCPServer") -> FunctionTool:
56+
def to_function_tool(
57+
cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool
58+
) -> FunctionTool:
5159
"""Convert an MCP tool to an Agents SDK function tool."""
5260
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
61+
schema, is_strict = tool.inputSchema, False
62+
if convert_schemas_to_strict:
63+
try:
64+
schema = ensure_strict_json_schema(schema)
65+
is_strict = True
66+
except Exception as e:
67+
logger.info(f"Error converting MCP schema to strict mode: {e}")
68+
5369
return FunctionTool(
5470
name=tool.name,
5571
description=tool.description or "",
56-
params_json_schema=tool.inputSchema,
72+
params_json_schema=schema,
5773
on_invoke_tool=invoke_func,
58-
strict_json_schema=False,
74+
strict_json_schema=is_strict,
5975
)
6076

6177
@classmethod

tests/mcp/test_mcp_util.py

Lines changed: 158 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import json
12
import logging
23
from typing import Any
34

45
import pytest
6+
from inline_snapshot import snapshot
57
from mcp.types import Tool as MCPTool
6-
from pydantic import BaseModel
8+
from pydantic import BaseModel, TypeAdapter
79

8-
from agents import FunctionTool, RunContextWrapper
10+
from agents import Agent, FunctionTool, RunContextWrapper
911
from agents.exceptions import AgentsException, ModelBehaviorError
1012
from agents.mcp import MCPServer, MCPUtil
1113

@@ -18,7 +20,16 @@ class Foo(BaseModel):
1820

1921

2022
class Bar(BaseModel):
21-
qux: str
23+
qux: dict[str, str]
24+
25+
26+
Baz = TypeAdapter(dict[str, str])
27+
28+
29+
def _convertible_schema() -> dict[str, Any]:
30+
schema = Foo.model_json_schema()
31+
schema["additionalProperties"] = False
32+
return schema
2233

2334

2435
@pytest.mark.asyncio
@@ -47,7 +58,7 @@ async def test_get_all_function_tools():
4758
server3.add_tool(names[4], schemas[4])
4859

4960
servers: list[MCPServer] = [server1, server2, server3]
50-
tools = await MCPUtil.get_all_function_tools(servers)
61+
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False)
5162
assert len(tools) == 5
5263
assert all(tool.name in names for tool in tools)
5364

@@ -56,6 +67,11 @@ async def test_get_all_function_tools():
5667
assert tool.params_json_schema == schemas[idx]
5768
assert tool.name == names[idx]
5869

70+
# Also make sure it works with strict schemas
71+
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True)
72+
assert len(tools) == 5
73+
assert all(tool.name in names for tool in tools)
74+
5975

6076
@pytest.mark.asyncio
6177
async def test_invoke_mcp_tool():
@@ -107,3 +123,141 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur
107123
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
108124

109125
assert "Error invoking MCP tool test_tool_1" in caplog.text
126+
127+
128+
@pytest.mark.asyncio
129+
async def test_agent_convert_schemas_true():
130+
"""Test that setting convert_schemas_to_strict to True converts non-strict schemas to strict.
131+
- 'foo' tool is already strict and remains strict.
132+
- 'bar' tool is non-strict and becomes strict (additionalProperties set to False, etc).
133+
"""
134+
strict_schema = Foo.model_json_schema()
135+
non_strict_schema = Baz.json_schema()
136+
possible_to_convert_schema = _convertible_schema()
137+
138+
server = FakeMCPServer()
139+
server.add_tool("foo", strict_schema)
140+
server.add_tool("bar", non_strict_schema)
141+
server.add_tool("baz", possible_to_convert_schema)
142+
agent = Agent(
143+
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True}
144+
)
145+
tools = await agent.get_mcp_tools()
146+
147+
foo_tool = next(tool for tool in tools if tool.name == "foo")
148+
assert isinstance(foo_tool, FunctionTool)
149+
bar_tool = next(tool for tool in tools if tool.name == "bar")
150+
assert isinstance(bar_tool, FunctionTool)
151+
baz_tool = next(tool for tool in tools if tool.name == "baz")
152+
assert isinstance(baz_tool, FunctionTool)
153+
154+
# Checks that additionalProperties is set to False
155+
assert foo_tool.params_json_schema == snapshot(
156+
{
157+
"properties": {
158+
"bar": {"title": "Bar", "type": "string"},
159+
"baz": {"title": "Baz", "type": "integer"},
160+
},
161+
"required": ["bar", "baz"],
162+
"title": "Foo",
163+
"type": "object",
164+
"additionalProperties": False,
165+
}
166+
)
167+
assert foo_tool.strict_json_schema is True, "foo_tool should be strict"
168+
169+
# Checks that additionalProperties is set to False
170+
assert bar_tool.params_json_schema == snapshot(
171+
{
172+
"type": "object",
173+
"additionalProperties": {"type": "string"},
174+
}
175+
)
176+
assert bar_tool.strict_json_schema is False, "bar_tool should not be strict"
177+
178+
# Checks that additionalProperties is set to False
179+
assert baz_tool.params_json_schema == snapshot(
180+
{
181+
"properties": {
182+
"bar": {"title": "Bar", "type": "string"},
183+
"baz": {"title": "Baz", "type": "integer"},
184+
},
185+
"required": ["bar", "baz"],
186+
"title": "Foo",
187+
"type": "object",
188+
"additionalProperties": False,
189+
}
190+
)
191+
assert baz_tool.strict_json_schema is True, "baz_tool should be strict"
192+
193+
194+
@pytest.mark.asyncio
195+
async def test_agent_convert_schemas_false():
196+
"""Test that setting convert_schemas_to_strict to False leaves tool schemas as non-strict.
197+
- 'foo' tool remains strict.
198+
- 'bar' tool remains non-strict (additionalProperties remains True).
199+
"""
200+
strict_schema = Foo.model_json_schema()
201+
non_strict_schema = Baz.json_schema()
202+
possible_to_convert_schema = _convertible_schema()
203+
204+
server = FakeMCPServer()
205+
server.add_tool("foo", strict_schema)
206+
server.add_tool("bar", non_strict_schema)
207+
server.add_tool("baz", possible_to_convert_schema)
208+
209+
agent = Agent(
210+
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False}
211+
)
212+
tools = await agent.get_mcp_tools()
213+
214+
foo_tool = next(tool for tool in tools if tool.name == "foo")
215+
assert isinstance(foo_tool, FunctionTool)
216+
bar_tool = next(tool for tool in tools if tool.name == "bar")
217+
assert isinstance(bar_tool, FunctionTool)
218+
baz_tool = next(tool for tool in tools if tool.name == "baz")
219+
assert isinstance(baz_tool, FunctionTool)
220+
221+
assert foo_tool.params_json_schema == strict_schema
222+
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
223+
224+
assert bar_tool.params_json_schema == non_strict_schema
225+
assert bar_tool.strict_json_schema is False
226+
227+
assert baz_tool.params_json_schema == possible_to_convert_schema
228+
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
229+
230+
231+
@pytest.mark.asyncio
232+
async def test_agent_convert_schemas_unset():
233+
"""Test that leaving convert_schemas_to_strict unset (defaulting to False) leaves tool schemas
234+
as non-strict.
235+
- 'foo' tool remains strict.
236+
- 'bar' tool remains non-strict.
237+
"""
238+
strict_schema = Foo.model_json_schema()
239+
non_strict_schema = Baz.json_schema()
240+
possible_to_convert_schema = _convertible_schema()
241+
242+
server = FakeMCPServer()
243+
server.add_tool("foo", strict_schema)
244+
server.add_tool("bar", non_strict_schema)
245+
server.add_tool("baz", possible_to_convert_schema)
246+
agent = Agent(name="test_agent", mcp_servers=[server])
247+
tools = await agent.get_mcp_tools()
248+
249+
foo_tool = next(tool for tool in tools if tool.name == "foo")
250+
assert isinstance(foo_tool, FunctionTool)
251+
bar_tool = next(tool for tool in tools if tool.name == "bar")
252+
assert isinstance(bar_tool, FunctionTool)
253+
baz_tool = next(tool for tool in tools if tool.name == "baz")
254+
assert isinstance(baz_tool, FunctionTool)
255+
256+
assert foo_tool.params_json_schema == strict_schema
257+
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
258+
259+
assert bar_tool.params_json_schema == non_strict_schema
260+
assert bar_tool.strict_json_schema is False
261+
262+
assert baz_tool.params_json_schema == possible_to_convert_schema
263+
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"

0 commit comments

Comments
 (0)