Skip to content

Commit 21db2a6

Browse files
Implement close_sse_stream and client auto-reconnect (SEP-1699)
Server now supports closing SSE streams mid-operation via close_sse_stream(), which triggers client reconnection. Client automatically reconnects when the stream closes after receiving a priming event. Changes: - Server transport: Implement close_sse_stream() to close SSE writer - Server transport: Create callback and pass via ServerMessageMetadata - Lowlevel server: Thread close_sse_stream callback to RequestContext - FastMCP Context: Wire close_sse_stream() to call the callback - Client: Track priming events and auto-reconnect with Last-Event-ID
1 parent 1dfe97e commit 21db2a6

File tree

4 files changed

+94
-11
lines changed

4 files changed

+94
-11
lines changed

src/mcp/client/streamable_http.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,15 @@ async def _handle_sse_response(
329329
is_initialization: bool = False,
330330
) -> None:
331331
"""Handle SSE response from the server."""
332+
last_event_id: str | None = None
333+
332334
try:
333335
event_source = EventSource(response)
334336
async for sse in event_source.aiter_sse(): # pragma: no branch
337+
# Track last event ID for potential reconnection
338+
if sse.id:
339+
last_event_id = sse.id
340+
335341
is_complete = await self._handle_sse_event(
336342
sse,
337343
ctx.read_stream_writer,
@@ -342,10 +348,63 @@ async def _handle_sse_response(
342348
# break the loop
343349
if is_complete:
344350
await response.aclose()
345-
break
351+
return # Normal completion, no reconnect needed
352+
except Exception as e:
353+
logger.debug(f"SSE stream ended: {e}")
354+
355+
# Stream ended without response - reconnect if we have priming event
356+
if last_event_id is not None:
357+
await self._handle_reconnection(ctx, last_event_id)
358+
359+
async def _handle_reconnection(
360+
self,
361+
ctx: RequestContext,
362+
last_event_id: str,
363+
) -> None:
364+
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
365+
headers = self._prepare_request_headers(ctx.headers)
366+
headers[LAST_EVENT_ID] = last_event_id
367+
368+
# Extract original request ID to map responses
369+
original_request_id = None
370+
if isinstance(ctx.session_message.message.root, JSONRPCRequest):
371+
original_request_id = ctx.session_message.message.root.id
372+
373+
try:
374+
async with aconnect_sse(
375+
ctx.client,
376+
"GET",
377+
self.url,
378+
headers=headers,
379+
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
380+
) as event_source:
381+
event_source.response.raise_for_status()
382+
logger.debug("Reconnection GET SSE connection established")
383+
384+
# Track for potential further reconnection
385+
reconnect_last_event_id: str | None = last_event_id
386+
387+
async for sse in event_source.aiter_sse():
388+
if sse.id:
389+
reconnect_last_event_id = sse.id
390+
391+
is_complete = await self._handle_sse_event(
392+
sse,
393+
ctx.read_stream_writer,
394+
original_request_id,
395+
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
396+
)
397+
if is_complete:
398+
await event_source.response.aclose()
399+
return
400+
401+
# Stream ended again without response - reconnect again
402+
if reconnect_last_event_id is not None:
403+
await self._handle_reconnection(ctx, reconnect_last_event_id)
346404
except Exception as e:
347-
logger.exception("Error reading SSE stream:") # pragma: no cover
348-
await ctx.read_stream_writer.send(e) # pragma: no cover
405+
logger.debug(f"Reconnection failed: {e}")
406+
# Try to reconnect again if we still have an event ID
407+
await self._handle_reconnection(ctx, last_event_id)
349408

350409
async def _handle_unexpected_content_type(
351410
self,

src/mcp/server/fastmcp/server.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,11 +1295,9 @@ async def close_sse_stream(self) -> None:
12951295
Note:
12961296
This is a no-op if not using StreamableHTTP transport with event_store.
12971297
The callback is only available when event_store is configured.
1298-
1299-
Raises:
1300-
NotImplementedError: Feature not yet implemented.
13011298
"""
1302-
raise NotImplementedError("close_sse_stream not yet implemented")
1299+
if self._request_context and self._request_context.close_sse_stream:
1300+
await self._request_context.close_sse_stream()
13031301

13041302
# Convenience methods for common log levels
13051303
async def debug(self, message: str, **extra: Any) -> None:

src/mcp/server/lowlevel/server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,12 +680,14 @@ async def _handle_request(
680680

681681
token = None
682682
try:
683-
# Extract request context from message metadata
683+
# Extract request context and close_sse_stream from message metadata
684684
request_data = None
685+
close_sse_stream_cb = None
685686
if message.message_metadata is not None and isinstance(
686687
message.message_metadata, ServerMessageMetadata
687688
): # pragma: no cover
688689
request_data = message.message_metadata.request_context
690+
close_sse_stream_cb = message.message_metadata.close_sse_stream
689691

690692
# Set our global state that can be retrieved via
691693
# app.get_request_context()
@@ -696,6 +698,7 @@ async def _handle_request(
696698
session,
697699
lifespan_context,
698700
request=request_data,
701+
close_sse_stream=close_sse_stream_cb,
699702
)
700703
)
701704
response = await handler(req)

src/mcp/server/streamable_http.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def __init__(
177177
MemoryObjectReceiveStream[EventMessage],
178178
],
179179
] = {}
180+
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {}
180181
self._terminated = False
181182

182183
@property
@@ -202,7 +203,26 @@ def close_sse_stream(self, request_id: RequestId) -> None:
202203
Requires event_store to be configured for events to be stored during
203204
the disconnect.
204205
"""
205-
raise NotImplementedError("close_sse_stream not yet implemented")
206+
writer = self._sse_stream_writers.pop(request_id, None)
207+
if writer:
208+
writer.close()
209+
210+
def _create_session_message(
211+
self,
212+
message: JSONRPCMessage,
213+
request: Request,
214+
request_id: RequestId,
215+
) -> SessionMessage:
216+
"""Create a session message with metadata including close_sse_stream callback."""
217+
218+
async def close_stream_callback() -> None:
219+
self.close_sse_stream(request_id)
220+
221+
metadata = ServerMessageMetadata(
222+
request_context=request,
223+
close_sse_stream=close_stream_callback,
224+
)
225+
return SessionMessage(message, metadata=metadata)
206226

207227
def _create_error_response(
208228
self,
@@ -485,6 +505,9 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
485505
# Create SSE stream
486506
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
487507

508+
# Store writer reference so close_sse_stream() can close it
509+
self._sse_stream_writers[request_id] = sse_stream_writer
510+
488511
async def sse_writer():
489512
# Get the request ID from the incoming request message
490513
try:
@@ -516,6 +539,7 @@ async def sse_writer():
516539
logger.exception("Error in SSE writer")
517540
finally:
518541
logger.debug("Closing SSE writer")
542+
self._sse_stream_writers.pop(request_id, None)
519543
await self._clean_up_memory_streams(request_id)
520544

521545
# Create and start EventSourceResponse
@@ -539,8 +563,7 @@ async def sse_writer():
539563
async with anyio.create_task_group() as tg:
540564
tg.start_soon(response, scope, receive, send)
541565
# Then send the message to be processed by the server
542-
metadata = ServerMessageMetadata(request_context=request)
543-
session_message = SessionMessage(message, metadata=metadata)
566+
session_message = self._create_session_message(message, request, request_id)
544567
await writer.send(session_message)
545568
except Exception:
546569
logger.exception("SSE response error")

0 commit comments

Comments
 (0)