Skip to content

Commit b975992

Browse files
authored
Merge branch 'main' into fix-resolving-message-endpoint-url
2 parents e41a8ee + 84adde1 commit b975992

24 files changed

+643
-93
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
141141
* Constructs a new WebFlux SSE server transport provider instance.
142142
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
143143
* of MCP messages. Must not be null.
144-
* @param baseUrl webflux messag base path
144+
* @param baseUrl webflux message base path
145145
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
146146
* messages. This endpoint will be communicated to clients during SSE connection
147147
* setup. Must not be null.

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 165 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import java.util.List;
99
import java.util.Map;
1010
import java.util.concurrent.ConcurrentHashMap;
11+
import java.util.concurrent.TimeUnit;
1112
import java.util.concurrent.atomic.AtomicReference;
1213
import java.util.function.Function;
1314
import java.util.stream.Collectors;
@@ -18,6 +19,7 @@
1819
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
1920
import io.modelcontextprotocol.server.McpServer;
2021
import io.modelcontextprotocol.server.McpServerFeatures;
22+
import io.modelcontextprotocol.server.TestUtil;
2123
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
2224
import io.modelcontextprotocol.spec.McpError;
2325
import io.modelcontextprotocol.spec.McpSchema;
@@ -35,10 +37,8 @@
3537
import org.junit.jupiter.api.BeforeEach;
3638
import org.junit.jupiter.params.ParameterizedTest;
3739
import org.junit.jupiter.params.provider.ValueSource;
38-
import reactor.core.publisher.Mono;
3940
import reactor.netty.DisposableServer;
4041
import reactor.netty.http.server.HttpServer;
41-
import reactor.test.StepVerifier;
4242

4343
import org.springframework.http.server.reactive.HttpHandler;
4444
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
@@ -47,12 +47,14 @@
4747
import org.springframework.web.reactive.function.server.RouterFunctions;
4848

4949
import static org.assertj.core.api.Assertions.assertThat;
50+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
51+
import static org.assertj.core.api.Assertions.assertWith;
5052
import static org.awaitility.Awaitility.await;
5153
import static org.mockito.Mockito.mock;
5254

53-
public class WebFluxSseIntegrationTests {
55+
class WebFluxSseIntegrationTests {
5456

55-
private static final int PORT = 8182;
57+
private static final int PORT = TestUtil.findAvailablePort();
5658

5759
private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse";
5860

@@ -106,12 +108,9 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
106108
var clientBuilder = clientBuilders.get(clientType);
107109

108110
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
109-
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
110-
111-
exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block();
112-
113-
return Mono.just(mock(CallToolResult.class));
114-
});
111+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema),
112+
(exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class))
113+
.thenReturn(mock(CallToolResult.class)));
115114

116115
var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build();
117116

@@ -133,7 +132,7 @@ void testCreateMessageWithoutSamplingCapabilities(String clientType) {
133132

134133
@ParameterizedTest(name = "{0} : {displayName} ")
135134
@ValueSource(strings = { "httpclient", "webflux" })
136-
void testCreateMessageSuccess(String clientType) throws InterruptedException {
135+
void testCreateMessageSuccess(String clientType) {
137136

138137
var clientBuilder = clientBuilders.get(clientType);
139138

@@ -148,10 +147,12 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
148147
CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
149148
null);
150149

150+
AtomicReference<CreateMessageResult> samplingResult = new AtomicReference<>();
151+
151152
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
152153
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
153154

154-
var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
155+
var createMessageRequest = McpSchema.CreateMessageRequest.builder()
155156
.messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
156157
new McpSchema.TextContent("Test message"))))
157158
.modelPreferences(ModelPreferences.builder()
@@ -162,19 +163,89 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
162163
.build())
163164
.build();
164165

