1717
1818from ..exceptions import UserError
1919from ..logger import logger
20+ from ..tool import ToolFilter , ToolFilterContext , ToolFilterStatic
2021
2122
2223class MCPServer (abc .ABC ):
@@ -61,8 +62,7 @@ def __init__(
6162 self ,
6263 cache_tools_list : bool ,
6364 client_session_timeout_seconds : float | None ,
64- allowed_tools : list [str ] | None = None ,
65- excluded_tools : list [str ] | None = None ,
65+ tool_filter : ToolFilter = None ,
6666 ):
6767 """
6868 Args:
@@ -74,10 +74,7 @@ def __init__(
7474 (by avoiding a round-trip to the server every time).
7575
7676 client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
77- allowed_tools: Optional list of tool names to allow (whitelist).
78- If set, only these tools will be available.
79- excluded_tools: Optional list of tool names to exclude (blacklist).
80- If set, these tools will be filtered out.
77+ tool_filter: The tool filter to use for filtering tools.
8178 """
8279 self .session : ClientSession | None = None
8380 self .exit_stack : AsyncExitStack = AsyncExitStack ()
@@ -91,8 +88,39 @@ def __init__(
9188 self ._cache_dirty = True
9289 self ._tools_list : list [MCPTool ] | None = None
9390
94- self .allowed_tools = allowed_tools
95- self .excluded_tools = excluded_tools
91+ self .tool_filter = tool_filter
92+
93+ def _apply_tool_filter (self , tools : list [MCPTool ]) -> list [MCPTool ]:
94+ """Apply the tool filter to the list of tools."""
95+ if self .tool_filter is None :
96+ return tools
97+
98+ # Handle static tool filter
99+ if isinstance (self .tool_filter , dict ):
100+ static_filter : ToolFilterStatic = self .tool_filter
101+ filtered_tools = tools
102+
103+ # Apply allowed_tool_names filter (whitelist)
104+ if "allowed_tool_names" in static_filter :
105+ allowed_names = static_filter ["allowed_tool_names" ]
106+ filtered_tools = [t for t in filtered_tools if t .name in allowed_names ]
107+
108+ # Apply blocked_tool_names filter (blacklist)
109+ if "blocked_tool_names" in static_filter :
110+ blocked_names = static_filter ["blocked_tool_names" ]
111+ filtered_tools = [t for t in filtered_tools if t .name not in blocked_names ]
112+
113+ return filtered_tools
114+
115+ # Handle callable tool filter
116+ # For now, we can't support callable filters because we don't have access to
117+ # run context and agent in the current list_tools signature.
118+ # This could be enhanced in the future by modifying the call chain.
119+ else :
120+ raise NotImplementedError (
121+ "Callable tool filters are not yet supported. Please use ToolFilterStatic "
122+ "with 'allowed_tool_names' and/or 'blocked_tool_names' for now."
123+ )
96124
97125 @abc .abstractmethod
98126 def create_streams (
@@ -159,12 +187,10 @@ async def list_tools(self) -> list[MCPTool]:
159187 self ._tools_list = (await self .session .list_tools ()).tools
160188 tools = self ._tools_list
161189
162- # Filter tools based on allowed and excluded tools
190+ # Filter tools based on tool_filter
163191 filtered_tools = tools
164- if self .allowed_tools is not None :
165- filtered_tools = [t for t in filtered_tools if t .name in self .allowed_tools ]
166- if self .excluded_tools is not None :
167- filtered_tools = [t for t in filtered_tools if t .name not in self .excluded_tools ]
192+ if self .tool_filter is not None :
193+ filtered_tools = self ._apply_tool_filter (filtered_tools )
168194 return filtered_tools
169195
170196 async def call_tool (self , tool_name : str , arguments : dict [str , Any ] | None ) -> CallToolResult :
@@ -226,8 +252,7 @@ def __init__(
226252 cache_tools_list : bool = False ,
227253 name : str | None = None ,
228254 client_session_timeout_seconds : float | None = 5 ,
229- allowed_tools : list [str ] | None = None ,
230- excluded_tools : list [str ] | None = None ,
255+ tool_filter : ToolFilter = None ,
231256 ):
232257 """Create a new MCP server based on the stdio transport.
233258
@@ -245,14 +270,12 @@ def __init__(
245270 name: A readable name for the server. If not provided, we'll create one from the
246271 command.
247272 client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
248- allowed_tools: Optional list of tool names to allow (whitelist).
249- excluded_tools: Optional list of tool names to exclude (blacklist).
273+ tool_filter: The tool filter to use for filtering tools.
250274 """
251275 super ().__init__ (
252276 cache_tools_list ,
253277 client_session_timeout_seconds ,
254- allowed_tools ,
255- excluded_tools ,
278+ tool_filter ,
256279 )
257280
258281 self .params = StdioServerParameters (
@@ -312,8 +335,7 @@ def __init__(
312335 cache_tools_list : bool = False ,
313336 name : str | None = None ,
314337 client_session_timeout_seconds : float | None = 5 ,
315- allowed_tools : list [str ] | None = None ,
316- excluded_tools : list [str ] | None = None ,
338+ tool_filter : ToolFilter = None ,
317339 ):
318340 """Create a new MCP server based on the HTTP with SSE transport.
319341
@@ -333,14 +355,12 @@ def __init__(
333355 URL.
334356
335357 client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
336- allowed_tools: Optional list of tool names to allow (whitelist).
337- excluded_tools: Optional list of tool names to exclude (blacklist).
358+ tool_filter: The tool filter to use for filtering tools.
338359 """
339360 super ().__init__ (
340361 cache_tools_list ,
341362 client_session_timeout_seconds ,
342- allowed_tools ,
343- excluded_tools ,
363+ tool_filter ,
344364 )
345365
346366 self .params = params
@@ -400,8 +420,7 @@ def __init__(
400420 cache_tools_list : bool = False ,
401421 name : str | None = None ,
402422 client_session_timeout_seconds : float | None = 5 ,
403- allowed_tools : list [str ] | None = None ,
404- excluded_tools : list [str ] | None = None ,
423+ tool_filter : ToolFilter = None ,
405424 ):
406425 """Create a new MCP server based on the Streamable HTTP transport.
407426
@@ -422,14 +441,12 @@ def __init__(
422441 URL.
423442
424443 client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
425- allowed_tools: Optional list of tool names to allow (whitelist).
426- excluded_tools: Optional list of tool names to exclude (blacklist).
444+ tool_filter: The tool filter to use for filtering tools.
427445 """
428446 super ().__init__ (
429447 cache_tools_list ,
430448 client_session_timeout_seconds ,
431- allowed_tools ,
432- excluded_tools ,
449+ tool_filter ,
433450 )
434451
435452 self .params = params
0 commit comments