11from __future__ import annotations
22
33import base64
4+ import functools
45import json
56from abc import ABC , abstractmethod
67from collections .abc import AsyncIterator , Sequence
1112from typing import Any
1213
1314import anyio
15+ import httpx
1416from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
17+ from mcp .shared .message import SessionMessage
1518from mcp .types import (
1619 BlobResourceContents ,
1720 EmbeddedResource ,
1821 ImageContent ,
19- JSONRPCMessage ,
2022 LoggingLevel ,
2123 TextContent ,
2224 TextResourceContents ,
@@ -56,8 +58,8 @@ class MCPServer(ABC):
5658 """
5759
5860 _client : ClientSession
59- _read_stream : MemoryObjectReceiveStream [JSONRPCMessage | Exception ]
60- _write_stream : MemoryObjectSendStream [JSONRPCMessage ]
61+ _read_stream : MemoryObjectReceiveStream [SessionMessage | Exception ]
62+ _write_stream : MemoryObjectSendStream [SessionMessage ]
6163 _exit_stack : AsyncExitStack
6264
6365 @abstractmethod
@@ -66,8 +68,8 @@ async def client_streams(
6668 self ,
6769 ) -> AsyncIterator [
6870 tuple [
69- MemoryObjectReceiveStream [JSONRPCMessage | Exception ],
70- MemoryObjectSendStream [JSONRPCMessage ],
71+ MemoryObjectReceiveStream [SessionMessage | Exception ],
72+ MemoryObjectSendStream [SessionMessage ],
7173 ]
7274 ]:
7375 """Create the streams for the MCP server."""
@@ -266,8 +268,8 @@ async def client_streams(
266268 self ,
267269 ) -> AsyncIterator [
268270 tuple [
269- MemoryObjectReceiveStream [JSONRPCMessage | Exception ],
270- MemoryObjectSendStream [JSONRPCMessage ],
271+ MemoryObjectReceiveStream [SessionMessage | Exception ],
272+ MemoryObjectSendStream [SessionMessage ],
271273 ]
272274 ]:
273275 server = StdioServerParameters (command = self .command , args = list (self .args ), env = self .env , cwd = self .cwd )
@@ -326,6 +328,31 @@ async def main():
326328
327329 These headers will be passed directly to the underlying `httpx.AsyncClient`.
328330 Useful for authentication, custom headers, or other HTTP-specific configurations.
331+
332+ !!! note
333+ You can either pass `headers` or `http_client`, but not both.
334+
335+ See [`MCPServerHTTP.http_client`][pydantic_ai.mcp.MCPServerHTTP.http_client] for more information.
336+ """
337+
338+ http_client : httpx .AsyncClient | None = None
339+ """An `httpx.AsyncClient` to use with the SSE endpoint.
340+
341+ This client may be configured to use customized connection parameters like self-signed certificates.
342+
343+ !!! note
344+ You can either pass `headers` or `http_client`, but not both.
345+
346+ If you want to use both, you can pass the headers to the `http_client` instead:
347+
348+ ```python {py="3.10"}
349+ import httpx
350+
351+ from pydantic_ai.mcp import MCPServerHTTP
352+
353+ http_client = httpx.AsyncClient(headers={'Authorization': 'Bearer ...'})
354+ server = MCPServerHTTP('http://localhost:3001/sse', http_client=http_client)
355+ ```
329356 """
330357
331358 timeout : float = 5
@@ -362,18 +389,33 @@ async def main():
362389 async def client_streams (
363390 self ,
364391 ) -> AsyncIterator [
365- tuple [
366- MemoryObjectReceiveStream [JSONRPCMessage | Exception ],
367- MemoryObjectSendStream [JSONRPCMessage ],
368- ]
392+ tuple [MemoryObjectReceiveStream [SessionMessage | Exception ], MemoryObjectSendStream [SessionMessage ]]
369393 ]: # pragma: no cover
370- async with sse_client (
394+ if self .http_client and self .headers :
395+ raise ValueError ('`http_client` is mutually exclusive with `headers`.' )
396+
397+ sse_client_partial = functools .partial (
398+ sse_client ,
371399 url = self .url ,
372- headers = self .headers ,
373400 timeout = self .timeout ,
374401 sse_read_timeout = self .sse_read_timeout ,
375- ) as (read_stream , write_stream ):
376- yield read_stream , write_stream
402+ )
403+
404+ if self .http_client is not None :
405+
406+ def httpx_client_factory (
407+ headers : dict [str , str ] | None = None ,
408+ timeout : httpx .Timeout | None = None ,
409+ auth : httpx .Auth | None = None ,
410+ ) -> httpx .AsyncClient :
411+ assert self .http_client is not None
412+ return self .http_client
413+
414+ async with sse_client_partial (httpx_client_factory = httpx_client_factory ) as (read_stream , write_stream ):
415+ yield read_stream , write_stream
416+ else :
417+ async with sse_client_partial (headers = self .headers ) as (read_stream , write_stream ):
418+ yield read_stream , write_stream
377419
378420 def _get_log_level (self ) -> LoggingLevel | None :
379421 return self .log_level
0 commit comments