8
8
import java .util .List ;
9
9
import java .util .Map ;
10
10
import java .util .concurrent .ConcurrentHashMap ;
11
+ import java .util .concurrent .TimeUnit ;
11
12
import java .util .concurrent .atomic .AtomicReference ;
12
13
import java .util .function .Function ;
13
14
import java .util .stream .Collectors ;
18
19
import io .modelcontextprotocol .client .transport .WebFluxSseClientTransport ;
19
20
import io .modelcontextprotocol .server .McpServer ;
20
21
import io .modelcontextprotocol .server .McpServerFeatures ;
22
+ import io .modelcontextprotocol .server .TestUtil ;
21
23
import io .modelcontextprotocol .server .transport .WebFluxSseServerTransportProvider ;
22
24
import io .modelcontextprotocol .spec .McpError ;
23
25
import io .modelcontextprotocol .spec .McpSchema ;
35
37
import org .junit .jupiter .api .BeforeEach ;
36
38
import org .junit .jupiter .params .ParameterizedTest ;
37
39
import org .junit .jupiter .params .provider .ValueSource ;
38
- import reactor .core .publisher .Mono ;
39
40
import reactor .netty .DisposableServer ;
40
41
import reactor .netty .http .server .HttpServer ;
41
- import reactor .test .StepVerifier ;
42
42
43
43
import org .springframework .http .server .reactive .HttpHandler ;
44
44
import org .springframework .http .server .reactive .ReactorHttpHandlerAdapter ;
47
47
import org .springframework .web .reactive .function .server .RouterFunctions ;
48
48
49
49
import static org .assertj .core .api .Assertions .assertThat ;
50
+ import static org .assertj .core .api .Assertions .assertThatExceptionOfType ;
51
+ import static org .assertj .core .api .Assertions .assertWith ;
50
52
import static org .awaitility .Awaitility .await ;
51
53
import static org .mockito .Mockito .mock ;
52
54
53
- public class WebFluxSseIntegrationTests {
55
+ class WebFluxSseIntegrationTests {
54
56
55
- private static final int PORT = 8182 ;
57
+ private static final int PORT = TestUtil . findAvailablePort () ;
56
58
57
59
private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse" ;
58
60
@@ -106,12 +108,9 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
106
108
var clientBuilder = clientBuilders .get (clientType );
107
109
108
110
McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
109
- new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
110
-
111
- exchange .createMessage (mock (McpSchema .CreateMessageRequest .class )).block ();
112
-
113
- return Mono .just (mock (CallToolResult .class ));
114
- });
111
+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ),
112
+ (exchange , request ) -> exchange .createMessage (mock (CreateMessageRequest .class ))
113
+ .thenReturn (mock (CallToolResult .class )));
115
114
116
115
var server = McpServer .async (mcpServerTransportProvider ).serverInfo ("test-server" , "1.0.0" ).tools (tool ).build ();
117
116
@@ -133,7 +132,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
133
132
134
133
@ ParameterizedTest (name = "{0} : {displayName} " )
135
134
@ ValueSource (strings = { "httpclient" , "webflux" })
136
- void testCreateMessageSuccess (String clientType ) throws InterruptedException {
135
+ void testCreateMessageSuccess (String clientType ) {
137
136
138
137
var clientBuilder = clientBuilders .get (clientType );
139
138
@@ -148,10 +147,12 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
148
147
CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
149
148
null );
150
149
150
+ AtomicReference <CreateMessageResult > samplingResult = new AtomicReference <>();
151
+
151
152
McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
152
153
new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
153
154
154
- var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
155
+ var createMessageRequest = McpSchema .CreateMessageRequest .builder ()
155
156
.messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
156
157
new McpSchema .TextContent ("Test message" ))))
157
158
.modelPreferences (ModelPreferences .builder ()
@@ -162,19 +163,89 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
162
163
.build ())
163
164
.build ();
164
165
165
- StepVerifier .create (exchange .createMessage (craeteMessageRequest )).consumeNextWith (result -> {
166
- assertThat (result ).isNotNull ();
167
- assertThat (result .role ()).isEqualTo (Role .USER );
168
- assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
169
- assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
170
- assertThat (result .model ()).isEqualTo ("MockModelName" );
171
- assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
172
- }).verifyComplete ();
166
+ return exchange .createMessage (createMessageRequest )
167
+ .doOnNext (samplingResult ::set )
168
+ .thenReturn (callResponse );
169
+ });
170
+
171
+ var mcpServer = McpServer .async (mcpServerTransportProvider )
172
+ .serverInfo ("test-server" , "1.0.0" )
173
+ .tools (tool )
174
+ .build ();
175
+
176
+ try (var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
177
+ .capabilities (ClientCapabilities .builder ().sampling ().build ())
178
+ .sampling (samplingHandler )
179
+ .build ()) {
180
+
181
+ InitializeResult initResult = mcpClient .initialize ();
182
+ assertThat (initResult ).isNotNull ();
183
+
184
+ CallToolResult response = mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
185
+
186
+ assertThat (response ).isNotNull ();
187
+ assertThat (response ).isEqualTo (callResponse );
188
+
189
+ assertWith (samplingResult .get (), result -> {
190
+ assertThat (result ).isNotNull ();
191
+ assertThat (result .role ()).isEqualTo (Role .USER );
192
+ assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
193
+ assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
194
+ assertThat (result .model ()).isEqualTo ("MockModelName" );
195
+ assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
196
+ });
197
+ }
198
+ mcpServer .closeGracefully ().block ();
199
+ }
200
+
201
+ @ ParameterizedTest (name = "{0} : {displayName} " )
202
+ @ ValueSource (strings = { "httpclient" , "webflux" })
203
+ void testCreateMessageWithRequestTimeoutSuccess (String clientType ) throws InterruptedException {
204
+
205
+ // Client
206
+ var clientBuilder = clientBuilders .get (clientType );
207
+
208
+ Function <CreateMessageRequest , CreateMessageResult > samplingHandler = request -> {
209
+ assertThat (request .messages ()).hasSize (1 );
210
+ assertThat (request .messages ().get (0 ).content ()).isInstanceOf (McpSchema .TextContent .class );
211
+ try {
212
+ TimeUnit .SECONDS .sleep (2 );
213
+ }
214
+ catch (InterruptedException e ) {
215
+ throw new RuntimeException (e );
216
+ }
217
+ return new CreateMessageResult (Role .USER , new McpSchema .TextContent ("Test message" ), "MockModelName" ,
218
+ CreateMessageResult .StopReason .STOP_SEQUENCE );
219
+ };
220
+
221
+ // Server
222
+
223
+ CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
224
+ null );
225
+
226
+ AtomicReference <CreateMessageResult > samplingResult = new AtomicReference <>();
227
+
228
+ McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
229
+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
230
+
231
+ var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
232
+ .messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
233
+ new McpSchema .TextContent ("Test message" ))))
234
+ .modelPreferences (ModelPreferences .builder ()
235
+ .hints (List .of ())
236
+ .costPriority (1.0 )
237
+ .speedPriority (1.0 )
238
+ .intelligencePriority (1.0 )
239
+ .build ())
240
+ .build ();
173
241
174
- return Mono .just (callResponse );
242
+ return exchange .createMessage (craeteMessageRequest )
243
+ .doOnNext (samplingResult ::set )
244
+ .thenReturn (callResponse );
175
245
});
176
246
177
247
var mcpServer = McpServer .async (mcpServerTransportProvider )
248
+ .requestTimeout (Duration .ofSeconds (4 ))
178
249
.serverInfo ("test-server" , "1.0.0" )
179
250
.tools (tool )
180
251
.build ();
@@ -191,8 +262,77 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
191
262
192
263
assertThat (response ).isNotNull ();
193
264
assertThat (response ).isEqualTo (callResponse );
265
+
266
+ assertWith (samplingResult .get (), result -> {
267
+ assertThat (result ).isNotNull ();
268
+ assertThat (result .role ()).isEqualTo (Role .USER );
269
+ assertThat (result .content ()).isInstanceOf (McpSchema .TextContent .class );
270
+ assertThat (((McpSchema .TextContent ) result .content ()).text ()).isEqualTo ("Test message" );
271
+ assertThat (result .model ()).isEqualTo ("MockModelName" );
272
+ assertThat (result .stopReason ()).isEqualTo (CreateMessageResult .StopReason .STOP_SEQUENCE );
273
+ });
194
274
}
195
- mcpServer .close ();
275
+
276
+ mcpServer .closeGracefully ().block ();
277
+ }
278
+
279
+ @ ParameterizedTest (name = "{0} : {displayName} " )
280
+ @ ValueSource (strings = { "httpclient" , "webflux" })
281
+ void testCreateMessageWithRequestTimeoutFail (String clientType ) throws InterruptedException {
282
+
283
+ // Client
284
+ var clientBuilder = clientBuilders .get (clientType );
285
+
286
+ Function <CreateMessageRequest , CreateMessageResult > samplingHandler = request -> {
287
+ assertThat (request .messages ()).hasSize (1 );
288
+ assertThat (request .messages ().get (0 ).content ()).isInstanceOf (McpSchema .TextContent .class );
289
+ try {
290
+ TimeUnit .SECONDS .sleep (2 );
291
+ }
292
+ catch (InterruptedException e ) {
293
+ throw new RuntimeException (e );
294
+ }
295
+ return new CreateMessageResult (Role .USER , new McpSchema .TextContent ("Test message" ), "MockModelName" ,
296
+ CreateMessageResult .StopReason .STOP_SEQUENCE );
297
+ };
298
+
299
+ // Server
300
+
301
+ CallToolResult callResponse = new McpSchema .CallToolResult (List .of (new McpSchema .TextContent ("CALL RESPONSE" )),
302
+ null );
303
+
304
+ McpServerFeatures .AsyncToolSpecification tool = new McpServerFeatures .AsyncToolSpecification (
305
+ new McpSchema .Tool ("tool1" , "tool1 description" , emptyJsonSchema ), (exchange , request ) -> {
306
+
307
+ var craeteMessageRequest = McpSchema .CreateMessageRequest .builder ()
308
+ .messages (List .of (new McpSchema .SamplingMessage (McpSchema .Role .USER ,
309
+ new McpSchema .TextContent ("Test message" ))))
310
+ .build ();
311
+
312
+ return exchange .createMessage (craeteMessageRequest ).thenReturn (callResponse );
313
+ });
314
+
315
+ var mcpServer = McpServer .async (mcpServerTransportProvider )
316
+ .requestTimeout (Duration .ofSeconds (1 ))
317
+ .serverInfo ("test-server" , "1.0.0" )
318
+ .tools (tool )
319
+ .build ();
320
+
321
+ try (var mcpClient = clientBuilder .clientInfo (new McpSchema .Implementation ("Sample client" , "0.0.0" ))
322
+ .capabilities (ClientCapabilities .builder ().sampling ().build ())
323
+ .sampling (samplingHandler )
324
+ .build ()) {
325
+
326
+ InitializeResult initResult = mcpClient .initialize ();
327
+ assertThat (initResult ).isNotNull ();
328
+
329
+ assertThatExceptionOfType (McpError .class ).isThrownBy (() -> {
330
+ mcpClient .callTool (new McpSchema .CallToolRequest ("tool1" , Map .of ()));
331
+ }).withMessageContaining ("within 1000ms" );
332
+
333
+ }
334
+
335
+ mcpServer .closeGracefully ().block ();
196
336
}
197
337
198
338
// ---------------------------------------
@@ -262,9 +402,8 @@ void testRootsWithoutCapability(String clientType) {
262
402
var mcpServer = McpServer .sync (mcpServerTransportProvider ).rootsChangeHandler ((exchange , rootsUpdate ) -> {
263
403
}).tools (tool ).build ();
264
404
265
- try (
266
- // Create client without roots capability
267
- var mcpClient = clientBuilder .capabilities (ClientCapabilities .builder ().build ()).build ()) {
405
+ // Create client without roots capability
406
+ try (var mcpClient = clientBuilder .capabilities (ClientCapabilities .builder ().build ()).build ()) {
268
407
269
408
assertThat (mcpClient .initialize ()).isNotNull ();
270
409
@@ -282,7 +421,7 @@ void testRootsWithoutCapability(String clientType) {
282
421
283
422
@ ParameterizedTest (name = "{0} : {displayName} " )
284
423
@ ValueSource (strings = { "httpclient" , "webflux" })
285
- void testRootsNotifciationWithEmptyRootsList (String clientType ) {
424
+ void testRootsNotificationWithEmptyRootsList (String clientType ) {
286
425
var clientBuilder = clientBuilders .get (clientType );
287
426
288
427
AtomicReference <List <Root >> rootsRef = new AtomicReference <>();
0 commit comments