Skip to content

Commit 0e86471

Browse files
committed
test: add coverage for extra_headers in HTTP transport
- Add test_streamablehttp_client_tool_invocation_with_extra_headers for POST requests - Add test_streamablehttp_client_resumption_with_extra_headers for resumption with extra headers - Refactor common resumption setup code into _setup_resumption_test helper - Achieve 100% coverage for streamable_http.py
1 parent 876e13a commit 0e86471

File tree

1 file changed

+118
-15
lines changed

1 file changed

+118
-15
lines changed

tests/shared/test_streamable_http.py

Lines changed: 118 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,19 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session:
893893
assert result.content[0].text == "Called test_tool"
894894

895895

896+
@pytest.mark.anyio
897+
async def test_streamablehttp_client_tool_invocation_with_extra_headers(initialized_client_session: ClientSession):
898+
"""Test HTTP POST request with extra headers."""
899+
result = await initialized_client_session.call_tool(
900+
"test_tool",
901+
{},
902+
extra_headers={"X-Custom-Header": "test-value"},
903+
)
904+
assert len(result.content) == 1
905+
assert result.content[0].type == "text"
906+
assert result.content[0].text == "Called test_tool"
907+
908+
896909
@pytest.mark.anyio
897910
async def test_streamablehttp_client_error_handling(initialized_client_session: ClientSession):
898911
"""Test error handling in client."""
@@ -1106,26 +1119,27 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt
11061119
await session.list_tools()
11071120

11081121

1109-
@pytest.mark.anyio
1110-
async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventStore, str]):
1111-
"""Test client session resumption using sync primitives for reliable coordination."""
1112-
_, server_url = event_server
1122+
async def _setup_resumption_test(
1123+
server_url: str,
1124+
) -> tuple[str | None, str | None, str | int | None, list[types.ServerNotification]]:
1125+
"""Helper function to set up a resumption test by starting a session and capturing resumption state.
11131126
1114-
# Variables to track the state
1115-
captured_session_id = None
1116-
captured_resumption_token = None
1117-
captured_notifications: list[types.ServerNotification] = []
1118-
captured_protocol_version = None
1119-
first_notification_received = False
1127+
Returns:
1128+
Tuple of (session_id, resumption_token, protocol_version, notifications)
1129+
"""
1130+
captured_session_id = None # pragma: no cover
1131+
captured_resumption_token = None # pragma: no cover
1132+
captured_notifications: list[types.ServerNotification] = [] # pragma: no cover
1133+
captured_protocol_version = None # pragma: no cover
1134+
first_notification_received = False # pragma: no cover
11201135

11211136
async def message_handler( # pragma: no branch
11221137
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1123-
) -> None:
1138+
) -> None: # pragma: no cover
11241139
if isinstance(message, types.ServerNotification): # pragma: no branch
11251140
captured_notifications.append(message)
1126-
# Look for our first notification
11271141
if isinstance(message.root, types.LoggingMessageNotification): # pragma: no branch
1128-
if message.root.params.data == "First notification before lock":
1142+
if message.root.params.data == "First notification before lock": # pragma: no branch
11291143
nonlocal first_notification_received
11301144
first_notification_received = True
11311145

@@ -1181,8 +1195,95 @@ async def run_tool():
11811195
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification) # pragma: no cover
11821196
assert captured_notifications[0].root.params.data == "First notification before lock" # pragma: no cover
11831197

1184-
# Clear notifications for the second phase
1185-
captured_notifications = [] # pragma: no cover
1198+
return (
1199+
captured_session_id,
1200+
captured_resumption_token,
1201+
captured_protocol_version,
1202+
captured_notifications,
1203+
) # pragma: no cover # noqa: E501
1204+
1205+
1206+
@pytest.mark.anyio
1207+
async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventStore, str]):
1208+
"""Test client session resumption using sync primitives for reliable coordination."""
1209+
_, server_url = event_server
1210+
1211+
# Set up the initial session and capture resumption state
1212+
captured_session_id, captured_resumption_token, captured_protocol_version, _ = await _setup_resumption_test(
1213+
server_url
1214+
)
1215+
1216+
# Track notifications for the resumed session
1217+
captured_notifications: list[types.ServerNotification] = [] # pragma: no cover
1218+
1219+
async def message_handler( # pragma: no branch
1220+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1221+
) -> None: # pragma: no cover
1222+
if isinstance(message, types.ServerNotification): # pragma: no branch
1223+
captured_notifications.append(message)
1224+
1225+
# Now resume the session with the same mcp-session-id and protocol version
1226+
headers: dict[str, Any] = {} # pragma: no cover
1227+
if captured_session_id: # pragma: no cover
1228+
headers[MCP_SESSION_ID_HEADER] = captured_session_id
1229+
if captured_protocol_version: # pragma: no cover
1230+
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version
1231+
async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
1232+
read_stream,
1233+
write_stream,
1234+
_,
1235+
):
1236+
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
1237+
result = await session.send_request(
1238+
types.ClientRequest(
1239+
types.CallToolRequest(
1240+
params=types.CallToolRequestParams(name="release_lock", arguments={}),
1241+
)
1242+
),
1243+
types.CallToolResult,
1244+
)
1245+
metadata = ClientMessageMetadata(
1246+
resumption_token=captured_resumption_token,
1247+
)
1248+
1249+
result = await session.send_request(
1250+
types.ClientRequest(
1251+
types.CallToolRequest(
1252+
params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}),
1253+
)
1254+
),
1255+
types.CallToolResult,
1256+
metadata=metadata,
1257+
)
1258+
assert len(result.content) == 1
1259+
assert result.content[0].type == "text"
1260+
assert result.content[0].text == "Completed"
1261+
1262+
# We should have received the remaining notifications
1263+
assert len(captured_notifications) == 1
1264+
1265+
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification)
1266+
assert captured_notifications[0].root.params.data == "Second notification after lock"
1267+
1268+
1269+
@pytest.mark.anyio
1270+
async def test_streamablehttp_client_resumption_with_extra_headers(event_server: tuple[SimpleEventStore, str]):
1271+
"""Test client session resumption with extra headers."""
1272+
_, server_url = event_server
1273+
1274+
# Set up the initial session and capture resumption state
1275+
captured_session_id, captured_resumption_token, captured_protocol_version, _ = await _setup_resumption_test(
1276+
server_url
1277+
)
1278+
1279+
# Track notifications for the resumed session
1280+
captured_notifications: list[types.ServerNotification] = [] # pragma: no cover
1281+
1282+
async def message_handler( # pragma: no branch
1283+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
1284+
) -> None: # pragma: no cover
1285+
if isinstance(message, types.ServerNotification): # pragma: no branch
1286+
captured_notifications.append(message)
11861287

11871288
# Now resume the session with the same mcp-session-id and protocol version
11881289
headers: dict[str, Any] = {} # pragma: no cover
@@ -1204,8 +1305,10 @@ async def run_tool():
12041305
),
12051306
types.CallToolResult,
12061307
)
1308+
# Test resumption WITH extra_headers
12071309
metadata = ClientMessageMetadata(
12081310
resumption_token=captured_resumption_token,
1311+
extra_headers={"X-Resumption-Test": "test-value"},
12091312
)
12101313

12111314
result = await session.send_request(

0 commit comments

Comments
 (0)