diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java new file mode 100644 index 000000000..1b026fc46 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStatelessServerTransport.java @@ -0,0 +1,237 @@ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.List; + +/** + * Implementation of a WebMVC based {@link McpStatelessServerTransport}. + * + *

+ * This is the non-reactive version of + * {@link io.modelcontextprotocol.server.transport.WebFluxStatelessServerTransport} + * + * @author Christian Tzolov + */ +public class WebMvcStatelessServerTransport implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcStatelessServerTransport.class); + + private final ObjectMapper objectMapper; + + private final String mcpEndpoint; + + private final RouterFunction routerFunction; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private WebMvcStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .build(); + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Returns the WebMVC router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines one endpoint handling two HTTP methods: + *

+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + private ServerResponse handleGet(ServerRequest request) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + private ServerResponse handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!(acceptHeaders.contains(MediaType.APPLICATION_JSON) + && acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM))) { + return ServerResponse.badRequest().build(); + } + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + try { + McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler + .handleRequest(transportContext, jsonrpcRequest) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.ok().contentType(MediaType.APPLICATION_JSON).body(jsonrpcResponse); + } + catch (Exception e) { + logger.error("Failed to handle request: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Failed to handle request: " + e.getMessage())); + } + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + try { + this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + catch (Exception e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Failed to handle notification: " + e.getMessage())); + } + } + else { + return ServerResponse.badRequest() + .body(new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Unexpected error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Unexpected error: " + e.getMessage())); + } + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebMvcStatelessServerTransport}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebMvcStatelessServerTransport with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private Builder() { + // used by a static method + } + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link WebMvcStatelessServerTransport} with the + * configured settings. + * @return A new WebMvcStatelessServerTransport instance + * @throws IllegalStateException if required parameters are not set + */ + public WebMvcStatelessServerTransport build() { + Assert.notNull(objectMapper, "ObjectMapper must be set"); + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + + return new WebMvcStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java new file mode 100644 index 000000000..b2264ea00 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStatelessIntegrationTests.java @@ -0,0 +1,165 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.AbstractStatelessIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStatelessServerTransport; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.scheduler.Schedulers; + +class WebMvcStatelessIntegrationTests extends AbstractStatelessIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcStatelessServerTransport mcpServerTransport; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStatelessServerTransport webMvcStatelessServerTransport() { + + return WebMvcStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(MESSAGE_ENDPOINT) + .build(); + + } + + @Bean + public RouterFunction routerFunction(WebMvcStatelessServerTransport statelessServerTransport) { + return statelessServerTransport.getRouterFunction(); + } + + } + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient.sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(MESSAGE_ENDPOINT) + .build())); + + // Get the transport from Spring context + this.mcpServerTransport = tomcatServer.appContext().getBean(WebMvcStatelessServerTransport.class); + + } + + @Override + protected StatelessAsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransport); + } + + @Override + protected StatelessSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransport); + } + + @AfterEach + public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); + if (this.mcpServerTransport != null) { + this.mcpServerTransport.closeGracefully().block(); + } + Schedulers.shutdownNow(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = McpServer.async(this.mcpServerTransport) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + server.closeGracefully(); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient.sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build())); + } + +} diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java new file mode 100644 index 000000000..a84d127aa --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractStatelessIntegrationTests.java @@ -0,0 +1,538 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.awaitility.Awaitility.await; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.server.McpServer.StatelessAsyncSpecification; +import io.modelcontextprotocol.server.McpServer.StatelessSyncSpecification; +import io.modelcontextprotocol.server.McpStatelessServerFeatures; +import io.modelcontextprotocol.server.McpStatelessSyncServer; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; +import reactor.core.publisher.Mono; + +public abstract class AbstractStatelessIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected StatelessAsyncSpecification prepareAsyncServerBuilder(); + + abstract protected StatelessSyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + server.closeGracefully(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((ctx, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(callResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpStatelessSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((context, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool + // instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((ctx, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + rootsRef.set(toolsUpdate); + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + // Remove a tool + mcpServer.removeTool("tool1"); + + // Add a new tool + McpStatelessServerFeatures.SyncToolSpecification tool2 = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat(response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = McpStatelessServerFeatures.SyncToolSpecification + .builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + var tool = McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + var toolSpec = McpStatelessServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java index f154272ef..8be59a779 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java @@ -22,6 +22,7 @@ * support. * * @author Dariusz Jędrzejczyk + * @author Christian Tzolov */ public class McpStatelessServerFeatures { @@ -212,6 +213,59 @@ static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boole return new AsyncToolSpecification(syncToolSpec.tool(), callHandler); } + + /** + * Builder for creating AsyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction> callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction> callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the AsyncToolSpecification instance. + * @return a new AsyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public AsyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Call handler function must not be null"); + + return new AsyncToolSpecification(tool, callHandler); + } + + } + + /** + * Creates a new builder instance. + * @return a new Builder instance + */ + public static Builder builder() { + return new Builder(); + } } /** @@ -324,6 +378,55 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet */ public record SyncToolSpecification(McpSchema.Tool tool, BiFunction callHandler) { + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating SyncToolSpecification instances. + */ + public static class Builder { + + private McpSchema.Tool tool; + + private BiFunction callHandler; + + /** + * Sets the tool definition. + * @param tool The tool definition including name, description, and parameter + * schema + * @return this builder instance + */ + public Builder tool(McpSchema.Tool tool) { + this.tool = tool; + return this; + } + + /** + * Sets the call tool handler function. + * @param callHandler The function that implements the tool's logic + * @return this builder instance + */ + public Builder callHandler( + BiFunction callHandler) { + this.callHandler = callHandler; + return this; + } + + /** + * Builds the SyncToolSpecification instance. + * @return a new SyncToolSpecification instance + * @throws IllegalArgumentException if required fields are not set + */ + public SyncToolSpecification build() { + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "CallTool function must not be null"); + + return new SyncToolSpecification(tool, callHandler); + } + + } } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java new file mode 100644 index 000000000..25b003564 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java @@ -0,0 +1,306 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpStatelessServerHandler; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.Assert; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import reactor.core.publisher.Mono; + +/** + * Implementation of an HttpServlet based {@link McpStatelessServerTransport}. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +@WebServlet(asyncSupported = true) +public class HttpServletStatelessServerTransport extends HttpServlet implements McpStatelessServerTransport { + + private static final Logger logger = LoggerFactory.getLogger(HttpServletStatelessServerTransport.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String ACCEPT = "Accept"; + + public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}"; + + private final ObjectMapper objectMapper; + + private final String mcpEndpoint; + + private McpStatelessServerHandler mcpHandler; + + private McpTransportContextExtractor contextExtractor; + + private volatile boolean isClosing = false; + + private HttpServletStatelessServerTransport(ObjectMapper objectMapper, String mcpEndpoint, + McpTransportContextExtractor contextExtractor) { + Assert.notNull(objectMapper, "objectMapper must not be null"); + Assert.notNull(mcpEndpoint, "mcpEndpoint must not be null"); + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.contextExtractor = contextExtractor; + } + + @Override + public void setMcpHandler(McpStatelessServerHandler mcpHandler) { + this.mcpHandler = mcpHandler; + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> this.isClosing = true); + } + + /** + * Handles GET requests - returns 405 METHOD NOT ALLOWED as stateless transport + * doesn't support GET requests. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED); + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The HTTP servlet request containing the JSON-RPC message + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing) { + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + String accept = request.getHeader(ACCEPT); + if (accept == null || !(accept.contains(APPLICATION_JSON) && accept.contains(TEXT_EVENT_STREAM))) { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("Both application/json and text/event-stream required in Accept header")); + return; + } + + try { + BufferedReader reader = request.getReader(); + StringBuilder body = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + try { + McpSchema.JSONRPCResponse jsonrpcResponse = this.mcpHandler + .handleRequest(transportContext, jsonrpcRequest) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(HttpServletResponse.SC_OK); + + String jsonResponseText = objectMapper.writeValueAsString(jsonrpcResponse); + PrintWriter writer = response.getWriter(); + writer.write(jsonResponseText); + writer.flush(); + } + catch (Exception e) { + logger.error("Failed to handle request: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to handle request: " + e.getMessage())); + } + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + try { + this.mcpHandler.handleNotification(transportContext, jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + catch (Exception e) { + logger.error("Failed to handle notification: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Failed to handle notification: " + e.getMessage())); + } + } + else { + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, + new McpError("The server accepts either requests or notifications")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_BAD_REQUEST, new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Unexpected error handling message: {}", e.getMessage()); + this.responseError(response, HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + new McpError("Unexpected error: " + e.getMessage())); + } + } + + /** + * Sends an error response to the client. + * @param response The HTTP servlet response + * @param httpCode The HTTP status code + * @param mcpError The MCP error to send + * @throws IOException If an I/O error occurs + */ + private void responseError(HttpServletResponse response, int httpCode, McpError mcpError) throws IOException { + response.setContentType(APPLICATION_JSON); + response.setCharacterEncoding(UTF_8); + response.setStatus(httpCode); + String jsonError = objectMapper.writeValueAsString(mcpError); + PrintWriter writer = response.getWriter(); + writer.write(jsonError); + writer.flush(); + } + + /** + * Cleans up resources when the servlet is being destroyed. + *

+ * This method ensures a graceful shutdown before calling the parent's destroy method. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Create a builder for the server. + * @return a fresh {@link Builder} instance. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link HttpServletStatelessServerTransport}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * HttpServletStatelessServerTransport with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + private Builder() { + // used by a static method + } + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "Context extractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link HttpServletStatelessServerTransport} with the + * configured settings. + * @return A new HttpServletStatelessServerTransport instance + * @throws IllegalStateException if required parameters are not set + */ + public HttpServletStatelessServerTransport build() { + Assert.notNull(objectMapper, "ObjectMapper must be set"); + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + + return new HttpServletStatelessServerTransport(objectMapper, mcpEndpoint, contextExtractor); + } + + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java new file mode 100644 index 000000000..da8aa4adf --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -0,0 +1,473 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.web.client.RestClient; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.server.transport.HttpServletStatelessServerTransport; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; + +class HttpServletStatelessIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private HttpServletStatelessServerTransport mcpStatelessServerTransport; + + ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + private Tomcat tomcat; + + @BeforeEach + public void before() { + this.mcpStatelessServerTransport = HttpServletStatelessServerTransport.builder() + .objectMapper(new ObjectMapper()) + .messageEndpoint(CUSTOM_MESSAGE_ENDPOINT) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpStatelessServerTransport); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + } + + @AfterEach + public void after() { + if (mcpStatelessServerTransport != null) { + mcpStatelessServerTransport.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpStatelessServerFeatures.SyncToolSpecification tool1 = new McpStatelessServerFeatures.SyncToolSpecification( + new Tool("tool1", "tool1 description", emptyJsonSchema), (transportContext, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport).build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (transportContext, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpStatelessServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (transportContext, getPromptRequest) -> null)) + .completions(new McpStatelessServerFeatures.SyncCompletionSpecification( + new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference("ref/prompt", "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputValidationFailure(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification tool = new McpStatelessServerFeatures.SyncToolSpecification( + calculatorTool, (transportContext, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }); + + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .instructions("bla") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = McpServer.sync(mcpStatelessServerTransport) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpStatelessServerFeatures.SyncToolSpecification toolSpec = new McpStatelessServerFeatures.SyncToolSpecification( + dynamicTool, (transportContext, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +}