Skip to content

Commit 9cef4ff

Browse files
Add tests and API stubs for SSE polling support (SEP-1699)
Sets up test infrastructure and API surface for SEP-1699 SSE polling. Tests will fail until implementation is complete. New APIs (stubbed): - StreamableHTTPReconnectionOptions dataclass - Server: _create_priming_event(), close_sse_stream(), retry_interval - Client: resume_stream(), _get_next_reconnection_delay() - RequestContext.close_sse_stream callback Github-Issue:#1654
1 parent 5983a65 commit 9cef4ff

File tree

12 files changed

+1051
-53
lines changed

12 files changed

+1051
-53
lines changed

examples/servers/everything-server/mcp_everything_server/server.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,29 @@
1414
from mcp.server.fastmcp import Context, FastMCP
1515
from mcp.server.fastmcp.prompts.base import UserMessage
1616
from mcp.server.session import ServerSession
17+
from mcp.server.streamable_http import (
18+
EventCallback,
19+
EventId,
20+
EventMessage,
21+
EventStore,
22+
StreamId,
23+
)
1724
from mcp.types import (
1825
AudioContent,
1926
Completion,
2027
CompletionArgument,
2128
CompletionContext,
2229
EmbeddedResource,
2330
ImageContent,
31+
JSONRPCMessage,
2432
PromptReference,
2533
ResourceTemplateReference,
2634
SamplingMessage,
2735
TextContent,
2836
TextResourceContents,
2937
)
3038
from pydantic import AnyUrl, BaseModel, Field
39+
from starlette.requests import Request
3140

3241
logger = logging.getLogger(__name__)
3342

@@ -39,8 +48,47 @@
3948
resource_subscriptions: set[str] = set()
4049
watched_resource_content = "Watched resource content"
4150

51+
52+
# Simple in-memory event store for SSE polling resumability (SEP-1699)
53+
class SimpleEventStore(EventStore):
54+
"""Simple in-memory event store for testing resumability."""
55+
56+
def __init__(self) -> None:
57+
self._events: list[tuple[StreamId, EventId, JSONRPCMessage]] = []
58+
self._event_id_counter = 0
59+
60+
async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
61+
"""Store an event and return its ID."""
62+
self._event_id_counter += 1
63+
event_id = str(self._event_id_counter)
64+
self._events.append((stream_id, event_id, message))
65+
return event_id
66+
67+
async def replay_events_after(
68+
self,
69+
last_event_id: EventId,
70+
send_callback: EventCallback,
71+
) -> StreamId | None:
72+
"""Replay events after the specified ID."""
73+
target_stream_id = None
74+
found = False
75+
for stream_id, event_id, message in self._events:
76+
if event_id == last_event_id:
77+
target_stream_id = stream_id
78+
found = True
79+
continue
80+
if found and stream_id == target_stream_id:
81+
await send_callback(EventMessage(message=message, event_id=event_id))
82+
return target_stream_id
83+
84+
85+
# Create event store for resumability
86+
event_store = SimpleEventStore()
87+
4288
mcp = FastMCP(
4389
name="mcp-conformance-test-server",
90+
event_store=event_store,
91+
sse_retry_interval=3000, # 3 seconds
4492
)
4593

4694

@@ -257,6 +305,33 @@ async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> s
257305
return f"Elicitation not supported or error: {str(e)}"
258306

259307

308+
@mcp.tool()
309+
async def test_reconnection(ctx: Context[ServerSession, None]) -> str:
310+
"""Tests SSE polling via server-initiated disconnect (SEP-1699)
311+
312+
This tool closes the SSE stream mid-call, requiring the client to reconnect
313+
with Last-Event-ID to receive the remaining events.
314+
"""
315+
# Send notification before disconnect
316+
await ctx.info("Notification before disconnect")
317+
318+
# Get session_id from request headers
319+
request = ctx.request_context.request
320+
if isinstance(request, Request):
321+
session_id = request.headers.get("mcp-session-id")
322+
if session_id:
323+
# Trigger server-initiated SSE disconnect
324+
await mcp.session_manager.close_sse_stream(session_id, ctx.request_id)
325+
326+
# Wait for client to reconnect
327+
await asyncio.sleep(0.2)
328+
329+
# Send notification after disconnect (will be replayed via event store)
330+
await ctx.info("Notification after disconnect")
331+
332+
return "Reconnection test completed successfully"
333+
334+
260335
@mcp.tool()
261336
def test_error_handling() -> str:
262337
"""Tests error response handling"""
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
SSE Polling Example Client
3+
4+
Demonstrates client-side behavior during server-initiated SSE disconnect.
5+
6+
Key features:
7+
- Automatic reconnection when server closes SSE stream
8+
- Event replay via Last-Event-ID header (handled internally by the transport)
9+
- Progress notifications via logging callback
10+
11+
This client connects to the SSE polling server and calls the `long-task` tool.
12+
The server disconnects at 50% progress, and the client automatically reconnects
13+
to receive remaining progress updates.
14+
15+
Run:
16+
# First start the server:
17+
uv run examples/snippets/servers/sse_polling_server.py
18+
19+
# Then run this client:
20+
uv run examples/snippets/clients/sse_polling_client.py
21+
"""
22+
23+
import asyncio
24+
import logging
25+
26+
from mcp import ClientSession
27+
from mcp.client.streamable_http import StreamableHTTPReconnectionOptions, streamablehttp_client
28+
from mcp.types import LoggingMessageNotificationParams, TextContent
29+
30+
logging.basicConfig(
31+
level=logging.INFO,
32+
format="%(asctime)s - %(levelname)s - %(message)s",
33+
)
34+
logger = logging.getLogger(__name__)
35+
36+
37+
async def main() -> None:
38+
print("SSE Polling Example Client")
39+
print("=" * 50)
40+
print()
41+
42+
# Track notifications received via the logging callback
43+
notifications_received: list[str] = []
44+
45+
async def logging_callback(params: LoggingMessageNotificationParams) -> None:
46+
"""Called when a log message notification is received from the server."""
47+
data = params.data
48+
if data:
49+
data_str = str(data)
50+
notifications_received.append(data_str)
51+
print(f"[Progress] {data_str}")
52+
53+
# Configure reconnection behavior
54+
reconnection_options = StreamableHTTPReconnectionOptions(
55+
initial_reconnection_delay=1.0, # Start with 1 second
56+
max_reconnection_delay=30.0, # Cap at 30 seconds
57+
reconnection_delay_grow_factor=1.5, # Exponential backoff
58+
max_retries=5, # Try up to 5 times
59+
)
60+
61+
print("[Client] Connecting to server...")
62+
63+
async with streamablehttp_client(
64+
"http://localhost:3001/mcp",
65+
reconnection_options=reconnection_options,
66+
) as (read_stream, write_stream, get_session_id):
67+
# Create session with logging callback to receive progress notifications
68+
async with ClientSession(
69+
read_stream,
70+
write_stream,
71+
logging_callback=logging_callback,
72+
) as session:
73+
# Initialize the session
74+
await session.initialize()
75+
session_id = get_session_id()
76+
print(f"[Client] Connected! Session ID: {session_id}")
77+
78+
# List available tools
79+
tools = await session.list_tools()
80+
tool_names = [t.name for t in tools.tools]
81+
print(f"[Client] Available tools: {tool_names}")
82+
print()
83+
84+
# Call the long-running task
85+
print("[Client] Calling long-task tool...")
86+
print("[Client] The server will disconnect at 50% and we'll auto-reconnect")
87+
print()
88+
89+
# Call the tool
90+
result = await session.call_tool("long-task", {})
91+
92+
print()
93+
print("[Client] Task completed!")
94+
if result.content and isinstance(result.content[0], TextContent):
95+
print(f"[Result] {result.content[0].text}")
96+
else:
97+
print("[Result] No content")
98+
print()
99+
print(f"[Summary] Received {len(notifications_received)} progress notifications")
100+
101+
102+
if __name__ == "__main__":
103+
asyncio.run(main())

0 commit comments

Comments
 (0)