@@ -98,32 +98,33 @@ async def replay_events_after(
9898 send_callback : EventCallback ,
9999 ) -> StreamId | None :
100100 """Replay events after the specified ID."""
101- # Find the index of the last event ID
102- start_index = None
103- for i , ( _ , event_id , _ ) in enumerate ( self ._events ) :
101+ # Find the stream ID of the last event
102+ target_stream_id = None
103+ for stream_id , event_id , _ in self ._events :
104104 if event_id == last_event_id :
105- start_index = i + 1
105+ target_stream_id = stream_id
106106 break
107107
108- if start_index is None :
109- # If event ID not found, start from beginning
110- start_index = 0
108+ if target_stream_id is None :
109+ # If event ID not found, return None
110+ return None
111111
112- stream_id = None
113- # Replay events
114- for _ , event_id , message in self ._events [start_index :]:
115- await send_callback (EventMessage (message , event_id ))
116- # Capture the stream ID from the first replayed event
117- if stream_id is None and len (self ._events ) > start_index :
118- stream_id = self ._events [start_index ][0 ]
112+ # Convert last_event_id to int for comparison
113+ last_event_id_int = int (last_event_id )
119114
120- return stream_id
115+ # Replay only events from the same stream with ID > last_event_id
116+ for stream_id , event_id , message in self ._events :
117+ if stream_id == target_stream_id and int (event_id ) > last_event_id_int :
118+ await send_callback (EventMessage (message , event_id ))
119+
120+ return target_stream_id
121121
122122
123123# Test server implementation that follows MCP protocol
124124class ServerTest (Server ):
125125 def __init__ (self ):
126126 super ().__init__ (SERVER_NAME )
127+ self ._lock = None # Will be initialized in async context
127128
128129 @self .read_resource ()
129130 async def handle_read_resource (uri : AnyUrl ) -> str | bytes :
@@ -159,6 +160,16 @@ async def handle_list_tools() -> list[Tool]:
159160 description = "A tool that triggers server-side sampling" ,
160161 inputSchema = {"type" : "object" , "properties" : {}},
161162 ),
163+ Tool (
164+ name = "wait_for_lock_with_notification" ,
165+ description = "A tool that sends a notification and waits for lock" ,
166+ inputSchema = {"type" : "object" , "properties" : {}},
167+ ),
168+ Tool (
169+ name = "release_lock" ,
170+ description = "A tool that releases the lock" ,
171+ inputSchema = {"type" : "object" , "properties" : {}},
172+ ),
162173 ]
163174
164175 @self .call_tool ()
@@ -214,6 +225,39 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
214225 )
215226 ]
216227
228+ elif name == "wait_for_lock_with_notification" :
229+ # Initialize lock if not already done
230+ if self ._lock is None :
231+ self ._lock = anyio .Event ()
232+
233+ # First send a notification
234+ await ctx .session .send_log_message (
235+ level = "info" ,
236+ data = "First notification before lock" ,
237+ logger = "lock_tool" ,
238+ related_request_id = ctx .request_id ,
239+ )
240+
241+ # Now wait for the lock to be released
242+ await self ._lock .wait ()
243+
244+ # Send second notification after lock is released
245+ await ctx .session .send_log_message (
246+ level = "info" ,
247+ data = "Second notification after lock" ,
248+ logger = "lock_tool" ,
249+ related_request_id = ctx .request_id ,
250+ )
251+
252+ return [TextContent (type = "text" , text = "Completed" )]
253+
254+ elif name == "release_lock" :
255+ assert self ._lock is not None , "Lock must be initialized before releasing"
256+
257+ # Release the lock
258+ self ._lock .set ()
259+ return [TextContent (type = "text" , text = "Lock released" )]
260+
217261 return [TextContent (type = "text" , text = f"Called { name } " )]
218262
219263
@@ -825,7 +869,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
825869 """Test client tool invocation."""
826870 # First list tools
827871 tools = await initialized_client_session .list_tools ()
828- assert len (tools .tools ) == 4
872+ assert len (tools .tools ) == 6
829873 assert tools .tools [0 ].name == "test_tool"
830874
831875 # Call the tool
@@ -862,7 +906,7 @@ async def test_streamablehttp_client_session_persistence(basic_server, basic_ser
862906
863907 # Make multiple requests to verify session persistence
864908 tools = await session .list_tools ()
865- assert len (tools .tools ) == 4
909+ assert len (tools .tools ) == 6
866910
867911 # Read a resource
868912 resource = await session .read_resource (uri = AnyUrl ("foobar://test-persist" ))
@@ -891,7 +935,7 @@ async def test_streamablehttp_client_json_response(json_response_server, json_se
891935
892936 # Check tool listing
893937 tools = await session .list_tools ()
894- assert len (tools .tools ) == 4
938+ assert len (tools .tools ) == 6
895939
896940 # Call a tool and verify JSON response handling
897941 result = await session .call_tool ("test_tool" , {})
@@ -962,7 +1006,7 @@ async def test_streamablehttp_client_session_termination(basic_server, basic_ser
9621006
9631007 # Make a request to confirm session is working
9641008 tools = await session .list_tools ()
965- assert len (tools .tools ) == 4
1009+ assert len (tools .tools ) == 6
9661010
9671011 headers = {}
9681012 if captured_session_id :
@@ -1026,7 +1070,7 @@ async def mock_delete(self, *args, **kwargs):
10261070
10271071 # Make a request to confirm session is working
10281072 tools = await session .list_tools ()
1029- assert len (tools .tools ) == 4
1073+ assert len (tools .tools ) == 6
10301074
10311075 headers = {}
10321076 if captured_session_id :
@@ -1048,32 +1092,32 @@ async def mock_delete(self, *args, **kwargs):
10481092
10491093@pytest .mark .anyio
10501094async def test_streamablehttp_client_resumption (event_server ):
1051- """Test client session to resume a long running tool ."""
1095+ """Test client session resumption using sync primitives for reliable coordination ."""
10521096 _ , server_url = event_server
10531097
10541098 # Variables to track the state
10551099 captured_session_id = None
10561100 captured_resumption_token = None
10571101 captured_notifications = []
1058- tool_started = False
10591102 captured_protocol_version = None
1103+ first_notification_received = False
10601104
10611105 async def message_handler (
10621106 message : RequestResponder [types .ServerRequest , types .ClientResult ] | types .ServerNotification | Exception ,
10631107 ) -> None :
10641108 if isinstance (message , types .ServerNotification ):
10651109 captured_notifications .append (message )
1066- # Look for our special notification that indicates the tool is running
1110+ # Look for our first notification
10671111 if isinstance (message .root , types .LoggingMessageNotification ):
1068- if message .root .params .data == "Tool started " :
1069- nonlocal tool_started
1070- tool_started = True
1112+ if message .root .params .data == "First notification before lock " :
1113+ nonlocal first_notification_received
1114+ first_notification_received = True
10711115
10721116 async def on_resumption_token_update (token : str ) -> None :
10731117 nonlocal captured_resumption_token
10741118 captured_resumption_token = token
10751119
1076- # First, start the client session and begin the long-running tool
1120+ # First, start the client session and begin the tool that waits on lock
10771121 async with streamablehttp_client (f"{ server_url } /mcp" , terminate_on_close = False ) as (
10781122 read_stream ,
10791123 write_stream ,
@@ -1088,7 +1132,7 @@ async def on_resumption_token_update(token: str) -> None:
10881132 # Capture the negotiated protocol version
10891133 captured_protocol_version = result .protocolVersion
10901134
1091- # Start a long-running tool in a task
1135+ # Start the tool that will wait on lock in a task
10921136 async with anyio .create_task_group () as tg :
10931137
10941138 async def run_tool ():
@@ -1099,7 +1143,9 @@ async def run_tool():
10991143 types .ClientRequest (
11001144 types .CallToolRequest (
11011145 method = "tools/call" ,
1102- params = types .CallToolRequestParams (name = "long_running_with_checkpoints" , arguments = {}),
1146+ params = types .CallToolRequestParams (
1147+ name = "wait_for_lock_with_notification" , arguments = {}
1148+ ),
11031149 )
11041150 ),
11051151 types .CallToolResult ,
@@ -1108,15 +1154,19 @@ async def run_tool():
11081154
11091155 tg .start_soon (run_tool )
11101156
1111- # Wait for the tool to start and at least one notification
1112- # and then kill the task group
1113- while not tool_started or not captured_resumption_token :
1157+ # Wait for the first notification and resumption token
1158+ while not first_notification_received or not captured_resumption_token :
11141159 await anyio .sleep (0.1 )
1160+
1161+ # Kill the client session while tool is waiting on lock
11151162 tg .cancel_scope .cancel ()
11161163
1117- # Store pre notifications and clear the captured notifications
1118- # for the post-resumption check
1119- captured_notifications_pre = captured_notifications .copy ()
1164+ # Verify we received exactly one notification
1165+ assert len (captured_notifications ) == 1
1166+ assert isinstance (captured_notifications [0 ].root , types .LoggingMessageNotification )
1167+ assert captured_notifications [0 ].root .params .data == "First notification before lock"
1168+
1169+ # Clear notifications for the second phase
11201170 captured_notifications = []
11211171
11221172 # Now resume the session with the same mcp-session-id and protocol version
@@ -1125,54 +1175,48 @@ async def run_tool():
11251175 headers [MCP_SESSION_ID_HEADER ] = captured_session_id
11261176 if captured_protocol_version :
11271177 headers [MCP_PROTOCOL_VERSION_HEADER ] = captured_protocol_version
1128-
11291178 async with streamablehttp_client (f"{ server_url } /mcp" , headers = headers ) as (
11301179 read_stream ,
11311180 write_stream ,
11321181 _ ,
11331182 ):
11341183 async with ClientSession (read_stream , write_stream , message_handler = message_handler ) as session :
1135- # Don't initialize - just use the existing session
1136-
1137- # Resume the tool with the resumption token
1138- assert captured_resumption_token is not None
1139-
1184+ result = await session .send_request (
1185+ types .ClientRequest (
1186+ types .CallToolRequest (
1187+ method = "tools/call" ,
1188+ params = types .CallToolRequestParams (name = "release_lock" , arguments = {}),
1189+ )
1190+ ),
1191+ types .CallToolResult ,
1192+ )
11401193 metadata = ClientMessageMetadata (
11411194 resumption_token = captured_resumption_token ,
11421195 )
1196+
11431197 result = await session .send_request (
11441198 types .ClientRequest (
11451199 types .CallToolRequest (
11461200 method = "tools/call" ,
1147- params = types .CallToolRequestParams (name = "long_running_with_checkpoints " , arguments = {}),
1201+ params = types .CallToolRequestParams (name = "wait_for_lock_with_notification " , arguments = {}),
11481202 )
11491203 ),
11501204 types .CallToolResult ,
11511205 metadata = metadata ,
11521206 )
1153-
1154- # We should get a complete result
11551207 assert len (result .content ) == 1
11561208 assert result .content [0 ].type == "text"
1157- assert "Completed" in result .content [0 ].text
1209+ assert result .content [0 ].text == "Completed"
11581210
11591211 # We should have received the remaining notifications
1160- assert len (captured_notifications ) > 0
1212+ assert len (captured_notifications ) == 1
11611213
1162- # Should not have the first notification
1163- # Check that "Tool started" notification isn't repeated when resuming
1164- assert not any (
1165- isinstance (n .root , types .LoggingMessageNotification ) and n .root .params .data == "Tool started"
1166- for n in captured_notifications
1167- )
1168- # there is no intersection between pre and post notifications
1169- assert not any (n in captured_notifications_pre for n in captured_notifications )
1214+ assert captured_notifications [0 ].root .params .data == "Second notification after lock"
11701215
11711216
11721217@pytest .mark .anyio
11731218async def test_streamablehttp_server_sampling (basic_server , basic_server_url ):
11741219 """Test server-initiated sampling request through streamable HTTP transport."""
1175- print ("Testing server sampling..." )
11761220 # Variable to track if sampling callback was invoked
11771221 sampling_callback_invoked = False
11781222 captured_message_params = None
0 commit comments