Skip to content

Commit 9059c04

Browse files
Implement SSE polling resumability (SEP-1699)
- Add client auto-reconnection with configurable backoff options - Add close_sse_stream API for server-initiated disconnects - Add priming events and event store support for resumability - Add sse_retry_interval setting to FastMCP - Add test coverage for SSE polling reconnection
1 parent 091afb8 commit 9059c04

File tree

31 files changed

+3020
-82
lines changed

31 files changed

+3020
-82
lines changed

README.md

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,10 +808,21 @@ Request additional information from users. This example shows an Elicitation dur
808808

809809
<!-- snippet-source examples/snippets/servers/elicitation.py -->
810810
```python
811+
"""Elicitation examples demonstrating form and URL mode elicitation.
812+
813+
Form mode elicitation collects structured, non-sensitive data through a schema.
814+
URL mode elicitation directs users to external URLs for sensitive operations
815+
like OAuth flows, credential collection, or payment processing.
816+
"""
817+
818+
import uuid
819+
811820
from pydantic import BaseModel, Field
812821

813822
from mcp.server.fastmcp import Context, FastMCP
814823
from mcp.server.session import ServerSession
824+
from mcp.shared.exceptions import UrlElicitationRequiredError
825+
from mcp.types import ElicitRequestURLParams
815826

816827
mcp = FastMCP(name="Elicitation Example")
817828

@@ -828,7 +839,10 @@ class BookingPreferences(BaseModel):
828839

829840
@mcp.tool()
830841
async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerSession, None]) -> str:
831-
"""Book a table with date availability check."""
842+
"""Book a table with date availability check.
843+
844+
This demonstrates form mode elicitation for collecting non-sensitive user input.
845+
"""
832846
# Check if date is available
833847
if date == "2024-12-25":
834848
# Date unavailable - ask user for alternative
@@ -845,6 +859,54 @@ async def book_table(date: str, time: str, party_size: int, ctx: Context[ServerS
845859

846860
# Date available
847861
return f"[SUCCESS] Booked for {date} at {time}"
862+
863+
864+
@mcp.tool()
865+
async def secure_payment(amount: float, ctx: Context[ServerSession, None]) -> str:
866+
"""Process a secure payment requiring URL confirmation.
867+
868+
This demonstrates URL mode elicitation using ctx.elicit_url() for
869+
operations that require out-of-band user interaction.
870+
"""
871+
elicitation_id = str(uuid.uuid4())
872+
873+
result = await ctx.elicit_url(
874+
message=f"Please confirm payment of ${amount:.2f}",
875+
url=f"https://payments.example.com/confirm?amount={amount}&id={elicitation_id}",
876+
elicitation_id=elicitation_id,
877+
)
878+
879+
if result.action == "accept":
880+
# In a real app, the payment confirmation would happen out-of-band
881+
# and you'd verify the payment status from your backend
882+
return f"Payment of ${amount:.2f} initiated - check your browser to complete"
883+
elif result.action == "decline":
884+
return "Payment declined by user"
885+
return "Payment cancelled"
886+
887+
888+
@mcp.tool()
889+
async def connect_service(service_name: str, ctx: Context[ServerSession, None]) -> str:
890+
"""Connect to a third-party service requiring OAuth authorization.
891+
892+
This demonstrates the "throw error" pattern using UrlElicitationRequiredError.
893+
Use this pattern when the tool cannot proceed without user authorization.
894+
"""
895+
elicitation_id = str(uuid.uuid4())
896+
897+
# Raise UrlElicitationRequiredError to signal that the client must complete
898+
# a URL elicitation before this request can be processed.
899+
# The MCP framework will convert this to a -32042 error response.
900+
raise UrlElicitationRequiredError(
901+
[
902+
ElicitRequestURLParams(
903+
mode="url",
904+
message=f"Authorization required to connect to {service_name}",
905+
url=f"https://{service_name}.example.com/oauth/authorize?elicit={elicitation_id}",
906+
elicitationId=elicitation_id,
907+
)
908+
]
909+
)
848910
```
849911

850912
_Full example: [examples/snippets/servers/elicitation.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/examples/snippets/servers/elicitation.py)_

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,15 @@ def get_state(self):
150150
class SimpleAuthClient:
151151
"""Simple MCP client with auth support."""
152152

153-
def __init__(self, server_url: str, transport_type: str = "streamable-http"):
153+
def __init__(
154+
self,
155+
server_url: str,
156+
transport_type: str = "streamable-http",
157+
client_metadata_url: str | None = None,
158+
):
154159
self.server_url = server_url
155160
self.transport_type = transport_type
161+
self.client_metadata_url = client_metadata_url
156162
self.session: ClientSession | None = None
157163

158164
async def connect(self):
@@ -185,12 +191,14 @@ async def _default_redirect_handler(authorization_url: str) -> None:
185191
webbrowser.open(authorization_url)
186192

187193
# Create OAuth authentication handler using the new interface
194+
# Use client_metadata_url to enable CIMD when the server supports it
188195
oauth_auth = OAuthClientProvider(
189196
server_url=self.server_url,
190197
client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
191198
storage=InMemoryTokenStorage(),
192199
redirect_handler=_default_redirect_handler,
193200
callback_handler=callback_handler,
201+
client_metadata_url=self.client_metadata_url,
194202
)
195203

196204
# Create transport with auth handler based on transport type
@@ -334,6 +342,7 @@ async def main():
334342
# Most MCP streamable HTTP servers use /mcp as the endpoint
335343
server_url = os.getenv("MCP_SERVER_PORT", 8000)
336344
transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable-http")
345+
client_metadata_url = os.getenv("MCP_CLIENT_METADATA_URL")
337346
server_url = (
338347
f"http://localhost:{server_url}/mcp"
339348
if transport_type == "streamable-http"
@@ -343,9 +352,11 @@ async def main():
343352
print("🚀 Simple MCP Auth Client")
344353
print(f"Connecting to: {server_url}")
345354
print(f"Transport type: {transport_type}")
355+
if client_metadata_url:
356+
print(f"Client metadata URL: {client_metadata_url}")
346357

347358
# Start connection flow - OAuth will be handled automatically
348-
client = SimpleAuthClient(server_url, transport_type)
359+
client = SimpleAuthClient(server_url, transport_type, client_metadata_url)
349360
await client.connect()
350361

351362

examples/servers/everything-server/mcp_everything_server/server.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,21 @@
1414
from mcp.server.fastmcp import Context, FastMCP
1515
from mcp.server.fastmcp.prompts.base import UserMessage
1616
from mcp.server.session import ServerSession
17+
from mcp.server.streamable_http import (
18+
EventCallback,
19+
EventId,
20+
EventMessage,
21+
EventStore,
22+
StreamId,
23+
)
1724
from mcp.types import (
1825
AudioContent,
1926
Completion,
2027
CompletionArgument,
2128
CompletionContext,
2229
EmbeddedResource,
2330
ImageContent,
31+
JSONRPCMessage,
2432
PromptReference,
2533
ResourceTemplateReference,
2634
SamplingMessage,
@@ -39,8 +47,47 @@
3947
resource_subscriptions: set[str] = set()
4048
watched_resource_content = "Watched resource content"
4149

50+
51+
# Simple in-memory event store for SSE polling resumability (SEP-1699)
52+
class SimpleEventStore(EventStore):
53+
"""Simple in-memory event store for testing resumability."""
54+
55+
def __init__(self) -> None:
56+
self._events: list[tuple[StreamId, EventId, JSONRPCMessage]] = []
57+
self._event_id_counter = 0
58+
59+
async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
60+
"""Store an event and return its ID."""
61+
self._event_id_counter += 1
62+
event_id = str(self._event_id_counter)
63+
self._events.append((stream_id, event_id, message))
64+
return event_id
65+
66+
async def replay_events_after(
67+
self,
68+
last_event_id: EventId,
69+
send_callback: EventCallback,
70+
) -> StreamId | None:
71+
"""Replay events after the specified ID."""
72+
target_stream_id = None
73+
found = False
74+
for stream_id, event_id, message in self._events:
75+
if event_id == last_event_id:
76+
target_stream_id = stream_id
77+
found = True
78+
continue
79+
if found and stream_id == target_stream_id:
80+
await send_callback(EventMessage(message=message, event_id=event_id))
81+
return target_stream_id
82+
83+
84+
# Create event store for resumability
85+
event_store = SimpleEventStore()
86+
4287
mcp = FastMCP(
4388
name="mcp-conformance-test-server",
89+
event_store=event_store,
90+
sse_retry_interval=3000, # 3 seconds
4491
)
4592

4693

@@ -257,6 +304,31 @@ async def test_elicitation_sep1330_enums(ctx: Context[ServerSession, None]) -> s
257304
return f"Elicitation not supported or error: {str(e)}"
258305

259306

307+
@mcp.tool()
308+
async def test_reconnection(ctx: Context[ServerSession, None]) -> str:
309+
"""Tests SSE polling via server-initiated disconnect (SEP-1699)
310+
311+
This tool closes the SSE stream mid-call, requiring the client to reconnect
312+
with Last-Event-ID to receive the remaining events.
313+
"""
314+
# Send notification before disconnect
315+
await ctx.info("Notification before disconnect")
316+
317+
# Use the close_sse_stream callback if available
318+
# This is None if not on streamable HTTP transport or no event store configured
319+
if ctx.close_sse_stream:
320+
# Trigger server-initiated SSE disconnect with optional retry interval
321+
await ctx.close_sse_stream(retry_interval=3000) # 3 seconds
322+
323+
# Wait for client to reconnect
324+
await asyncio.sleep(0.2)
325+
326+
# Send notification after disconnect (will be replayed via event store)
327+
await ctx.info("Notification after disconnect")
328+
329+
return "Reconnection test completed successfully"
330+
331+
260332
@mcp.tool()
261333
def test_error_handling() -> str:
262334
"""Tests error response handling"""
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
SSE Polling Example Client
3+
4+
Demonstrates client-side behavior during server-initiated SSE disconnect.
5+
6+
Key features:
7+
- Automatic reconnection when server closes SSE stream
8+
- Event replay via Last-Event-ID header (handled internally by the transport)
9+
- Progress notifications via logging callback
10+
11+
This client connects to the SSE polling server and calls the `long-task` tool.
12+
The server disconnects at 50% progress, and the client automatically reconnects
13+
to receive remaining progress updates.
14+
15+
Run:
16+
# First start the server:
17+
uv run examples/snippets/servers/sse_polling_server.py
18+
19+
# Then run this client:
20+
uv run examples/snippets/clients/sse_polling_client.py
21+
"""
22+
23+
import asyncio
24+
import logging
25+
26+
from mcp import ClientSession
27+
from mcp.client.streamable_http import StreamableHTTPReconnectionOptions, streamablehttp_client
28+
from mcp.types import LoggingMessageNotificationParams, TextContent
29+
30+
logging.basicConfig(
31+
level=logging.INFO,
32+
format="%(asctime)s - %(levelname)s - %(message)s",
33+
)
34+
logger = logging.getLogger(__name__)
35+
36+
37+
async def main() -> None:
38+
print("SSE Polling Example Client")
39+
print("=" * 50)
40+
print()
41+
42+
# Track notifications received via the logging callback
43+
notifications_received: list[str] = []
44+
45+
async def logging_callback(params: LoggingMessageNotificationParams) -> None:
46+
"""Called when a log message notification is received from the server."""
47+
data = params.data
48+
if data:
49+
data_str = str(data)
50+
notifications_received.append(data_str)
51+
print(f"[Progress] {data_str}")
52+
53+
# Configure reconnection behavior
54+
reconnection_options = StreamableHTTPReconnectionOptions(
55+
initial_reconnection_delay=1.0, # Start with 1 second
56+
max_reconnection_delay=30.0, # Cap at 30 seconds
57+
reconnection_delay_grow_factor=1.5, # Exponential backoff
58+
max_retries=5, # Try up to 5 times
59+
)
60+
61+
print("[Client] Connecting to server...")
62+
63+
async with streamablehttp_client(
64+
"http://localhost:3001/mcp",
65+
reconnection_options=reconnection_options,
66+
) as (read_stream, write_stream, get_session_id):
67+
# Create session with logging callback to receive progress notifications
68+
async with ClientSession(
69+
read_stream,
70+
write_stream,
71+
logging_callback=logging_callback,
72+
) as session:
73+
# Initialize the session
74+
await session.initialize()
75+
session_id = get_session_id()
76+
print(f"[Client] Connected! Session ID: {session_id}")
77+
78+
# List available tools
79+
tools = await session.list_tools()
80+
tool_names = [t.name for t in tools.tools]
81+
print(f"[Client] Available tools: {tool_names}")
82+
print()
83+
84+
# Call the long-running task
85+
print("[Client] Calling long-task tool...")
86+
print("[Client] The server will disconnect at 50% and we'll auto-reconnect")
87+
print()
88+
89+
# Call the tool
90+
result = await session.call_tool("long-task", {})
91+
92+
print()
93+
print("[Client] Task completed!")
94+
if result.content and isinstance(result.content[0], TextContent):
95+
print(f"[Result] {result.content[0].text}")
96+
else:
97+
print("[Result] No content")
98+
print()
99+
print(f"[Summary] Received {len(notifications_received)} progress notifications")
100+
101+
102+
if __name__ == "__main__":
103+
asyncio.run(main())

0 commit comments

Comments
 (0)