diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index ec2eb18fe..26a291779 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -11,13 +11,14 @@ import contextlib import logging from collections.abc import Callable +from dataclasses import dataclass from datetime import timedelta from types import TracebackType -from typing import Any, TypeAlias +from typing import Any, TypeAlias, overload import anyio from pydantic import BaseModel -from typing_extensions import Self +from typing_extensions import Self, deprecated import mcp from mcp import types @@ -25,6 +26,9 @@ from mcp.client.stdio import StdioServerParameters from mcp.client.streamable_http import streamablehttp_client from mcp.shared.exceptions import McpError +from mcp.shared.session import ProgressFnT + +from .session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT class SseServerParameters(BaseModel): @@ -65,6 +69,21 @@ class StreamableHttpParameters(BaseModel): ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters +# Use dataclass instead of pydantic BaseModel +# because pydantic BaseModel cannot handle Protocol fields. +@dataclass +class ClientSessionParameters: + """Parameters for establishing a client session to an MCP server.""" + + read_timeout_seconds: timedelta | None = None + sampling_callback: SamplingFnT | None = None + elicitation_callback: ElicitationFnT | None = None + list_roots_callback: ListRootsFnT | None = None + logging_callback: LoggingFnT | None = None + message_handler: MessageHandlerFnT | None = None + client_info: types.Implementation | None = None + + class ClientSessionGroup: """Client for managing connections to multiple MCP servers. @@ -172,11 +191,49 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools - async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: + @overload + async def call_tool( + self, + name: str, + arguments: dict[str, Any], + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: ... + + @overload + @deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.") + async def call_tool( + self, + name: str, + *, + args: dict[str, Any], + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + meta: dict[str, Any] | None = None, + ) -> types.CallToolResult: ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + *, + meta: dict[str, Any] | None = None, + args: dict[str, Any] | None = None, + ) -> types.CallToolResult: """Executes a tool given its name and arguments.""" session = self._tool_to_session[name] session_tool_name = self.tools[name].name - return await session.call_tool(session_tool_name, args) + return await session.call_tool( + session_tool_name, + arguments if args is None else args, + read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, + meta=meta, + ) async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" @@ -225,13 +282,16 @@ async def connect_with_session( async def connect_to_server( self, server_params: ServerParameters, + session_params: ClientSessionParameters | None = None, ) -> mcp.ClientSession: """Connects to a single MCP server.""" - server_info, session = await self._establish_session(server_params) + server_info, session = await self._establish_session(server_params, session_params) return await self.connect_with_session(server_info, session) async def _establish_session( - self, server_params: ServerParameters + self, + server_params: ServerParameters, + session_params: ClientSessionParameters | None = None, ) -> tuple[types.Implementation, mcp.ClientSession]: """Establish a client session to an MCP server.""" @@ -259,7 +319,23 @@ async def _establish_session( ) read, write, _ = await session_stack.enter_async_context(client) - session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) + if session_params is None: + session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) + else: + session = await session_stack.enter_async_context( + mcp.ClientSession( + read, + write, + read_timeout_seconds=session_params.read_timeout_seconds, + sampling_callback=session_params.sampling_callback, + elicitation_callback=session_params.elicitation_callback, + list_roots_callback=session_params.list_roots_callback, + logging_callback=session_params.logging_callback, + message_handler=session_params.message_handler, + client_info=session_params.client_info, + ) + ) + result = await session.initialize() # Session successfully initialized. diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index c38cfeabc..584c9bddf 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -62,7 +62,7 @@ def hook(name: str, server_info: types.Implementation) -> str: # --- Test Execution --- result = await mcp_session_group.call_tool( name="server1-my_tool", - args={ + arguments={ "name": "value1", "args": {}, }, @@ -73,6 +73,9 @@ def hook(name: str, server_info: types.Implementation) -> str: mock_session.call_tool.assert_called_once_with( "my_tool", {"name": "value1", "args": {}}, + read_timeout_seconds=None, + progress_callback=None, + meta=None, ) async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack):