diff --git a/README.md b/README.md index 5dbc4bd9d..a49ac8b72 100644 --- a/README.md +++ b/README.md @@ -2153,6 +2153,109 @@ if __name__ == "__main__": _Full example: [examples/snippets/clients/streamable_basic.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/streamable_basic.py)_ +### Per-Request HTTP Headers + +When using HTTP transports, you can pass custom headers on a per-request basis. This enables various use cases such as request tracing, authentication context, A/B testing, debugging flags, and more while maintaining a single persistent connection: + + +```python +""" +Example demonstrating per-request headers functionality for MCP client. + +Shows how to use the extra_headers parameter to send different HTTP headers +with each request, enabling use cases like per-user authentication, request +tracing, A/B testing, and multi-tenant applications. +""" + +import asyncio + +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client + + +async def main(): + """Demonstrate per-request headers functionality.""" + + # Connection-level headers (static for the entire session) + connection_headers = {"Authorization": "Bearer org-level-token", "X-Org-ID": "org-123"} + + async with streamablehttp_client("https://mcp.example.com/mcp", headers=connection_headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Example 1: Request tracing + tracing_headers = { + "X-Request-ID": "req-12345", + "X-Trace-ID": "trace-abc-456", + } + result = await session.call_tool("process_data", {"type": "analytics"}, extra_headers=tracing_headers) + print(f"Traced request result: {result}") + + # Example 2: User-specific authentication + user_headers = { + "X-User-ID": "alice", + "X-Auth-Token": "user-token-12345", + } + result = await session.call_tool("get_user_data", {"fields": ["profile"]}, extra_headers=user_headers) + print(f"User-specific result: {result}") + + # Example 3: A/B testing + experiment_headers = { + "X-Experiment-ID": "new-ui-test", + "X-Variant": "variant-b", + } + result = await session.call_tool( + "get_recommendations", {"user_id": "user123"}, extra_headers=experiment_headers + ) + print(f"A/B test result: {result}") + + # Example 4: Override connection-level headers + override_headers = { + "Authorization": "Bearer user-specific-token", # Overrides connection-level + "X-Special-Permission": "admin", + } + result = await session.call_tool("admin_operation", {"operation": "reset"}, extra_headers=override_headers) + print(f"Admin operation result: {result}") + + # Example 5: Works with all ClientSession methods + await session.list_resources(extra_headers={"X-Resource-Filter": "public"}) + await session.get_prompt("template", extra_headers={"X-Context": "help"}) + await session.set_logging_level("debug", extra_headers={"X-Debug-Session": "true"}) + + +if __name__ == "__main__": + print("MCP Client Per-Request Headers Example") + print("=" * 50) + + try: + asyncio.run(main()) + except Exception as e: + print(f"Example requires a running MCP server. Error: {e}") + print("\nThis example demonstrates the API usage patterns.") +``` + +_Full example: [examples/snippets/clients/per_request_headers_example.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/clients/per_request_headers_example.py)_ + + +The `extra_headers` parameter is available for all `ClientSession` methods that make server requests: + +- `call_tool()` +- `get_prompt()` +- `read_resource()` +- `list_tools()` +- `list_prompts()` +- `list_resources()` +- `list_resource_templates()` +- `subscribe()` +- `unsubscribe()` +- `set_logging_level()` + +Per-request headers are merged with the transport's default headers, with per-request headers taking precedence for duplicate keys. + ### Client Display Utilities When building MCP clients, the SDK provides utilities to help display human-readable names for tools, resources, and prompts: diff --git a/examples/snippets/clients/per_request_headers_example.py b/examples/snippets/clients/per_request_headers_example.py new file mode 100644 index 000000000..bedaae652 --- /dev/null +++ b/examples/snippets/clients/per_request_headers_example.py @@ -0,0 +1,77 @@ +""" +Example demonstrating per-request headers functionality for MCP client. + +Shows how to use the extra_headers parameter to send different HTTP headers +with each request, enabling use cases like per-user authentication, request +tracing, A/B testing, and multi-tenant applications. +""" + +import asyncio + +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client + + +async def main(): + """Demonstrate per-request headers functionality.""" + + # Connection-level headers (static for the entire session) + connection_headers = {"Authorization": "Bearer org-level-token", "X-Org-ID": "org-123"} + + async with streamablehttp_client("https://mcp.example.com/mcp", headers=connection_headers) as ( + read_stream, + write_stream, + _, + ): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + # Example 1: Request tracing + tracing_headers = { + "X-Request-ID": "req-12345", + "X-Trace-ID": "trace-abc-456", + } + result = await session.call_tool("process_data", {"type": "analytics"}, extra_headers=tracing_headers) + print(f"Traced request result: {result}") + + # Example 2: User-specific authentication + user_headers = { + "X-User-ID": "alice", + "X-Auth-Token": "user-token-12345", + } + result = await session.call_tool("get_user_data", {"fields": ["profile"]}, extra_headers=user_headers) + print(f"User-specific result: {result}") + + # Example 3: A/B testing + experiment_headers = { + "X-Experiment-ID": "new-ui-test", + "X-Variant": "variant-b", + } + result = await session.call_tool( + "get_recommendations", {"user_id": "user123"}, extra_headers=experiment_headers + ) + print(f"A/B test result: {result}") + + # Example 4: Override connection-level headers + override_headers = { + "Authorization": "Bearer user-specific-token", # Overrides connection-level + "X-Special-Permission": "admin", + } + result = await session.call_tool("admin_operation", {"operation": "reset"}, extra_headers=override_headers) + print(f"Admin operation result: {result}") + + # Example 5: Works with all ClientSession methods + await session.list_resources(extra_headers={"X-Resource-Filter": "public"}) + await session.get_prompt("template", extra_headers={"X-Context": "help"}) + await session.set_logging_level("debug", extra_headers={"X-Debug-Session": "true"}) + + +if __name__ == "__main__": + print("MCP Client Per-Request Headers Example") + print("=" * 50) + + try: + asyncio.run(main()) + except Exception as e: + print(f"Example requires a running MCP server. Error: {e}") + print("\nThis example demonstrates the API usage patterns.") diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 9e9389ac1..d5624d666 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -135,6 +135,14 @@ def __init__( self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} + def _create_metadata_for_extra_headers(self, extra_headers: dict[str, str] | None): + """Create metadata for passing extra headers to the transport layer.""" + if extra_headers: + from mcp.shared.message import ClientMessageMetadata + + return ClientMessageMetadata(extra_headers=extra_headers) + return None + async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None elicitation = ( @@ -202,7 +210,9 @@ async def send_progress_notification( ) ) - async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: + async def set_logging_level( + self, level: types.LoggingLevel, *, extra_headers: dict[str, str] | None = None + ) -> types.EmptyResult: """Send a logging/setLevel request.""" return await self.send_request( types.ClientRequest( @@ -211,6 +221,7 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul ) ), types.EmptyResult, + metadata=self._create_metadata_for_extra_headers(extra_headers), ) @overload @@ -218,22 +229,26 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul async def list_resources(self, cursor: str | None) -> types.ListResourcesResult: ... @overload - async def list_resources(self, *, params: types.PaginatedRequestParams | None) -> types.ListResourcesResult: ... + async def list_resources( + self, *, params: types.PaginatedRequestParams | None, extra_headers: dict[str, str] | None = None + ) -> types.ListResourcesResult: ... @overload - async def list_resources(self) -> types.ListResourcesResult: ... + async def list_resources(self, *, extra_headers: dict[str, str] | None = None) -> types.ListResourcesResult: ... async def list_resources( self, cursor: str | None = None, *, params: types.PaginatedRequestParams | None = None, + extra_headers: dict[str, str] | None = None, ) -> types.ListResourcesResult: """Send a resources/list request. Args: cursor: Simple cursor string for pagination (deprecated, use params instead) params: Full pagination parameters including cursor and any future fields + extra_headers: Additional HTTP headers to include in this specific request. """ if params is not None and cursor is not None: raise ValueError("Cannot specify both cursor and params") @@ -248,6 +263,7 @@ async def list_resources( return await self.send_request( types.ClientRequest(types.ListResourcesRequest(params=request_params)), types.ListResourcesResult, + metadata=self._create_metadata_for_extra_headers(extra_headers), ) @overload @@ -256,23 +272,27 @@ async def list_resource_templates(self, cursor: str | None) -> types.ListResourc @overload async def list_resource_templates( - self, *, params: types.PaginatedRequestParams | None + self, *, params: types.PaginatedRequestParams | None, extra_headers: dict[str, str] | None = None ) -> types.ListResourceTemplatesResult: ... @overload - async def list_resource_templates(self) -> types.ListResourceTemplatesResult: ... + async def list_resource_templates( + self, *, extra_headers: dict[str, str] | None = None + ) -> types.ListResourceTemplatesResult: ... async def list_resource_templates( self, cursor: str | None = None, *, params: types.PaginatedRequestParams | None = None, + extra_headers: dict[str, str] | None = None, ) -> types.ListResourceTemplatesResult: """Send a resources/templates/list request. Args: cursor: Simple cursor string for pagination (deprecated, use params instead) params: Full pagination parameters including cursor and any future fields + extra_headers: Additional HTTP headers to include in this specific request. """ if params is not None and cursor is not None: raise ValueError("Cannot specify both cursor and params") @@ -287,9 +307,12 @@ async def list_resource_templates( return await self.send_request( types.ClientRequest(types.ListResourceTemplatesRequest(params=request_params)), types.ListResourceTemplatesResult, + metadata=self._create_metadata_for_extra_headers(extra_headers), ) - async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + async def read_resource( + self, uri: AnyUrl, *, extra_headers: dict[str, str] | None = None + ) -> types.ReadResourceResult: """Send a resources/read request.""" return await self.send_request( types.ClientRequest( @@ -298,9 +321,12 @@ async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: ) ), types.ReadResourceResult, + metadata=self._create_metadata_for_extra_headers(extra_headers), ) - async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def subscribe_resource( + self, uri: AnyUrl, *, extra_headers: dict[str, str] | None = None + ) -> types.EmptyResult: """Send a resources/subscribe request.""" return await self.send_request( types.ClientRequest( @@ -309,9 +335,12 @@ async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: ) ), types.EmptyResult, + metadata=self._create_metadata_for_extra_headers(extra_headers), ) - async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + async def unsubscribe_resource( + self, uri: AnyUrl, *, extra_headers: dict[str, str] | None = None + ) -> types.EmptyResult: """Send a resources/unsubscribe request.""" return await self.send_request( types.ClientRequest( @@ -320,6 +349,7 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: ) ), types.EmptyResult, + metadata=self._create_metadata_for_extra_headers(extra_headers), ) async def call_tool( @@ -330,8 +360,21 @@ async def call_tool( progress_callback: ProgressFnT | None = None, *, meta: dict[str, Any] | None = None, + extra_headers: dict[str, str] | None = None, ) -> types.CallToolResult: - """Send a tools/call request with optional progress callback support.""" + """Send a tools/call request with optional progress callback support. + + Args: + name: The name of the tool to call. + arguments: The arguments to pass to the tool. + read_timeout_seconds: Optional timeout for reading the response. + progress_callback: Optional callback for progress notifications. + meta: Optional meta parameters for the request. + extra_headers: Additional HTTP headers to include in this specific request. + These are merged with connection-level headers, with extra_headers + taking precedence for duplicate keys. Useful for per-request + authentication, tracing, debugging, A/B testing, and more. + """ _meta: types.RequestParams.Meta | None = None if meta is not None: @@ -345,6 +388,7 @@ async def call_tool( ), types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, + metadata=self._create_metadata_for_extra_headers(extra_headers), progress_callback=progress_callback, ) @@ -380,22 +424,26 @@ async def _validate_tool_result(self, name: str, result: types.CallToolResult) - async def list_prompts(self, cursor: str | None) -> types.ListPromptsResult: ... @overload - async def list_prompts(self, *, params: types.PaginatedRequestParams | None) -> types.ListPromptsResult: ... + async def list_prompts( + self, *, params: types.PaginatedRequestParams | None, extra_headers: dict[str, str] | None = None + ) -> types.ListPromptsResult: ... @overload - async def list_prompts(self) -> types.ListPromptsResult: ... + async def list_prompts(self, *, extra_headers: dict[str, str] | None = None) -> types.ListPromptsResult: ... async def list_prompts( self, cursor: str | None = None, *, params: types.PaginatedRequestParams | None = None, + extra_headers: dict[str, str] | None = None, ) -> types.ListPromptsResult: """Send a prompts/list request. Args: cursor: Simple cursor string for pagination (deprecated, use params instead) params: Full pagination parameters including cursor and any future fields + extra_headers: Additional HTTP headers to include in this specific request. """ if params is not None and cursor is not None: raise ValueError("Cannot specify both cursor and params") @@ -410,9 +458,12 @@ async def list_prompts( return await self.send_request( types.ClientRequest(types.ListPromptsRequest(params=request_params)), types.ListPromptsResult, + metadata=self._create_metadata_for_extra_headers(extra_headers), ) - async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None, *, extra_headers: dict[str, str] | None = None + ) -> types.GetPromptResult: """Send a prompts/get request.""" return await self.send_request( types.ClientRequest( @@ -421,6 +472,7 @@ async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) - ) ), types.GetPromptResult, + metadata=self._create_metadata_for_extra_headers(extra_headers), ) async def complete( @@ -452,22 +504,26 @@ async def complete( async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ... @overload - async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ... + async def list_tools( + self, *, params: types.PaginatedRequestParams | None, extra_headers: dict[str, str] | None = None + ) -> types.ListToolsResult: ... @overload - async def list_tools(self) -> types.ListToolsResult: ... + async def list_tools(self, *, extra_headers: dict[str, str] | None = None) -> types.ListToolsResult: ... async def list_tools( self, cursor: str | None = None, *, params: types.PaginatedRequestParams | None = None, + extra_headers: dict[str, str] | None = None, ) -> types.ListToolsResult: """Send a tools/list request. Args: cursor: Simple cursor string for pagination (deprecated, use params instead) params: Full pagination parameters including cursor and any future fields + extra_headers: Additional HTTP headers to include in this specific request. """ if params is not None and cursor is not None: raise ValueError("Cannot specify both cursor and params") @@ -482,6 +538,7 @@ async def list_tools( result = await self.send_request( types.ClientRequest(types.ListToolsRequest(params=request_params)), types.ListToolsResult, + metadata=self._create_metadata_for_extra_headers(extra_headers), ) # Cache tool output schemas for future validation diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 57df64705..ca0fb74c1 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -220,6 +220,11 @@ async def handle_get_stream( async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" headers = self._prepare_request_headers(ctx.headers) + + # Merge extra headers from metadata if present + if ctx.metadata and ctx.metadata.extra_headers: + headers.update(ctx.metadata.extra_headers) + if ctx.metadata and ctx.metadata.resumption_token: headers[LAST_EVENT_ID] = ctx.metadata.resumption_token else: @@ -254,6 +259,11 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" headers = self._prepare_request_headers(ctx.headers) + + # Merge extra headers from metadata if present + if ctx.metadata and ctx.metadata.extra_headers: + headers.update(ctx.metadata.extra_headers) + message = ctx.session_message.message is_initialization = self._is_initialization_request(message) diff --git a/src/mcp/shared/message.py b/src/mcp/shared/message.py index 4b6df23eb..86dd83a54 100644 --- a/src/mcp/shared/message.py +++ b/src/mcp/shared/message.py @@ -21,6 +21,7 @@ class ClientMessageMetadata: resumption_token: ResumptionToken | None = None on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None + extra_headers: dict[str, str] | None = None @dataclass diff --git a/tests/client/test_extra_headers.py b/tests/client/test_extra_headers.py new file mode 100644 index 000000000..8df79b854 --- /dev/null +++ b/tests/client/test_extra_headers.py @@ -0,0 +1,694 @@ +"""Tests for per-request headers functionality in call_tool.""" + +import anyio +import pytest + +from mcp.client.session import ClientSession +from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + CallToolResult, + ClientRequest, + Implementation, + InitializeRequest, + InitializeResult, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + ListToolsResult, + ServerCapabilities, + ServerResult, + TextContent, + Tool, +) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "extra_headers", + [ + None, + {}, + {"X-Auth-Token": "user-123-token", "X-Trace-Id": "trace-456"}, + ], +) +async def test_call_tool_with_extra_headers(extra_headers: dict[str, str] | None): + """Test that call_tool properly handles extra_headers parameter.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + mocked_tool = Tool(name="test_tool", inputSchema={}) + + async def mock_server(): + # Receive initialization request from client + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + # Answer initialization request + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Wait for the client to send a 'tools/call' request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + assert jsonrpc_request.root.method == "tools/call" + + # Verify that extra_headers are passed through metadata + if extra_headers: + # Check if the session message has metadata with extra headers + assert session_message.metadata is not None + assert isinstance(session_message.metadata, ClientMessageMetadata) + assert session_message.metadata.extra_headers == extra_headers + + result = ServerResult( + CallToolResult(content=[TextContent(type="text", text="Called successfully")], isError=False) + ) + + # Send the tools/call result + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Wait for the tools/list request from the client + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + assert jsonrpc_request.root.method == "tools/list" + + result = ListToolsResult(tools=[mocked_tool]) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + server_to_client_send.close() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + await session.initialize() + + # Call tool with extra_headers + result = await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, extra_headers=extra_headers) + + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Called successfully" + + +@pytest.mark.anyio +async def test_call_tool_combined_parameters(): + """Test call_tool with extra_headers combined with other parameters.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + mocked_tool = Tool(name="test_tool", inputSchema={}) + extra_headers = {"X-Custom": "test-value"} + meta = {"test_meta": "meta_value"} + + async def mock_server(): + # Receive initialization request from client + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, InitializeRequest) + + result = ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ) + + # Answer initialization request + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Wait for the client to send a 'tools/call' request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + assert jsonrpc_request.root.method == "tools/call" + + # Verify that meta is in the JSON-RPC params + assert jsonrpc_request.root.params + assert "_meta" in jsonrpc_request.root.params + assert jsonrpc_request.root.params["_meta"] == meta + + # Verify that extra_headers are in the session message metadata + assert session_message.metadata is not None + assert isinstance(session_message.metadata, ClientMessageMetadata) + assert session_message.metadata.extra_headers == extra_headers + + result = ServerResult( + CallToolResult(content=[TextContent(type="text", text="Called successfully")], isError=False) + ) + + # Send the tools/call result + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Wait for the tools/list request from the client + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + assert jsonrpc_request.root.method == "tools/list" + + result = ListToolsResult(tools=[mocked_tool]) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + server_to_client_send.close() + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + await session.initialize() + + # Call tool with both meta and extra_headers + result = await session.call_tool( + name=mocked_tool.name, arguments={"arg1": "value1"}, meta=meta, extra_headers=extra_headers + ) + + assert isinstance(result, CallToolResult) + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Called successfully" + + +def test_client_message_metadata_extra_headers(): + """Test that ClientMessageMetadata properly handles extra_headers.""" + # Test with extra_headers + headers = {"X-Test": "value", "Authorization": "Bearer token"} + metadata = ClientMessageMetadata(extra_headers=headers) + assert metadata.extra_headers == headers + + # Test without extra_headers + metadata = ClientMessageMetadata() + assert metadata.extra_headers is None + + # Test with all fields + metadata = ClientMessageMetadata(resumption_token="token-123", extra_headers=headers) + assert metadata.resumption_token == "token-123" + assert metadata.extra_headers == headers + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "extra_headers", + [ + None, + {}, + {"X-Log-Level": "debug", "X-Trace-Id": "trace-789"}, + ], +) +async def test_set_logging_level_with_extra_headers(extra_headers: dict[str, str] | None): + """Test that set_logging_level properly handles extra_headers parameter.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async def mock_server(): + # Receive initialization request from client + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + # Answer initialization request + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Wait for the client to send a 'logging/setLevel' request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert jsonrpc_request.root.method == "logging/setLevel" + + # Verify that extra_headers are passed through metadata + if extra_headers: + assert session_message.metadata is not None + assert isinstance(session_message.metadata, ClientMessageMetadata) + assert session_message.metadata.extra_headers == extra_headers + + # Send response + from mcp.types import EmptyResult + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult(EmptyResult()).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + server_to_client_send.close() + + async with ( + ClientSession(server_to_client_receive, client_to_server_send) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Call set_logging_level with extra_headers + result = await session.set_logging_level("debug", extra_headers=extra_headers) + + from mcp.types import EmptyResult + + assert isinstance(result, EmptyResult) + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "extra_headers", + [ + None, + {}, + {"X-Resource-Filter": "public", "X-Trace-Id": "trace-123"}, + ], +) +async def test_list_resources_with_extra_headers(extra_headers: dict[str, str] | None): + """Test that list_resources properly handles extra_headers parameter.""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) + + async def mock_server(): + # Handle initialization + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Wait for the client to send a 'resources/list' request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert jsonrpc_request.root.method == "resources/list" + + # Verify extra_headers metadata + if extra_headers: + assert session_message.metadata is not None + assert isinstance(session_message.metadata, ClientMessageMetadata) + assert session_message.metadata.extra_headers == extra_headers + + # Send response + from mcp.types import ListResourcesResult + + result = ServerResult(ListResourcesResult(resources=[])) + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + server_to_client_send.close() + + async with ( + ClientSession(server_to_client_receive, client_to_server_send) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Call list_resources with extra_headers + result = await session.list_resources(extra_headers=extra_headers) + + from mcp.types import ListResourcesResult + + assert isinstance(result, ListResourcesResult) + + +@pytest.mark.anyio +async def test_all_methods_without_extra_headers(): + """Test that all extended methods work correctly without extra_headers (no regression).""" + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](10) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) + + request_count = 0 + + async def mock_server(): + nonlocal request_count + + # Handle initialization + session_message = await client_to_server_receive.receive() + request_count += 1 + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Handle each method call + while True: + try: + session_message = await client_to_server_receive.receive() + request_count += 1 + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + # Verify no metadata is passed when extra_headers is None + assert session_message.metadata is None + + method = jsonrpc_request.root.method + + # Send appropriate response based on method + if method == "logging/setLevel": + from mcp.types import EmptyResult + + result = ServerResult(EmptyResult()) + elif method == "resources/list": + from mcp.types import ListResourcesResult + + result = ServerResult(ListResourcesResult(resources=[])) + elif method == "resources/templates/list": + from mcp.types import ListResourceTemplatesResult + + result = ServerResult(ListResourceTemplatesResult(resourceTemplates=[])) + elif method == "resources/read": + from mcp.types import ReadResourceResult + + result = ServerResult(ReadResourceResult(contents=[])) + elif method in ["resources/subscribe", "resources/unsubscribe"]: + from mcp.types import EmptyResult + + result = ServerResult(EmptyResult()) + elif method == "prompts/list": + from mcp.types import ListPromptsResult + + result = ServerResult(ListPromptsResult(prompts=[])) + elif method == "prompts/get": + from mcp.types import GetPromptResult + + result = ServerResult(GetPromptResult(messages=[])) + elif method == "tools/list": + from mcp.types import ListToolsResult + + result = ServerResult(ListToolsResult(tools=[])) + else: + continue + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + except anyio.EndOfStream: + break + + server_to_client_send.close() + + async with ( + ClientSession(server_to_client_receive, client_to_server_send) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Test all methods without extra_headers + await session.set_logging_level("info") + await session.list_resources() + await session.list_resource_templates() + from pydantic import AnyUrl + + test_uri = AnyUrl("file://test.txt") + await session.read_resource(test_uri) + await session.subscribe_resource(test_uri) + await session.unsubscribe_resource(test_uri) + await session.list_prompts() + await session.get_prompt("test_prompt") + await session.list_tools() + + +@pytest.mark.anyio +async def test_per_request_headers_take_precedence_over_connection_headers(): + """Test that per-request headers override connection-level headers when passed to metadata.""" + from mcp.types import EmptyResult + + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) + + # Track captured metadata from the session layer + captured_metadata: list[ClientMessageMetadata] = [] + + async def mock_server(): + # Handle initialization + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult( + InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + serverInfo=Implementation(name="mock-server", version="0.1.0"), + ) + ).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + + # Receive initialized notification + await client_to_server_receive.receive() + + # Handle the test request + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert jsonrpc_request.root.method == "logging/setLevel" + + # Capture the metadata that was passed with the request + if isinstance(session_message.metadata, ClientMessageMetadata): + captured_metadata.append(session_message.metadata) + + # Send response + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=ServerResult(EmptyResult()).model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + ) + server_to_client_send.close() + + async with ( + ClientSession(server_to_client_receive, client_to_server_send) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + await session.initialize() + + # Per-request headers that demonstrate the functionality + per_request_headers = { + "Authorization": "Bearer per-request-token", + "X-Request-ID": "req-456", + "X-Environment": "staging", + } + + # Make request with per-request headers + await session.set_logging_level("debug", extra_headers=per_request_headers) + + # Verify metadata was captured and contains our headers + assert len(captured_metadata) == 1 + metadata = captured_metadata[0] + assert metadata is not None + assert isinstance(metadata, ClientMessageMetadata) + assert metadata.extra_headers == per_request_headers