@@ -893,6 +893,19 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session:
893893 assert result .content [0 ].text == "Called test_tool"
894894
895895
896+ @pytest .mark .anyio
897+ async def test_streamablehttp_client_tool_invocation_with_extra_headers (initialized_client_session : ClientSession ):
898+ """Test HTTP POST request with extra headers."""
899+ result = await initialized_client_session .call_tool (
900+ "test_tool" ,
901+ {},
902+ extra_headers = {"X-Custom-Header" : "test-value" },
903+ )
904+ assert len (result .content ) == 1
905+ assert result .content [0 ].type == "text"
906+ assert result .content [0 ].text == "Called test_tool"
907+
908+
896909@pytest .mark .anyio
897910async def test_streamablehttp_client_error_handling (initialized_client_session : ClientSession ):
898911 """Test error handling in client."""
@@ -1106,26 +1119,27 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11061119 await session .list_tools ()
11071120
11081121
1109- @ pytest . mark . anyio
1110- async def test_streamablehttp_client_resumption ( event_server : tuple [ SimpleEventStore , str ]):
1111- """Test client session resumption using sync primitives for reliable coordination."""
1112- _ , server_url = event_server
1122+ async def _setup_resumption_test (
1123+ server_url : str ,
1124+ ) -> tuple [ str | None , str | None , str | int | None , list [ types . ServerNotification ]]:
1125+ """Helper function to set up a resumption test by starting a session and capturing resumption state.
11131126
1114- # Variables to track the state
1115- captured_session_id = None
1116- captured_resumption_token = None
1117- captured_notifications : list [types .ServerNotification ] = []
1118- captured_protocol_version = None
1119- first_notification_received = False
1127+ Returns:
1128+ Tuple of (session_id, resumption_token, protocol_version, notifications)
1129+ """
1130+ captured_session_id = None # pragma: no cover
1131+ captured_resumption_token = None # pragma: no cover
1132+ captured_notifications : list [types .ServerNotification ] = [] # pragma: no cover
1133+ captured_protocol_version = None # pragma: no cover
1134+ first_notification_received = False # pragma: no cover
11201135
11211136 async def message_handler ( # pragma: no branch
11221137 message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
1123- ) -> None :
1138+ ) -> None : # pragma: no cover
11241139 if isinstance (message , types .ServerNotification ): # pragma: no branch
11251140 captured_notifications .append (message )
1126- # Look for our first notification
11271141 if isinstance (message .root , types .LoggingMessageNotification ): # pragma: no branch
1128- if message .root .params .data == "First notification before lock" :
1142+ if message .root .params .data == "First notification before lock" : # pragma: no branch
11291143 nonlocal first_notification_received
11301144 first_notification_received = True
11311145
@@ -1181,8 +1195,95 @@ async def run_tool():
11811195 assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification ) # pragma: no cover
11821196 assert captured_notifications [0 ].root .params .data == "First notification before lock" # pragma: no cover
11831197
1184- # Clear notifications for the second phase
1185- captured_notifications = [] # pragma: no cover
1198+ return (
1199+ captured_session_id ,
1200+ captured_resumption_token ,
1201+ captured_protocol_version ,
1202+ captured_notifications ,
1203+ ) # pragma: no cover # noqa: E501
1204+
1205+
1206+ @pytest .mark .anyio
1207+ async def test_streamablehttp_client_resumption (event_server : tuple [SimpleEventStore , str ]):
1208+ """Test client session resumption using sync primitives for reliable coordination."""
1209+ _ , server_url = event_server
1210+
1211+ # Set up the initial session and capture resumption state
1212+ captured_session_id , captured_resumption_token , captured_protocol_version , _ = await _setup_resumption_test (
1213+ server_url
1214+ )
1215+
1216+ # Track notifications for the resumed session
1217+ captured_notifications : list [types .ServerNotification ] = [] # pragma: no cover
1218+
1219+ async def message_handler ( # pragma: no branch
1220+ message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
1221+ ) -> None : # pragma: no cover
1222+ if isinstance (message , types .ServerNotification ): # pragma: no branch
1223+ captured_notifications .append (message )
1224+
1225+ # Now resume the session with the same mcp-session-id and protocol version
1226+ headers : dict [str , Any ] = {} # pragma: no cover
1227+ if captured_session_id : # pragma: no cover
1228+ headers [MCP_SESSION_ID_HEADER ] = captured_session_id
1229+ if captured_protocol_version : # pragma: no cover
1230+ headers [MCP_PROTOCOL_VERSION_HEADER ] = captured_protocol_version
1231+ async with streamablehttp_client (f"{ server_url } /mcp" , headers = headers ) as (
1232+ read_stream ,
1233+ write_stream ,
1234+ _ ,
1235+ ):
1236+ async with ClientSession (read_stream , write_stream , message_handler = message_handler ) as session :
1237+ result = await session .send_request (
1238+ types .ClientRequest (
1239+ types .CallToolRequest (
1240+ params = types .CallToolRequestParams (name = "release_lock" , arguments = {}),
1241+ )
1242+ ),
1243+ types .CallToolResult ,
1244+ )
1245+ metadata = ClientMessageMetadata (
1246+ resumption_token = captured_resumption_token ,
1247+ )
1248+
1249+ result = await session .send_request (
1250+ types .ClientRequest (
1251+ types .CallToolRequest (
1252+ params = types .CallToolRequestParams (name = "wait_for_lock_with_notification" , arguments = {}),
1253+ )
1254+ ),
1255+ types .CallToolResult ,
1256+ metadata = metadata ,
1257+ )
1258+ assert len (result .content ) == 1
1259+ assert result .content [0 ].type == "text"
1260+ assert result .content [0 ].text == "Completed"
1261+
1262+ # We should have received the remaining notifications
1263+ assert len (captured_notifications ) == 1
1264+
1265+ assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification )
1266+ assert captured_notifications [0 ].root .params .data == "Second notification after lock"
1267+
1268+
1269+ @pytest .mark .anyio
1270+ async def test_streamablehttp_client_resumption_with_extra_headers (event_server : tuple [SimpleEventStore , str ]):
1271+ """Test client session resumption with extra headers."""
1272+ _ , server_url = event_server
1273+
1274+ # Set up the initial session and capture resumption state
1275+ captured_session_id , captured_resumption_token , captured_protocol_version , _ = await _setup_resumption_test (
1276+ server_url
1277+ )
1278+
1279+ # Track notifications for the resumed session
1280+ captured_notifications : list [types .ServerNotification ] = [] # pragma: no cover
1281+
1282+ async def message_handler ( # pragma: no branch
1283+ message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
1284+ ) -> None : # pragma: no cover
1285+ if isinstance (message , types .ServerNotification ): # pragma: no branch
1286+ captured_notifications .append (message )
11861287
11871288 # Now resume the session with the same mcp-session-id and protocol version
11881289 headers : dict [str , Any ] = {} # pragma: no cover
@@ -1204,8 +1305,10 @@ async def run_tool():
12041305 ),
12051306 types .CallToolResult ,
12061307 )
1308+ # Test resumption WITH extra_headers
12071309 metadata = ClientMessageMetadata (
12081310 resumption_token = captured_resumption_token ,
1311+ extra_headers = {"X-Resumption-Test" : "test-value" },
12091312 )
12101313
12111314 result = await session .send_request (
0 commit comments