Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 83 additions & 7 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,24 @@
import contextlib
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from types import TracebackType
from typing import Any, TypeAlias
from typing import Any, TypeAlias, overload

import anyio
from pydantic import BaseModel
from typing_extensions import Self
from typing_extensions import Self, deprecated

import mcp
from mcp import types
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.exceptions import McpError
from mcp.shared.session import ProgressFnT

from .session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT


class SseServerParameters(BaseModel):
Expand Down Expand Up @@ -65,6 +69,21 @@ class StreamableHttpParameters(BaseModel):
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters


# Use dataclass instead of pydantic BaseModel
# because pydantic BaseModel cannot handle Protocol fields.
@dataclass
class ClientSessionParameters:
"""Parameters for establishing a client session to an MCP server."""

read_timeout_seconds: timedelta | None = None
sampling_callback: SamplingFnT | None = None
elicitation_callback: ElicitationFnT | None = None
list_roots_callback: ListRootsFnT | None = None
logging_callback: LoggingFnT | None = None
message_handler: MessageHandlerFnT | None = None
client_info: types.Implementation | None = None


class ClientSessionGroup:
"""Client for managing connections to multiple MCP servers.

Expand Down Expand Up @@ -172,11 +191,49 @@ def tools(self) -> dict[str, types.Tool]:
"""Returns the tools as a dictionary of names to tools."""
return self._tools

async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult: ...

@overload
@deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.")
async def call_tool(
self,
name: str,
*,
args: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult: ...

async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
args: dict[str, Any] | None = None,
) -> types.CallToolResult:
"""Executes a tool given its name and arguments."""
session = self._tool_to_session[name]
session_tool_name = self.tools[name].name
return await session.call_tool(session_tool_name, args)
return await session.call_tool(
session_tool_name,
arguments if args is None else args,
read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
meta=meta,
)

async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
"""Disconnects from a single MCP server."""
Expand Down Expand Up @@ -225,13 +282,16 @@ async def connect_with_session(
async def connect_to_server(
self,
server_params: ServerParameters,
session_params: ClientSessionParameters | None = None,
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
server_info, session = await self._establish_session(server_params)
server_info, session = await self._establish_session(server_params, session_params)
return await self.connect_with_session(server_info, session)

async def _establish_session(
self, server_params: ServerParameters
self,
server_params: ServerParameters,
session_params: ClientSessionParameters | None = None,
) -> tuple[types.Implementation, mcp.ClientSession]:
"""Establish a client session to an MCP server."""

Expand Down Expand Up @@ -259,7 +319,23 @@ async def _establish_session(
)
read, write, _ = await session_stack.enter_async_context(client)

session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
if session_params is None:
session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
else:
session = await session_stack.enter_async_context(
mcp.ClientSession(
read,
write,
read_timeout_seconds=session_params.read_timeout_seconds,
sampling_callback=session_params.sampling_callback,
elicitation_callback=session_params.elicitation_callback,
list_roots_callback=session_params.list_roots_callback,
logging_callback=session_params.logging_callback,
message_handler=session_params.message_handler,
client_info=session_params.client_info,
)
)

result = await session.initialize()

# Session successfully initialized.
Expand Down
5 changes: 4 additions & 1 deletion tests/client/test_session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def hook(name: str, server_info: types.Implementation) -> str:
# --- Test Execution ---
result = await mcp_session_group.call_tool(
name="server1-my_tool",
args={
arguments={
"name": "value1",
"args": {},
},
Expand All @@ -73,6 +73,9 @@ def hook(name: str, server_info: types.Implementation) -> str:
mock_session.call_tool.assert_called_once_with(
"my_tool",
{"name": "value1", "args": {}},
read_timeout_seconds=None,
progress_callback=None,
meta=None,
)

async def test_connect_to_server(self, mock_exit_stack: contextlib.AsyncExitStack):
Expand Down