165-
StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> {
166-
assertThat(result).isNotNull();
167-
assertThat(result.role()).isEqualTo(Role.USER);
168-
assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
169-
assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
170-
assertThat(result.model()).isEqualTo("MockModelName");
171-
assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
172-
}).verifyComplete();
166+
return exchange.createMessage(createMessageRequest)
167+
.doOnNext(samplingResult::set)
168+
.thenReturn(callResponse);
169+
});
170+
171+
var mcpServer = McpServer.async(mcpServerTransportProvider)
172+
.serverInfo("test-server", "1.0.0")
173+
.tools(tool)
174+
.build();
175+
176+
try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
177+
.capabilities(ClientCapabilities.builder().sampling().build())
178+
.sampling(samplingHandler)
179+
.build()) {
180+
181+
InitializeResult initResult = mcpClient.initialize();
182+
assertThat(initResult).isNotNull();
183+
184+
CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
185+
186+
assertThat(response).isNotNull();
187+
assertThat(response).isEqualTo(callResponse);
188+
189+
assertWith(samplingResult.get(), result -> {
190+
assertThat(result).isNotNull();
191+
assertThat(result.role()).isEqualTo(Role.USER);
192+
assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
193+
assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
194+
assertThat(result.model()).isEqualTo("MockModelName");
195+
assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
196+
});
197+
}
198+
mcpServer.closeGracefully().block();
199+
}
200+
201+
@ParameterizedTest(name = "{0} : {displayName} ")
202+
@ValueSource(strings = { "httpclient", "webflux" })
203+
void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException {
204+
205+
// Client
206+
var clientBuilder = clientBuilders.get(clientType);
207+
208+
Function<CreateMessageRequest, CreateMessageResult> samplingHandler = request -> {
209+
assertThat(request.messages()).hasSize(1);
210+
assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
211+
try {
212+
TimeUnit.SECONDS.sleep(2);
213+
}
214+
catch (InterruptedException e) {
215+
throw new RuntimeException(e);
216+
}
217+
return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
218+
CreateMessageResult.StopReason.STOP_SEQUENCE);
219+
};
220+
221+
// Server
222+
223+
CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
224+
null);
225+
226+
AtomicReference<CreateMessageResult> samplingResult = new AtomicReference<>();
227+
228+
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
229+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
230+
231+
var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
232+
.messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
233+
new McpSchema.TextContent("Test message"))))
234+
.modelPreferences(ModelPreferences.builder()
235+
.hints(List.of())
236+
.costPriority(1.0)
237+
.speedPriority(1.0)
238+
.intelligencePriority(1.0)
239+
.build())
240+
.build();
173241

174-
return Mono.just(callResponse);
242+
return exchange.createMessage(craeteMessageRequest)
243+
.doOnNext(samplingResult::set)
244+
.thenReturn(callResponse);
175245
});
176246

177247
var mcpServer = McpServer.async(mcpServerTransportProvider)
248+
.requestTimeout(Duration.ofSeconds(4))
178249
.serverInfo("test-server", "1.0.0")
179250
.tools(tool)
180251
.build();
@@ -191,8 +262,77 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
191262

192263
assertThat(response).isNotNull();
193264
assertThat(response).isEqualTo(callResponse);
265+
266+
assertWith(samplingResult.get(), result -> {
267+
assertThat(result).isNotNull();
268+
assertThat(result.role()).isEqualTo(Role.USER);
269+
assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
270+
assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
271+
assertThat(result.model()).isEqualTo("MockModelName");
272+
assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
273+
});
194274
}
195-
mcpServer.close();
275+
276+
mcpServer.closeGracefully().block();
277+
}
278+
279+
@ParameterizedTest(name = "{0} : {displayName} ")
280+
@ValueSource(strings = { "httpclient", "webflux" })
281+
void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException {
282+
283+
// Client
284+
var clientBuilder = clientBuilders.get(clientType);
285+
286+
Function<CreateMessageRequest, CreateMessageResult> samplingHandler = request -> {
287+
assertThat(request.messages()).hasSize(1);
288+
assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
289+
try {
290+
TimeUnit.SECONDS.sleep(2);
291+
}
292+
catch (InterruptedException e) {
293+
throw new RuntimeException(e);
294+
}
295+
return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
296+
CreateMessageResult.StopReason.STOP_SEQUENCE);
297+
};
298+
299+
// Server
300+
301+
CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
302+
null);
303+
304+
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
305+
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
306+
307+
var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
308+
.messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
309+
new McpSchema.TextContent("Test message"))))
310+
.build();
311+
312+
return exchange.createMessage(craeteMessageRequest).thenReturn(callResponse);
313+
});
314+
315+
var mcpServer = McpServer.async(mcpServerTransportProvider)
316+
.requestTimeout(Duration.ofSeconds(1))
317+
.serverInfo("test-server", "1.0.0")
318+
.tools(tool)
319+
.build();
320+
321+
try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
322+
.capabilities(ClientCapabilities.builder().sampling().build())
323+
.sampling(samplingHandler)
324+
.build()) {
325+
326+
InitializeResult initResult = mcpClient.initialize();
327+
assertThat(initResult).isNotNull();
328+
329+
assertThatExceptionOfType(McpError.class).isThrownBy(() -> {
330+
mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
331+
}).withMessageContaining("within 1000ms");
332+
333+
}
334+
335+
mcpServer.closeGracefully().block();
196336
}
197337

198338
// ---------------------------------------
@@ -262,9 +402,8 @@ void testRootsWithoutCapability(String clientType) {
262402
var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> {
263403
}).tools(tool).build();
264404

265-
try (
266-
// Create client without roots capability
267-
var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) {
405+
// Create client without roots capability
406+
try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) {
268407

269408
assertThat(mcpClient.initialize()).isNotNull();
270409

@@ -282,7 +421,7 @@ void testRootsWithoutCapability(String clientType) {
282421

283422
@ParameterizedTest(name = "{0} : {displayName} ")
284423
@ValueSource(strings = { "httpclient", "webflux" })
285-
void testRootsNotifciationWithEmptyRootsList(String clientType) {
424+
void testRootsNotificationWithEmptyRootsList(String clientType) {
286425
var clientBuilder = clientBuilders.get(clientType);
287426

288427
AtomicReference<List<Root>> rootsRef = new AtomicReference<>();

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
@Timeout(15) // Giving extra time beyond the client timeout
2424
class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests {
2525

26-
private static final int PORT = 8181;
26+
private static final int PORT = TestUtil.findAvailablePort();
2727

2828
private static final String MESSAGE_ENDPOINT = "/mcp/message";
2929

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
@Timeout(15) // Giving extra time beyond the client timeout
2424
class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests {
2525

26-
private static final int PORT = 8182;
26+
private static final int PORT = TestUtil.findAvailablePort();
2727

2828
private static final String MESSAGE_ENDPOINT = "/mcp/message";
2929

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/TomcatTestUtil.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
*/
44
package io.modelcontextprotocol.server;
55

6+
import java.io.IOException;
7+
import java.net.InetSocketAddress;
8+
import java.net.ServerSocket;
9+
610
import org.apache.catalina.Context;
711
import org.apache.catalina.startup.Tomcat;
812

@@ -14,10 +18,14 @@
1418
*/
1519
public class TomcatTestUtil {
1620

21+
TomcatTestUtil() {
22+
// Prevent instantiation
23+
}
24+
1725
public record TomcatServer(Tomcat tomcat, AnnotationConfigWebApplicationContext appContext) {
1826
}
1927

20-
public TomcatServer createTomcatServer(String contextPath, int port, Class<?> componentClass) {
28+
public static TomcatServer createTomcatServer(String contextPath, int port, Class<?> componentClass) {
2129

2230
// Set up Tomcat first
2331
var tomcat = new Tomcat();

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class WebMvcSseAsyncServerTransportTests extends AbstractMcpAsyncServerTests {
2525

2626
private static final String MESSAGE_ENDPOINT = "/mcp/message";
2727

28-
private static final int PORT = 8181;
28+
private static final int PORT = TestUtil.findAvailablePort();
2929

3030
private Tomcat tomcat;
3131

@@ -73,7 +73,6 @@ protected McpServerTransportProvider createMcpTransportProvider() {
7373

7474
// Create DispatcherServlet with our Spring context
7575
DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
76-
// dispatcherServlet.setThrowExceptionIfNoHandlerFound(true);
7776

7877
// Add servlet to Tomcat and get the wrapper
7978
var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);

mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222

2323
import static org.assertj.core.api.Assertions.assertThat;
2424

25-
public class WebMvcSseCustomContextPathTests {
25+
class WebMvcSseCustomContextPathTests {
2626

2727
private static final String CUSTOM_CONTEXT_PATH = "/app/1";
2828

29-
private static final int PORT = 8183;
29+
private static final int PORT = TestUtil.findAvailablePort();
3030

3131
private static final String MESSAGE_ENDPOINT = "/mcp/message";
3232

@@ -39,11 +39,11 @@ public class WebMvcSseCustomContextPathTests {
3939
@BeforeEach
4040
public void before() {
4141

42-
tomcatServer = new TomcatTestUtil().createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class);
42+
tomcatServer = TomcatTestUtil.createTomcatServer(CUSTOM_CONTEXT_PATH, PORT, TestConfig.class);
4343

4444
try {
4545
tomcatServer.tomcat().start();
46-
assertThat(tomcatServer.tomcat().getServer().getState() == LifecycleState.STARTED);
46+
assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
4747
}
4848
catch (Exception e) {
4949
throw new RuntimeException("Failed to start Tomcat", e);

0 commit comments

Comments
 (0)