diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 53b59cb30..e8cb26144 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -340,7 +340,8 @@ private Flux extractError(ClientResponse response, Str McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, McpSchema.JSONRPCResponse.class); jsonRpcError = jsonRpcResponse.error(); - toPropagate = new McpError(jsonRpcError); + toPropagate = jsonRpcError != null ? new McpError(jsonRpcError) + : new McpError("Can't parse the jsonResponse " + jsonRpcResponse); } catch (IOException ex) { toPropagate = new RuntimeException("Sending request failed", e); diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java new file mode 100644 index 000000000..172882007 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -0,0 +1,479 @@ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.DefaultMcpTransportContext; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpTransportContext; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; + +public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxStreamableServerTransportProvider.class); + + public static final String MESSAGE_EVENT_TYPE = "message"; + + public static final String DEFAULT_BASE_URL = ""; + + private final ObjectMapper objectMapper; + + private final String baseUrl; + + private final String mcpEndpoint; + + private final boolean disallowDelete; + + private final RouterFunction routerFunction; + + private McpStreamableServerSession.Factory sessionFactory; + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + // TODO: add means to specify this + private Function contextExtractor = req -> new DefaultMcpTransportContext(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Constructs a new WebFlux SSE server transport provider instance with the default + * SSE endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint) { + this(objectMapper, DEFAULT_BASE_URL, mcpEndpoint, false); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String mcpEndpoint, + boolean disallowDelete) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(baseUrl, "Message base path must not be null"); + Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); + + this.objectMapper = objectMapper; + this.baseUrl = baseUrl; + this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .DELETE(this.mcpEndpoint, this::handleDelete) + .build(); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a JSON-RPC message to all connected clients through their SSE + * connections. The message is serialized to JSON and sent as a server-sent event to + * each active session. + * + *

+ * The method: + *

    + *
  • Serializes the message to JSON
  • + *
  • Creates a server-sent event with the message data
  • + *
  • Attempts to send the event to all active sessions
  • + *
  • Tracks and reports any delivery failures
  • + *
+ * @param method The JSON-RPC method to send to clients + * @param params The method parameters to send to clients + * @return A Mono that completes when the message has been sent to all sessions, or + * errors if any session fails to receive the message + */ + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + // FIXME: This javadoc makes claims about using isClosing flag but it's not + // actually + // doing that. + /** + * Initiates a graceful shutdown of all the sessions. This method ensures all active + * sessions are properly closed and cleaned up. + * + *

+ * The shutdown process: + *

    + *
  • Marks the transport as closing to prevent new connections
  • + *
  • Closes each active session
  • + *
  • Removes closed sessions from the sessions map
  • + *
  • Times out after 5 seconds if shutdown takes too long
  • + *
+ * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpStreamableServerSession::closeGracefully) + .then(); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines two endpoints: + *

    + *
  • GET {sseEndpoint} - For establishing SSE connections
  • + *
  • POST {messageEndpoint} - For receiving client messages
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients. Creates a new session for each + * connection and sets up the SSE event stream. + * @param request The incoming server request + * @return A Mono which emits a response with the SSE event stream + */ + private Mono handleGet(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.apply(request); + + return Mono.defer(() -> { + if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) { + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id + } + + String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id"); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + if (request.headers().asHttpHeaders().containsKey("mcp-last-id")) { + String lastId = request.headers().asHttpHeaders().getFirst("mcp-last-id"); + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(session.replay(lastId), ServerSentEvent.class); + } + + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport( + sink); + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + sink.onDispose(listeningStream::close); + }), ServerSentEvent.class); + + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + /** + * Handles incoming JSON-RPC messages from clients. Deserializes the message and + * processes it through the configured message handler. + * + *

+ * The handler: + *

    + *
  • Deserializes the incoming JSON-RPC message
  • + *
  • Passes it through the message handler chain
  • + *
  • Returns appropriate HTTP responses based on processing results
  • + *
  • Handles various error conditions with appropriate error responses
  • + *
+ * @param request The incoming server request containing the JSON-RPC message + * @return A Mono emitting the response indicating the message processing result + */ + private Mono handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.apply(request); + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), + new TypeReference() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + sessions.put(init.session().getId(), init.session()); + return init.initResult().map(initializeResult -> { + McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initializeResult, null); + try { + return this.objectMapper.writeValueAsString(jsonrpcResponse); + } + catch (IOException e) { + logger.warn("Failed to serialize initResponse", e); + throw Exceptions.propagate(e); + } + }) + .flatMap(initResult -> ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) + .header("mcp-session-id", init.session().getId()) + .bodyValue(initResult)); + } + + if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing")); + } + + String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id"); + McpStreamableServerSession session = sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .bodyValue(new McpError("Session not found: " + sessionId)); + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + return session.accept(jsonrpcResponse).then(ServerResponse.accepted().build()); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + return session.accept(jsonrpcNotification).then(ServerResponse.accepted().build()); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); + Mono stream = session.responseStream(jsonrpcRequest, st); + Disposable streamSubscription = stream.doOnError(err -> sink.error(err)) + .contextWrite(sink.contextView()) + .subscribe(); + sink.onCancel(streamSubscription); + }), ServerSentEvent.class); + } + else { + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }) + .switchIfEmpty(ServerResponse.badRequest().build()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private Mono handleDelete(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.apply(request); + + return Mono.defer(() -> { + if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) { + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id + } + + // TODO: The user can configure whether deletions are permitted + if (this.disallowDelete) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id"); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + return session.delete().then(ServerResponse.ok().build()); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private class WebFluxStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final FluxSink> sink; + + public WebFluxStreamableMcpSessionTransport(FluxSink> sink) { + this.sink = sink; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return this.sendMessage(message, null); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromSupplier(() -> { + try { + return objectMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .id(messageId) + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxStreamableServerTransportProvider}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebFluxSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String baseUrl = DEFAULT_BASE_URL; + + private String mcpEndpoint = "/mcp"; + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the project basePath as endpoint prefix where clients should send their + * JSON-RPC messages + * @param baseUrl the message basePath . Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if basePath is null + */ + public Builder basePath(String baseUrl) { + Assert.notNull(baseUrl, "basePath must not be null"); + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with + * the configured settings. + * @return A new WebFluxSseServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxStreamableServerTransportProvider build() { + Assert.notNull(objectMapper, "ObjectMapper must be set"); + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + + return new WebFluxStreamableServerTransportProvider(objectMapper, baseUrl, mcpEndpoint, false); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java new file mode 100644 index 000000000..1e11a38d6 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -0,0 +1,1493 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +class WebFluxStreamableIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider; + + ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + @BeforeEach + public void before() { + + this.mcpStreamableServerTransportProvider = new WebFluxStreamableServerTransportProvider(new ObjectMapper(), + CUSTOM_MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions + .toHttpHandler(mcpStreamableServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + clientBuilders + .put("webflux", McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) + .thenReturn(mock(CallToolResult.class))) + .build(); + + var server = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + server.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var createMessageRequest = CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + // Server + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var craeteMessageRequest = CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(craeteMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .requestTimeout(Duration.ofSeconds(4)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + // Server + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var craeteMessageRequest = CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message")))) + .build(); + + return exchange.createMessage(craeteMessageRequest).thenReturn(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .requestTimeout(Duration.ofSeconds(1)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + } + + mcpServer.closeGracefully().block(); + } + + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> exchange.createElicitation(mock(ElicitRequest.class)) + .then(Mono.just(mock(CallToolResult.class)))) + .build(); + + var server = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + var latch = new CountDownLatch(1); + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + try { + if (!latch.await(2, TimeUnit.SECONDS)) { + throw new RuntimeException("Timeout waiting for elicitation processing"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) // 1 second. + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }) + .build(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> { + }) + .tools(tool) + .build(); + + // Create client without roots capability + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsNotificationWithEmptyRootsList(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithMultipleHandlers(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }) + .build(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() + .tool(new Tool("tool2", "tool2 description", emptyJsonSchema)) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider).build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("logging-test", "Test logging notifications", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(new CallToolResult("Logging test completed", false)); + //@formatter:on + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + mcpServer.close(); + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(new CallToolResult(("Progress test completed"), false)); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference("ref/prompt", "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testPingSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that uses ping functionality + AtomicReference executionOrder = new AtomicReference<>(""); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + executionOrder.set(executionOrder.get() + "1"); + + // Test async ping behavior + return exchange.ping().doOnNext(result -> { + + assertThat(result).isNotNull(); + // Ping should return an empty object or map + assertThat(result).isInstanceOf(Map.class); + + executionOrder.set(executionOrder.get() + "2"); + assertThat(result).isNotNull(); + }).then(Mono.fromCallable(() -> { + executionOrder.set(executionOrder.get() + "3"); + return new CallToolResult("Async ping test completed", false); + })); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that tests ping async behavior + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); + + // Verify execution order + assertThat(executionOrder.get()).isEqualTo("123"); + } + + mcpServer.closeGracefully().block(); + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, + (exchange, request) -> { + String expression = (String) request.getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationFailure(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, + (exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, + (exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .instructions("bla") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool, + (exchange, request) -> { + int count = (Integer) request.getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index cc33e7b94..a3bdf10b0 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -29,8 +29,7 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private DisposableServer httpServer; - @Override - protected McpServerTransportProvider createMcpTransportProvider() { + private McpServerTransportProvider createMcpTransportProvider() { var transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .build(); @@ -41,6 +40,11 @@ protected McpServerTransportProvider createMcpTransportProvider() { return transportProvider; } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + @Override protected void onStart() { } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java index 2fc104538..3e28e96b8 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpSyncServerTests.java @@ -32,7 +32,11 @@ class WebFluxSseMcpSyncServerTests extends AbstractMcpSyncServerTests { private WebFluxSseServerTransportProvider transportProvider; @Override - protected McpServerTransportProvider createMcpTransportProvider() { + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + private McpServerTransportProvider createMcpTransportProvider() { transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .build(); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java new file mode 100644 index 000000000..0d7429a0c --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxStreamableMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + var transportProvider = new WebFluxStreamableServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java new file mode 100644 index 000000000..e15137a3f --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpSyncServerTests.java @@ -0,0 +1,57 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxStreamableMcpSyncServerTests extends AbstractMcpSyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + var transportProvider = new WebFluxStreamableServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index 6a6ad17e9..bb4c2bf37 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -49,8 +49,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro private AnnotationConfigWebApplicationContext appContext; - @Override - protected McpServerTransportProvider createMcpTransportProvider() { + private McpServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -90,6 +89,11 @@ protected McpServerTransportProvider createMcpTransportProvider() { return transportProvider; } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + @Override protected void onStart() { } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java index 1964703c1..7e49ddf3b 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseSyncServerTransportTests.java @@ -49,7 +49,11 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro private AnnotationConfigWebApplicationContext appContext; @Override - protected WebMvcSseServerTransportProvider createMcpTransportProvider() { + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + private WebMvcSseServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index eb08bdcde..68a60a17c 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -18,9 +18,12 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -42,7 +45,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); protected void onStart() { } @@ -63,28 +66,29 @@ void tearDown() { // Server Lifecycle Tests // --------------------------------------- - @Test - void testConstructorWithInvalidArguments() { + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "sse", "streamable" }) + void testConstructorWithInvalidArguments(String serverType) { assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy( - () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); + var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); } @@ -104,8 +108,7 @@ void testImmediateClose() { @Deprecated void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -119,8 +122,7 @@ void testAddTool() { @Test void testAddToolCall() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -137,8 +139,7 @@ void testAddToolCall() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -158,8 +159,7 @@ void testAddDuplicateTool() { void testAddDuplicateToolCall() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -180,8 +180,7 @@ void testDuplicateToolCallDuringBuilding() { Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) // Duplicate! @@ -203,8 +202,7 @@ void testDuplicateToolsInBatchListRegistration() { .build() // Duplicate! ); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(specs) .build()).isInstanceOf(IllegalArgumentException.class) @@ -215,8 +213,7 @@ void testDuplicateToolsInBatchListRegistration() { void testDuplicateToolsInBatchVarargsRegistration() { Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) @@ -235,8 +232,7 @@ void testDuplicateToolsInBatchVarargsRegistration() { void testRemoveTool() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -248,8 +244,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -264,8 +259,7 @@ void testRemoveNonexistentTool() { void testNotifyToolsListChanged() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -281,7 +275,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -290,7 +284,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier .create(mcpAsyncServer @@ -302,8 +296,7 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -319,8 +312,7 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -335,9 +327,7 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); @@ -353,9 +343,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -369,7 +357,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -378,8 +366,7 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); @@ -392,9 +379,7 @@ void testAddPromptWithNullSpecification() { @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( @@ -410,9 +395,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -429,8 +412,7 @@ void testRemovePrompt() { prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -442,8 +424,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -466,8 +447,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -486,8 +466,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -500,8 +479,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -513,9 +491,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 4d5f9f772..5ab9ae1ff 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -40,7 +40,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); protected void onStart() { } @@ -68,28 +68,28 @@ void testConstructorWithInvalidArguments() { .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); @@ -111,8 +111,7 @@ void testGetAsyncServer() { @Test @Deprecated void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -126,8 +125,7 @@ void testAddTool() { @Test void testAddToolCall() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -145,8 +143,7 @@ void testAddToolCall() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); @@ -163,8 +160,7 @@ void testAddDuplicateTool() { void testAddDuplicateToolCall() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); @@ -183,8 +179,7 @@ void testDuplicateToolCallDuringBuilding() { Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) // Duplicate! @@ -206,8 +201,7 @@ void testDuplicateToolsInBatchListRegistration() { .build() // Duplicate! ); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(specs) .build()).isInstanceOf(IllegalArgumentException.class) @@ -218,8 +212,7 @@ void testDuplicateToolsInBatchListRegistration() { void testDuplicateToolsInBatchVarargsRegistration() { Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) @@ -238,8 +231,7 @@ void testDuplicateToolsInBatchVarargsRegistration() { void testRemoveTool() { Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(tool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); @@ -251,8 +243,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -264,7 +255,7 @@ void testRemoveNonexistentTool() { @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); @@ -277,7 +268,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); @@ -286,7 +277,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) @@ -297,8 +288,7 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -314,8 +304,7 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -328,9 +317,7 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); @@ -343,9 +330,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); @@ -357,7 +342,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); @@ -366,8 +351,7 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); @@ -378,9 +362,7 @@ void testAddPromptWithNullSpecification() { @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, @@ -393,9 +375,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); @@ -408,8 +388,7 @@ void testRemovePrompt() { (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -421,8 +400,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -442,8 +420,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -461,8 +438,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -474,8 +450,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -486,7 +461,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 7131b10fa..fcd42a433 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -15,6 +15,9 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; +import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; +import io.modelcontextprotocol.spec.McpServerTransportProviderBase; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -86,7 +89,7 @@ public class McpAsyncServer { private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - private final McpServerTransportProvider mcpTransportProvider; + private final McpServerTransportProviderBase mcpTransportProvider; private final ObjectMapper objectMapper; @@ -139,7 +142,56 @@ public class McpAsyncServer { this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; - Map> requestHandlers = new HashMap<>(); + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + mcpTransportProvider.setSessionFactory(transport -> new McpServerSession(UUID.randomUUID().toString(), + requestTimeout, transport, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + } + + McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout, + this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + } + + private Map prepareNotificationHandlers(McpServerFeatures.Async features) { + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + return notificationHandlers; + } + + private Map> prepareRequestHandlers() { + Map> requestHandlers = new HashMap<>(); // Initialize request handlers for standard MCP methods @@ -174,25 +226,7 @@ public class McpAsyncServer { if (this.serverCapabilities.completions() != null) { requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features - .rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + return requestHandlers; } // --------------------------------------- @@ -258,7 +292,7 @@ public void close() { this.mcpTransportProvider.close(); } - private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( + private McpNotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { return (exchange, params) -> exchange.listRoots() .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) @@ -450,7 +484,7 @@ public Mono notifyToolsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); } - private McpServerSession.RequestHandler toolsListRequestHandler() { + private McpRequestHandler toolsListRequestHandler() { return (exchange, params) -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); @@ -458,7 +492,7 @@ private McpServerSession.RequestHandler toolsListRequ }; } - private McpServerSession.RequestHandler toolsCallRequestHandler() { + private McpRequestHandler toolsCallRequestHandler() { return (exchange, params) -> { McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, new TypeReference() { @@ -551,7 +585,7 @@ public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification); } - private McpServerSession.RequestHandler resourcesListRequestHandler() { + private McpRequestHandler resourcesListRequestHandler() { return (exchange, params) -> { var resourceList = this.resources.values() .stream() @@ -561,7 +595,7 @@ private McpServerSession.RequestHandler resources }; } - private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { + private McpRequestHandler resourceTemplateListRequestHandler() { return (exchange, params) -> Mono .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); @@ -585,7 +619,7 @@ private List getResourceTemplates() { return list; } - private McpServerSession.RequestHandler resourcesReadRequestHandler() { + private McpRequestHandler resourcesReadRequestHandler() { return (exchange, params) -> { McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, new TypeReference() { @@ -678,7 +712,7 @@ public Mono notifyPromptsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); } - private McpServerSession.RequestHandler promptsListRequestHandler() { + private McpRequestHandler promptsListRequestHandler() { return (exchange, params) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, @@ -694,7 +728,7 @@ private McpServerSession.RequestHandler promptsList }; } - private McpServerSession.RequestHandler promptsGetRequestHandler() { + private McpRequestHandler promptsGetRequestHandler() { return (exchange, params) -> { McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, new TypeReference() { @@ -740,7 +774,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN loggingMessageNotification); } - private McpServerSession.RequestHandler setLoggerRequestHandler() { + private McpRequestHandler setLoggerRequestHandler() { return (exchange, params) -> { return Mono.defer(() -> { @@ -759,7 +793,7 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { }; } - private McpServerSession.RequestHandler completionCompleteRequestHandler() { + private McpRequestHandler completionCompleteRequestHandler() { return (exchange, params) -> { McpSchema.CompleteRequest request = parseCompletionParams(params); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index c0923e10e..6ebbbe23e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -8,11 +8,15 @@ import java.util.Collections; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.DefaultMcpTransportContext; import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpLoggableSession; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpSession; +import io.modelcontextprotocol.spec.McpTransportContext; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -25,13 +29,15 @@ */ public class McpAsyncServerExchange { - private final McpServerSession session; + private final String sessionId; + + private final McpLoggableSession session; private final McpSchema.ClientCapabilities clientCapabilities; private final McpSchema.Implementation clientInfo; - private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; + private final McpTransportContext transportContext; private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { }; @@ -51,12 +57,39 @@ public class McpAsyncServerExchange { * @param clientCapabilities The client capabilities that define the supported * features and functionality. * @param clientInfo The client implementation information. + * @deprecated Use + * {@link #McpAsyncServerExchange(String, McpLoggableSession, McpSchema.ClientCapabilities, McpSchema.Implementation, McpTransportContext)} */ - public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, + @Deprecated + public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.sessionId = null; + if (!(session instanceof McpLoggableSession)) { + throw new IllegalArgumentException("Expecting session to be a McpLoggableSession instance"); + } + this.session = (McpLoggableSession) session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + this.transportContext = McpTransportContext.EMPTY; + } + + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param transportContext context associated with the client as extracted from the + * transport + * @param clientInfo The client implementation information. + */ + public McpAsyncServerExchange(String sessionId, McpLoggableSession session, + McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + McpTransportContext transportContext) { + this.sessionId = sessionId; this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; + this.transportContext = transportContext; } /** @@ -75,6 +108,14 @@ public McpSchema.Implementation getClientInfo() { return this.clientInfo; } + public McpTransportContext transportContext() { + return this.transportContext; + } + + public String sessionId() { + return this.sessionId; + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM @@ -170,7 +211,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } return Mono.defer(() -> { - if (this.isNotificationForLevelAllowed(loggingMessageNotification.level())) { + if (this.session.isNotificationForLevelAllowed(loggingMessageNotification.level())) { return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification); } return Mono.empty(); @@ -205,11 +246,7 @@ public Mono ping() { */ void setMinLoggingLevel(LoggingLevel minLoggingLevel) { Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); - this.minLoggingLevel = minLoggingLevel; - } - - private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) { - return loggingLevel.level() >= this.minLoggingLevel.level(); + this.session.setMinLoggingLevel(minLoggingLevel); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java new file mode 100644 index 000000000..609744637 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java @@ -0,0 +1,18 @@ +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Request handler for the initialization request. + */ +public interface McpInitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java new file mode 100644 index 000000000..6b1061c03 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java @@ -0,0 +1,19 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated notifications. + */ +public interface McpNotificationHandler { + + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java new file mode 100644 index 000000000..c9d70ad04 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java @@ -0,0 +1,22 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ +public interface McpRequestHandler { + + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d4b8addf4..502dd8d9c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -21,6 +21,9 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpTransportContext; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; @@ -131,6 +134,8 @@ */ public interface McpServer { + McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); + /** * Starts building a synchronous MCP server that provides blocking operations. * Synchronous servers block the current Thread's execution upon each request before @@ -139,8 +144,8 @@ public interface McpServer { * @param transportProvider The transport layer implementation for MCP communication. * @return A new instance of {@link SyncSpecification} for configuring the server. */ - static SyncSpecification sync(McpServerTransportProvider transportProvider) { - return new SyncSpecification(transportProvider); + static SingleSessionSyncSpecification sync(McpServerTransportProvider transportProvider) { + return new SingleSessionSyncSpecification(transportProvider); } /** @@ -151,31 +156,113 @@ static SyncSpecification sync(McpServerTransportProvider transportProvider) { * @param transportProvider The transport layer implementation for MCP communication. * @return A new instance of {@link AsyncSpecification} for configuring the server. */ - static AsyncSpecification async(McpServerTransportProvider transportProvider) { - return new AsyncSpecification(transportProvider); + static AsyncSpecification async(McpServerTransportProvider transportProvider) { + return new SingleSessionAsyncSpecification(transportProvider); } /** - * Asynchronous server specification. + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static StreamableSyncSpecification sync(McpStreamableServerTransportProvider transportProvider) { + return new StreamableSyncSpecification(transportProvider); + } + + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. */ - class AsyncSpecification { + static AsyncSpecification async(McpStreamableServerTransportProvider transportProvider) { + return new StreamableServerAsyncSpecification(transportProvider); + } + + static StatelessAsyncSpecification async(McpStatelessServerTransport transport) { + return new StatelessAsyncSpecification(transport); + } + + static StatelessSyncSpecification sync(McpStatelessServerTransport transport) { + return new StatelessSyncSpecification(transport); + } - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); + class SingleSessionAsyncSpecification extends AsyncSpecification { private final McpServerTransportProvider transportProvider; - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + private SingleSessionAsyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + @Override + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + class StreamableServerAsyncSpecification extends AsyncSpecification { + + private final McpStreamableServerTransportProvider transportProvider; + + public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider transportProvider) { + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + @Override + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + /** + * Asynchronous server specification. + */ + abstract class AsyncSpecification> { + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - private ObjectMapper objectMapper; + ObjectMapper objectMapper; - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - private McpSchema.ServerCapabilities serverCapabilities; + McpSchema.ServerCapabilities serverCapabilities; - private JsonSchemaValidator jsonSchemaValidator; + JsonSchemaValidator jsonSchemaValidator; - private String instructions; + String instructions; /** * The Model Context Protocol (MCP) allows servers to expose tools that can be @@ -184,7 +271,7 @@ class AsyncSpecification { * Each tool is uniquely identified by a name and includes metadata describing its * schema. */ - private final List tools = new ArrayList<>(); + final List tools = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -193,9 +280,9 @@ class AsyncSpecification { * application-specific information. Each resource is uniquely identified by a * URI. */ - private final Map resources = new HashMap<>(); + final Map resources = new HashMap<>(); - private final List resourceTemplates = new ArrayList<>(); + final List resourceTemplates = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -204,18 +291,15 @@ class AsyncSpecification { * discover available prompts, retrieve their contents, and provide arguments to * customize them. */ - private final Map prompts = new HashMap<>(); + final Map prompts = new HashMap<>(); - private final Map completions = new HashMap<>(); + final Map completions = new HashMap<>(); - private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + final List, Mono>> rootsChangeHandlers = new ArrayList<>(); - private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + Duration requestTimeout = Duration.ofHours(10); // Default timeout - private AsyncSpecification(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transportProvider = transportProvider; - } + public abstract McpAsyncServer build(); /** * Sets the URI template manager factory to use for creating URI templates. This @@ -224,7 +308,7 @@ private AsyncSpecification(McpServerTransportProvider transportProvider) { * @return This builder instance for method chaining * @throws IllegalArgumentException if uriTemplateManagerFactory is null */ - public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); this.uriTemplateManagerFactory = uriTemplateManagerFactory; return this; @@ -239,7 +323,7 @@ public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory * @return This builder instance for method chaining * @throws IllegalArgumentException if requestTimeout is null */ - public AsyncSpecification requestTimeout(Duration requestTimeout) { + public AsyncSpecification requestTimeout(Duration requestTimeout) { Assert.notNull(requestTimeout, "Request timeout must not be null"); this.requestTimeout = requestTimeout; return this; @@ -254,7 +338,7 @@ public AsyncSpecification requestTimeout(Duration requestTimeout) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; return this; @@ -270,7 +354,7 @@ public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public AsyncSpecification serverInfo(String name, String version) { + public AsyncSpecification serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); @@ -284,7 +368,7 @@ public AsyncSpecification serverInfo(String name, String version) { * @param instructions The instructions text. Can be null or empty. * @return This builder instance for method chaining */ - public AsyncSpecification instructions(String instructions) { + public AsyncSpecification instructions(String instructions) { this.instructions = instructions; return this; } @@ -303,7 +387,7 @@ public AsyncSpecification instructions(String instructions) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; return this; @@ -334,7 +418,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * calls that require a request object. */ @Deprecated - public AsyncSpecification tool(McpSchema.Tool tool, + public AsyncSpecification tool(McpSchema.Tool tool, BiFunction, Mono> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -358,7 +442,7 @@ public AsyncSpecification tool(McpSchema.Tool tool, * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public AsyncSpecification toolCall(McpSchema.Tool tool, + public AsyncSpecification toolCall(McpSchema.Tool tool, BiFunction> callHandler) { Assert.notNull(tool, "Tool must not be null"); @@ -381,7 +465,7 @@ public AsyncSpecification toolCall(McpSchema.Tool tool, * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(McpServerFeatures.AsyncToolSpecification...) */ - public AsyncSpecification tools(List toolSpecifications) { + public AsyncSpecification tools(List toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { @@ -408,7 +492,7 @@ public AsyncSpecification tools(List t * @return This builder instance for method chaining * @throws IllegalArgumentException if toolSpecifications is null */ - public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { @@ -434,7 +518,7 @@ private void assertNoDuplicateTool(String toolName) { * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public AsyncSpecification resources( + public AsyncSpecification resources( Map resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); this.resources.putAll(resourceSpecifications); @@ -450,7 +534,8 @@ public AsyncSpecification resources( * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public AsyncSpecification resources(List resourceSpecifications) { + public AsyncSpecification resources( + List resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -475,7 +560,7 @@ public AsyncSpecification resources(List resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -500,7 +585,7 @@ public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(ResourceTemplate...) */ - public AsyncSpecification resourceTemplates(List resourceTemplates) { + public AsyncSpecification resourceTemplates(List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); return this; @@ -514,7 +599,7 @@ public AsyncSpecification resourceTemplates(List resourceTempl * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); @@ -539,7 +624,7 @@ public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplate * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public AsyncSpecification prompts(Map prompts) { + public AsyncSpecification prompts(Map prompts) { Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; @@ -553,7 +638,7 @@ public AsyncSpecification prompts(Map prompts) { + public AsyncSpecification prompts(List prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); @@ -577,7 +662,7 @@ public AsyncSpecification prompts(List prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); @@ -592,7 +677,7 @@ public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... * @return This builder instance for method chaining * @throws IllegalArgumentException if completions is null */ - public AsyncSpecification completions(List completions) { + public AsyncSpecification completions(List completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -607,7 +692,7 @@ public AsyncSpecification completions(List completions(McpServerFeatures.AsyncCompletionSpecification... completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -625,7 +710,7 @@ public AsyncSpecification completions(McpServerFeatures.AsyncCompletionSpecifica * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ - public AsyncSpecification rootsChangeHandler( + public AsyncSpecification rootsChangeHandler( BiFunction, Mono> handler) { Assert.notNull(handler, "Consumer must not be null"); this.rootsChangeHandlers.add(handler); @@ -641,7 +726,7 @@ public AsyncSpecification rootsChangeHandler( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandler(BiFunction) */ - public AsyncSpecification rootsChangeHandlers( + public AsyncSpecification rootsChangeHandlers( List, Mono>> handlers) { Assert.notNull(handlers, "Handlers list must not be null"); this.rootsChangeHandlers.addAll(handlers); @@ -657,7 +742,7 @@ public AsyncSpecification rootsChangeHandlers( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandlers(List) */ - public AsyncSpecification rootsChangeHandlers( + public AsyncSpecification rootsChangeHandlers( @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(Arrays.asList(handlers)); @@ -669,7 +754,7 @@ public AsyncSpecification rootsChangeHandlers( * @return This builder instance for method chaining. * @throws IllegalArgumentException if objectMapper is null */ - public AsyncSpecification objectMapper(ObjectMapper objectMapper) { + public AsyncSpecification objectMapper(ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); this.objectMapper = objectMapper; return this; @@ -683,26 +768,76 @@ public AsyncSpecification objectMapper(ObjectMapper objectMapper) { * @return This builder instance for method chaining * @throws IllegalArgumentException if jsonSchemaValidator is null */ - public AsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + public AsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); this.jsonSchemaValidator = jsonSchemaValidator; return this; } + } + + class SingleSessionSyncSpecification extends SyncSpecification { + + private final McpServerTransportProvider transportProvider; + + private SingleSessionSyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + /** - * Builds an asynchronous MCP server that provides non-blocking operations. - * @return A new instance of {@link McpAsyncServer} configured with this builder's + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's * settings. */ - public McpAsyncServer build() { - var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, - this.instructions); + @Override + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator : new DefaultJsonSchemaValidator(mapper); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + + return new McpSyncServer(asyncServer, this.immediateExecution); + } + + } + + class StreamableSyncSpecification extends SyncSpecification { + + private final McpStreamableServerTransportProvider transportProvider; + + private StreamableSyncSpecification(McpStreamableServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings. + */ + @Override + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); + + return new McpSyncServer(asyncServer, this.immediateExecution); } } @@ -710,22 +845,17 @@ public McpAsyncServer build() { /** * Synchronous server specification. */ - class SyncSpecification { + abstract class SyncSpecification> { - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + ObjectMapper objectMapper; - private final McpServerTransportProvider transportProvider; - - private ObjectMapper objectMapper; + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + McpSchema.ServerCapabilities serverCapabilities; - private McpSchema.ServerCapabilities serverCapabilities; - - private String instructions; + String instructions; /** * The Model Context Protocol (MCP) allows servers to expose tools that can be @@ -734,7 +864,7 @@ class SyncSpecification { * Each tool is uniquely identified by a name and includes metadata describing its * schema. */ - private final List tools = new ArrayList<>(); + final List tools = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -743,11 +873,11 @@ class SyncSpecification { * application-specific information. Each resource is uniquely identified by a * URI. */ - private final Map resources = new HashMap<>(); + final Map resources = new HashMap<>(); - private final List resourceTemplates = new ArrayList<>(); + final List resourceTemplates = new ArrayList<>(); - private JsonSchemaValidator jsonSchemaValidator; + JsonSchemaValidator jsonSchemaValidator; /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -756,20 +886,17 @@ class SyncSpecification { * discover available prompts, retrieve their contents, and provide arguments to * customize them. */ - private final Map prompts = new HashMap<>(); + final Map prompts = new HashMap<>(); - private final Map completions = new HashMap<>(); + final Map completions = new HashMap<>(); - private final List>> rootsChangeHandlers = new ArrayList<>(); + final List>> rootsChangeHandlers = new ArrayList<>(); - private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout - private boolean immediateExecution = false; + boolean immediateExecution = false; - private SyncSpecification(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transportProvider = transportProvider; - } + public abstract McpSyncServer build(); /** * Sets the URI template manager factory to use for creating URI templates. This @@ -778,7 +905,7 @@ private SyncSpecification(McpServerTransportProvider transportProvider) { * @return This builder instance for method chaining * @throws IllegalArgumentException if uriTemplateManagerFactory is null */ - public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); this.uriTemplateManagerFactory = uriTemplateManagerFactory; return this; @@ -790,10 +917,10 @@ public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory * resource access, and prompt operations. * @param requestTimeout The duration to wait before timing out requests. Must not * be null. - * @return This builder instance for method chaining + * @return this builder instance for method chaining * @throws IllegalArgumentException if requestTimeout is null */ - public SyncSpecification requestTimeout(Duration requestTimeout) { + public SyncSpecification requestTimeout(Duration requestTimeout) { Assert.notNull(requestTimeout, "Request timeout must not be null"); this.requestTimeout = requestTimeout; return this; @@ -808,7 +935,7 @@ public SyncSpecification requestTimeout(Duration requestTimeout) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; return this; @@ -824,7 +951,7 @@ public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public SyncSpecification serverInfo(String name, String version) { + public SyncSpecification serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); @@ -838,7 +965,7 @@ public SyncSpecification serverInfo(String name, String version) { * @param instructions The instructions text. Can be null or empty. * @return This builder instance for method chaining */ - public SyncSpecification instructions(String instructions) { + public SyncSpecification instructions(String instructions) { this.instructions = instructions; return this; } @@ -857,7 +984,7 @@ public SyncSpecification instructions(String instructions) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; return this; @@ -887,7 +1014,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * calls that require a request object. */ @Deprecated - public SyncSpecification tool(McpSchema.Tool tool, + public SyncSpecification tool(McpSchema.Tool tool, BiFunction, McpSchema.CallToolResult> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -911,7 +1038,7 @@ public SyncSpecification tool(McpSchema.Tool tool, * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public SyncSpecification toolCall(McpSchema.Tool tool, + public SyncSpecification toolCall(McpSchema.Tool tool, BiFunction handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -932,7 +1059,7 @@ public SyncSpecification toolCall(McpSchema.Tool tool, * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(McpServerFeatures.SyncToolSpecification...) */ - public SyncSpecification tools(List toolSpecifications) { + public SyncSpecification tools(List toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { @@ -961,7 +1088,7 @@ public SyncSpecification tools(List too * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(List) */ - public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { + public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { @@ -987,7 +1114,7 @@ private void assertNoDuplicateTool(String toolName) { * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.SyncResourceSpecification...) */ - public SyncSpecification resources( + public SyncSpecification resources( Map resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); this.resources.putAll(resourceSpecifications); @@ -1003,7 +1130,8 @@ public SyncSpecification resources( * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.SyncResourceSpecification...) */ - public SyncSpecification resources(List resourceSpecifications) { + public SyncSpecification resources( + List resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -1028,7 +1156,7 @@ public SyncSpecification resources(List resources(McpServerFeatures.SyncResourceSpecification... resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -1053,7 +1181,7 @@ public SyncSpecification resources(McpServerFeatures.SyncResourceSpecification.. * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(ResourceTemplate...) */ - public SyncSpecification resourceTemplates(List resourceTemplates) { + public SyncSpecification resourceTemplates(List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); return this; @@ -1067,7 +1195,7 @@ public SyncSpecification resourceTemplates(List resourceTempla * @throws IllegalArgumentException if resourceTemplates is null * @see #resourceTemplates(List) */ - public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); @@ -1093,7 +1221,7 @@ public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public SyncSpecification prompts(Map prompts) { + public SyncSpecification prompts(Map prompts) { Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; @@ -1107,7 +1235,7 @@ public SyncSpecification prompts(Map prompts) { + public SyncSpecification prompts(List prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); @@ -1131,7 +1259,7 @@ public SyncSpecification prompts(List * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { + public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); @@ -1147,7 +1275,7 @@ public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... pr * @throws IllegalArgumentException if completions is null * @see #completions(McpServerFeatures.SyncCompletionSpecification...) */ - public SyncSpecification completions(List completions) { + public SyncSpecification completions(List completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.SyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -1162,7 +1290,7 @@ public SyncSpecification completions(List completions(McpServerFeatures.SyncCompletionSpecification... completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.SyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -1180,7 +1308,8 @@ public SyncSpecification completions(McpServerFeatures.SyncCompletionSpecificati * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ - public SyncSpecification rootsChangeHandler(BiConsumer> handler) { + public SyncSpecification rootsChangeHandler( + BiConsumer> handler) { Assert.notNull(handler, "Consumer must not be null"); this.rootsChangeHandlers.add(handler); return this; @@ -1195,7 +1324,7 @@ public SyncSpecification rootsChangeHandler(BiConsumer rootsChangeHandlers( List>> handlers) { Assert.notNull(handlers, "Handlers list must not be null"); this.rootsChangeHandlers.addAll(handlers); @@ -1211,7 +1340,7 @@ public SyncSpecification rootsChangeHandlers( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandlers(List) */ - public SyncSpecification rootsChangeHandlers( + public SyncSpecification rootsChangeHandlers( BiConsumer>... handlers) { Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(List.of(handlers)); @@ -1223,13 +1352,13 @@ public SyncSpecification rootsChangeHandlers( * @return This builder instance for method chaining. * @throws IllegalArgumentException if objectMapper is null */ - public SyncSpecification objectMapper(ObjectMapper objectMapper) { + public SyncSpecification objectMapper(ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); this.objectMapper = objectMapper; return this; } - public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); this.jsonSchemaValidator = jsonSchemaValidator; return this; @@ -1246,30 +1375,966 @@ public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValid * @return This builder instance for method chaining. * */ - public SyncSpecification immediateExecution(boolean immediateExecution) { + public SyncSpecification immediateExecution(boolean immediateExecution) { this.immediateExecution = immediateExecution; return this; } + } + + class StatelessAsyncSpecification { + + private final McpStatelessServerTransport transport; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + ObjectMapper objectMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + /** - * Builds a synchronous MCP server that provides blocking operations. - * @return A new instance of {@link McpSyncServer} configured with this builder's - * settings. + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. */ - public McpSyncServer build() { - McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, - this.rootsChangeHandlers, this.instructions); - McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, - this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); + final List tools = new ArrayList<>(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); - return new McpSyncServer(asyncServer, this.immediateExecution); + final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessAsyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessAsyncSpecification uriTemplateManagerFactory( + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessAsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessAsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessAsyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessAsyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessAsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpAsyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessAsyncSpecification toolCall(McpSchema.Tool tool, + BiFunction> callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.AsyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.AsyncToolSpecification...) + */ + public StatelessAsyncSpecification tools( + List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessAsyncSpecification tools( + McpStatelessServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessAsyncSpecification resources( + McpStatelessServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

+ * Example usage:

{@code
+		 * .resourceTemplates(
+		 *     new ResourceTemplate("file://{path}", "Access files by path"),
+		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
+		 * )
+		 * }
+ * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(ResourceTemplate...) + */ + public StatelessAsyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessAsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
+		 *         .map(GetPromptResult::new)
+		 * )));
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts( + Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.AsyncPromptSpecification...) + */ + public StatelessAsyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts(McpStatelessServerFeatures.AsyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions( + List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions( + McpStatelessServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public StatelessAsyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessAsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + public McpStatelessAsyncServer build() { + var features = new McpStatelessServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + return new McpStatelessAsyncServer(this.transport, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + } + + } + + class StatelessSyncSpecification { + + private final McpStatelessServerTransport transport; + + boolean immediateExecution = false; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + ObjectMapper objectMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessSyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessSyncSpecification uriTemplateManagerFactory( + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessSyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessSyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessSyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessSyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessSyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpSyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessSyncSpecification toolCall(McpSchema.Tool tool, + BiFunction callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.SyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.SyncToolSpecification...) + */ + public StatelessSyncSpecification tools( + List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessSyncSpecification tools( + McpStatelessServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources( + List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new McpServerFeatures.SyncResourceSpecification(fileResource, fileHandler),
+		 *     new McpServerFeatures.SyncResourceSpecification(dbResource, dbHandler),
+		 *     new McpServerFeatures.SyncResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessSyncSpecification resources( + McpStatelessServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

+ * Example usage:

{@code
+		 * .resourceTemplates(
+		 *     new ResourceTemplate("file://{path}", "Access files by path"),
+		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
+		 * )
+		 * }
+ * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(ResourceTemplate...) + */ + public StatelessSyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessSyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * .prompts(Map.of("analysis", new McpServerFeatures.SyncPromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
+		 *         .map(GetPromptResult::new)
+		 * )));
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts( + Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.SyncPromptSpecification...) + */ + public StatelessSyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new McpServerFeatures.SyncPromptSpecification(analysisPrompt, analysisHandler),
+		 *     new McpServerFeatures.SyncPromptSpecification(summaryPrompt, summaryHandler),
+		 *     new McpServerFeatures.SyncPromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts(McpStatelessServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions( + List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions( + McpStatelessServerFeatures.SyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public StatelessSyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessSyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enable on "immediate execution" of the operations on the underlying + * {@link McpStatelessAsyncServer}. Defaults to false, which does blocking code + * offloading to prevent accidental blocking of the non-blocking transport. + *

+ * Do NOT set to true if the underlying transport is a non-blocking + * implementation. + * @param immediateExecution When true, do not offload work asynchronously. + * @return This builder instance for method chaining. + * + */ + public StatelessSyncSpecification immediateExecution(boolean immediateExecution) { + this.immediateExecution = immediateExecution; + return this; + } + + public McpStatelessSyncServer build() { + /* + * McpServerFeatures.Sync syncFeatures = new + * McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + * this.tools, this.resources, this.resourceTemplates, this.prompts, + * this.completions, this.rootsChangeHandlers, this.instructions); + * McpServerFeatures.Async asyncFeatures = + * McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); + * var mapper = this.objectMapper != null ? this.objectMapper : new + * ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null + * ? this.jsonSchemaValidator : new DefaultJsonSchemaValidator(mapper); + * + * var asyncServer = new McpAsyncServer(this.transportProvider, mapper, + * asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, + * jsonSchemaValidator); + * + * return new McpSyncServer(asyncServer, this.immediateExecution); + */ + var syncFeatures = new McpStatelessServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + var asyncFeatures = McpStatelessServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + var asyncServer = new McpStatelessAsyncServer(this.transport, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + return new McpStatelessSyncServer(asyncServer, this.immediateExecution); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java new file mode 100644 index 000000000..fa768d7f8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -0,0 +1,571 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.spec.McpTransportContext; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + +/** + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessAsyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessAsyncServer.class); + + private final McpStatelessServerTransport mcpTransportProvider; + + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + private final JsonSchemaValidator jsonSchemaValidator; + + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, ObjectMapper objectMapper, + McpStatelessServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransport; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, (ctx, params) -> Mono.just(Map.of())); + + requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + + mcpTransport.setRequestHandler((context, request) -> requestHandlers.get(request.method()) + .apply(context, request.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(t -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, t.getMessage(), + null))))); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private RequestHandler asyncInitializeRequestHandler() { + return (ctx, req) -> Mono.defer(() -> { + McpSchema.InitializeRequest initializeRequest = this.objectMapper.convertValue(req, + McpSchema.InitializeRequest.class); + + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpTransportProvider.close(); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolSpecification.callHandler() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); + } + + this.tools.add(toolSpecification); + logger.debug("Added tool handler: {}", toolSpecification.tool().name()); + + return Mono.empty(); + }); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + + private RequestHandler toolsListRequestHandler() { + return (ctx, params) -> { + List tools = this.tools.stream() + .map(McpStatelessServerFeatures.AsyncToolSpecification::tool) + .toList(); + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private RequestHandler toolsCallRequestHandler() { + return (ctx, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolSpecification.map(tool -> tool.callHandler().apply(ctx, callToolRequest)) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpStatelessServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + return Mono.empty(); + }); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); + } + + private RequestHandler resourcesListRequestHandler() { + return (ctx, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpStatelessServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private RequestHandler resourceTemplateListRequestHandler() { + return (ctx, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + + } + + private List getResourceTemplates() { + var list = new ArrayList<>(this.resourceTemplates); + List resourceTemplates = this.resources.keySet() + .stream() + .filter(uri -> uri.contains("{")) + .map(uri -> { + var resource = this.resources.get(uri).resource(); + var template = new ResourceTemplate(resource.uri(), resource.name(), resource.title(), + resource.description(), resource.mimeType(), resource.annotations()); + return template; + }) + .toList(); + + list.addAll(resourceTemplates); + + return list; + } + + private RequestHandler resourcesReadRequestHandler() { + return (ctx, params) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + + McpStatelessServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + + return specification.readHandler().apply(ctx, resourceRequest); + }; + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { + return Mono.error(new McpError("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts + .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); + if (specification != null) { + return Mono.error( + new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + + return Mono.empty(); + }); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpStatelessServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); + } + + private RequestHandler promptsListRequestHandler() { + return (ctx, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpStatelessServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private RequestHandler promptsGetRequestHandler() { + return (ctx, params) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return specification.promptHandler().apply(ctx, promptRequest); + }; + } + + private RequestHandler completionCompleteRequestHandler() { + return (ctx, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); + + if (request.ref() == null) { + return Mono.error(new McpError("ref must not be null")); + } + + if (request.ref().type() == null) { + return Mono.error(new McpError("type must not be null")); + } + + String type = request.ref().type(); + + String argumentName = request.argument().name(); + + // check if the referenced resource exists + if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpStatelessServerFeatures.AsyncPromptSpecification promptSpec = this.prompts + .get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(new McpError("Prompt not found: " + promptReference.name())); + } + if (promptSpec.prompt().arguments().stream().noneMatch(arg -> arg.name().equals(argumentName))) { + + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + } + + if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this.resources + .get(resourceReference.uri()); + if (resourceSpec == null) { + return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); + } + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + + } + + McpStatelessServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + + if (specification == null) { + return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + } + + return specification.completionHandler().apply(ctx, request); + }; + } + + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + refMap.get("title") != null ? (String) refMap.get("title") : null); + case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + + @FunctionalInterface + interface RequestHandler extends BiFunction> { + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java new file mode 100644 index 000000000..c64ec7cf7 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java @@ -0,0 +1,380 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpTransportContext; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * MCP stateless server features specification that a particular server can choose to + * support. + * + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessServerFeatures { + + /** + * Asynchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : List.of(); + this.resources = (resources != null) ? resources : Map.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); + this.prompts = (prompts != null) ? prompts : Map.of(); + this.completions = (completions != null) ? completions : Map.of(); + this.instructions = instructions; + } + + /** + * Convert a synchronous specification into an asynchronous one and provide + * blocking code offloading to prevent accidental blocking of the non-blocking + * transport. + * @param syncSpec a potentially blocking, synchronous specification. + * @param immediateExecution when true, do not offload. Do NOT set to true when + * using a non-blocking transport. + * @return a specification which is protected from blocking calls specified by the + * user. + */ + static Async fromSync(Sync syncSpec, boolean immediateExecution) { + List tools = new ArrayList<>(); + for (var tool : syncSpec.tools()) { + tools.add(AsyncToolSpecification.fromSync(tool, immediateExecution)); + } + + Map resources = new HashMap<>(); + syncSpec.resources().forEach((key, resource) -> { + resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); + }); + + Map prompts = new HashMap<>(); + syncSpec.prompts().forEach((key, prompt) -> { + prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); + }); + + Map completions = new HashMap<>(); + syncSpec.completions().forEach((key, completion) -> { + completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); + }); + + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, + syncSpec.resourceTemplates(), prompts, completions, syncSpec.instructions()); + } + } + + /** + * Synchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : new ArrayList<>(); + this.resources = (resources != null) ? resources : new HashMap<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); + this.prompts = (prompts != null) ? prompts : new HashMap<>(); + this.completions = (completions != null) ? completions : new HashMap<>(); + this.instructions = instructions; + } + + } + + /** + * Specification of a tool with its asynchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning the result. + */ + public record AsyncToolSpecification(McpSchema.Tool tool, + BiFunction> callHandler) { + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec) { + return fromSync(syncToolSpec, false); + } + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boolean immediate) { + + // FIXME: This is temporary, proper validation should be implemented + if (syncToolSpec == null) { + return null; + } + + BiFunction> callHandler = (ctx, + req) -> { + var toolResult = Mono.fromCallable(() -> syncToolSpec.callHandler().apply(ctx, req)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + }; + + return new AsyncToolSpecification(syncToolSpec.tool(), callHandler); + } + } + + /** + * Specification of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *

    + *
  • File contents + *
  • Database records + *
  • API responses + *
  • System information + *
  • Application state + *
+ * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record AsyncResourceSpecification(McpSchema.Resource resource, + BiFunction> readHandler) { + + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), (ctx, req) -> { + var resourceResult = Mono.fromCallable(() -> resource.readHandler().apply(ctx, req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
    + *
  • Consistent message formatting + *
  • Parameter substitution + *
  • Context injection + *
  • Response formatting + *
  • Instruction templating + *
+ * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. + */ + public record AsyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction> promptHandler) { + + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptSpecification(prompt.prompt(), (ctx, req) -> { + var promptResult = Mono.fromCallable(() -> prompt.promptHandler().apply(ctx, req)); + return immediateExecution ? promptResult : promptResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a completion handler function with asynchronous execution support. + * Completions generate AI model outputs based on prompt or resource references and + * user-provided arguments. This abstraction enables: + *
    + *
  • Customizable response generation logic + *
  • Parameter-driven template expansion + *
  • Dynamic interaction with connected clients + *
+ * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The asynchronous function that processes completion + * requests and returns results. The function's argument is a + * {@link McpSchema.CompleteRequest}. + */ + public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction> completionHandler) { + + /** + * Converts a synchronous {@link SyncCompletionSpecification} into an + * {@link AsyncCompletionSpecification} by wrapping the handler in a bounded + * elastic scheduler for safe non-blocking execution. + * @param completion the synchronous completion specification + * @return an asynchronous wrapper of the provided sync specification, or + * {@code null} if input is null + */ + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion, + boolean immediateExecution) { + if (completion == null) { + return null; + } + return new AsyncCompletionSpecification(completion.referenceKey(), (ctx, req) -> { + var completionResult = Mono.fromCallable(() -> completion.completionHandler().apply(ctx, req)); + return immediateExecution ? completionResult + : completionResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a tool with its synchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning results. + */ + public record SyncToolSpecification(McpSchema.Tool tool, + BiFunction callHandler) { + } + + /** + * Specification of a resource with its synchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
    + *
  • File contents + *
  • Database records + *
  • API responses + *
  • System information + *
  • Application state + *
+ * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record SyncResourceSpecification(McpSchema.Resource resource, + BiFunction readHandler) { + } + + /** + * Specification of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
    + *
  • Consistent message formatting + *
  • Parameter substitution + *
  • Context injection + *
  • Response formatting + *
  • Instruction templating + *
+ * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. + */ + public record SyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction promptHandler) { + } + + /** + * Specification of a completion handler function with synchronous execution support. + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The synchronous function that processes completion + * requests and returns results. The argument is a {@link McpSchema.CompleteRequest}. + */ + public record SyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction completionHandler) { + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java new file mode 100644 index 000000000..2f9715776 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java @@ -0,0 +1,156 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.spec.McpTransportContext; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + +/** + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessSyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessSyncServer.class); + + private final McpStatelessAsyncServer asyncServer; + + private final boolean immediateExecution; + + McpStatelessSyncServer(McpStatelessAsyncServer asyncServer) { + this(asyncServer, false); + } + + McpStatelessSyncServer(McpStatelessAsyncServer asyncServer, boolean immediateExecution) { + this.asyncServer = asyncServer; + this.immediateExecution = immediateExecution; + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.asyncServer.getServerCapabilities(); + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.asyncServer.getServerInfo(); + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.asyncServer.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.asyncServer.close(); + } + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public void addTool(McpStatelessServerFeatures.SyncToolSpecification toolSpecification) { + this.asyncServer + .addTool(McpStatelessServerFeatures.AsyncToolSpecification.fromSync(toolSpecification, + this.immediateExecution)) + .block(); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public void removeTool(String toolName) { + this.asyncServer.removeTool(toolName).block(); + } + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public void addResource(McpStatelessServerFeatures.SyncResourceSpecification resourceSpecification) { + this.asyncServer + .addResource(McpStatelessServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, + this.immediateExecution)) + .block(); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public void removeResource(String resourceUri) { + this.asyncServer.removeResource(resourceUri).block(); + } + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public void addPrompt(McpStatelessServerFeatures.SyncPromptSpecification promptSpecification) { + this.asyncServer + .addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, + this.immediateExecution)) + .block(); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public void removePrompt(String promptName) { + this.asyncServer.removePrompt(promptName).block(); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.asyncServer.setProtocolVersions(protocolVersions); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index dad1e4c19..d5fc317fe 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -6,6 +6,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpTransportContext; /** * Represents a synchronous exchange with a Model Context Protocol (MCP) client. The @@ -27,6 +28,10 @@ public McpSyncServerExchange(McpAsyncServerExchange exchange) { this.exchange = exchange; } + public String sessionId() { + return this.exchange.sessionId(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities @@ -43,6 +48,10 @@ public McpSchema.Implementation getClientInfo() { return this.exchange.getClientInfo(); } + public McpTransportContext transportContext() { + return this.exchange.transportContext(); + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java new file mode 100644 index 000000000..b58be7863 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java @@ -0,0 +1,40 @@ +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; + +public class DefaultMcpStreamableServerSessionFactory implements McpStreamableServerSession.Factory { + + Duration requestTimeout; + + McpStreamableServerSession.InitRequestHandler initRequestHandler; + + Map> requestHandlers; + + Map notificationHandlers; + + public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout, + McpStreamableServerSession.InitRequestHandler initRequestHandler, + Map> requestHandlers, + Map notificationHandlers) { + this.requestTimeout = requestTimeout; + this.initRequestHandler = initRequestHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public McpStreamableServerSession.McpStreamableServerSessionInit startSession( + McpSchema.InitializeRequest initializeRequest) { + return new McpStreamableServerSession.McpStreamableServerSessionInit( + new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(), + initializeRequest.clientInfo(), requestTimeout, requestHandlers, notificationHandlers), + this.initRequestHandler.handle(initializeRequest)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java new file mode 100644 index 000000000..ba5f3ed29 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java @@ -0,0 +1,32 @@ +package io.modelcontextprotocol.spec; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class DefaultMcpTransportContext implements McpTransportContext { + + private final Map storage; + + public DefaultMcpTransportContext() { + this.storage = new ConcurrentHashMap<>(); + } + + DefaultMcpTransportContext(Map storage) { + this.storage = storage; + } + + @Override + public Object get(String key) { + return this.storage.get(key); + } + + @Override + public void put(String key, Object value) { + this.storage.put(key, value); + } + + public McpTransportContext copy() { + return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java new file mode 100644 index 000000000..ac6d01d91 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java @@ -0,0 +1,14 @@ +package io.modelcontextprotocol.spec; + +public interface McpLoggableSession extends McpSession { + + /** + * Set the minimum logging level for the client. Messages below this level will be + * filtered out. + * @param minLoggingLevel The minimum logging level + */ + void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel); + + boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e9c23db6a..a3812dbc2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -44,7 +44,7 @@ public final class McpSchema { private McpSchema() { } - public static final String LATEST_PROTOCOL_VERSION = "2024-11-05"; + public static final String LATEST_PROTOCOL_VERSION = "2025-03-26"; public static final String JSONRPC_VERSION = "2.0"; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..fc04bd6b2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -9,6 +9,10 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpInitRequestHandler; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -19,7 +23,7 @@ * Represents a Model Control Protocol (MCP) session on the server side. It manages * bidirectional JSON-RPC communication with the client. */ -public class McpServerSession implements McpSession { +public class McpServerSession implements McpLoggableSession { private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); @@ -32,13 +36,11 @@ public class McpServerSession implements McpSession { private final AtomicLong requestCounter = new AtomicLong(0); - private final InitRequestHandler initRequestHandler; + private final McpInitRequestHandler initRequestHandler; - private final InitNotificationHandler initNotificationHandler; + private final Map> requestHandlers; - private final Map> requestHandlers; - - private final Map notificationHandlers; + private final Map notificationHandlers; private final McpServerTransport transport; @@ -56,6 +58,29 @@ public class McpServerSession implements McpSession { private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + /** + * Creates a new server session with the given parameters and the transport to use. + * @param id session id + * @param transport the transport to use + * @param initHandler called when a + * {@link io.modelcontextprotocol.spec.McpSchema.InitializeRequest} is received by the + * server + * @param requestHandlers map of request handlers to use + * @param notificationHandlers map of notification handlers to use + */ + public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, + McpInitRequestHandler initHandler, Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.requestTimeout = requestTimeout; + this.transport = transport; + this.initRequestHandler = initHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + /** * Creates a new server session with the given parameters and the transport to use. * @param id session id @@ -68,15 +93,18 @@ public class McpServerSession implements McpSession { * received. * @param requestHandlers map of request handlers to use * @param notificationHandlers map of notification handlers to use + * @deprecated Use + * {@link #McpServerSession(String, Duration, McpServerTransport, McpInitRequestHandler, Map, Map)} */ + @Deprecated public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, - InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, - Map> requestHandlers, Map notificationHandlers) { + McpInitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, + Map notificationHandlers) { this.id = id; this.requestTimeout = requestTimeout; this.transport = transport; this.initRequestHandler = initHandler; - this.initNotificationHandler = initNotificationHandler; this.requestHandlers = requestHandlers; this.notificationHandlers = notificationHandlers; } @@ -108,6 +136,17 @@ private String generateRequestId() { return this.id + "-" + this.requestCounter.getAndIncrement(); } + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { String requestId = this.generateRequestId(); @@ -242,8 +281,10 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); - exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); - return this.initNotificationHandler.handle(); + // FIXME: The session ID passed here is not the same as the one in the + // legacy SSE transport. + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(), + clientInfo.get(), McpTransportContext.EMPTY)); } var handler = notificationHandlers.get(notification.method()); @@ -264,17 +305,22 @@ private MethodNotFoundError getMethodNotFoundError(String method) { @Override public Mono closeGracefully() { + // TODO: clear pendingResponses and emit errors? return this.transport.closeGracefully(); } @Override public void close() { + // TODO: clear pendingResponses and emit errors? this.transport.close(); } /** * Request handler for the initialization request. + * + * @deprecated Use {@link McpInitRequestHandler} */ + @Deprecated public interface InitRequestHandler { /** @@ -301,7 +347,10 @@ public interface InitNotificationHandler { /** * A handler for client-initiated notifications. + * + * @deprecated Use {@link McpNotificationHandler} */ + @Deprecated public interface NotificationHandler { /** @@ -320,7 +369,9 @@ public interface NotificationHandler { * * @param the type of the response that is expected as a result of handling the * request. + * @deprecated Use {@link McpRequestHandler} */ + @Deprecated public interface RequestHandler { /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index 5fdbd7ab6..c04a4283d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -1,35 +1,6 @@ package io.modelcontextprotocol.spec; -import java.util.Map; - -import reactor.core.publisher.Mono; - -/** - * The core building block providing the server-side MCP transport. Implement this - * interface to bridge between a particular server-side technology and the MCP server - * transport layer. - * - *

- * The lifecycle of the provider dictates that it be created first, upon application - * startup, and then passed into either - * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or - * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As - * a result of the MCP server creation, the provider will be notified of a - * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication - * between a newly connected client and the server. The provider's responsibility is to - * create instances of {@link McpServerTransport} that the session will utilise during the - * session lifetime. - * - *

- * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or - * {@link #closeGracefully()} are called as part of the normal application shutdown event. - * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where - * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} - * closes the provided transport. - * - * @author Dariusz Jędrzejczyk - */ -public interface McpServerTransportProvider { +public interface McpServerTransportProvider extends McpServerTransportProviderBase { /** * Sets the session factory that will be used to create sessions for new clients. An @@ -39,28 +10,4 @@ public interface McpServerTransportProvider { */ void setSessionFactory(McpServerSession.Factory sessionFactory); - /** - * Sends a notification to all connected clients. - * @param method the name of the notification method to be called on the clients - * @param params parameters to be sent with the notification - * @return a Mono that completes when the notification has been broadcast - * @see McpSession#sendNotification(String, Map) - */ - Mono notifyClients(String method, Object params); - - /** - * Immediately closes all the transports with connected clients and releases any - * associated resources. - */ - default void close() { - this.closeGracefully().subscribe(); - } - - /** - * Gracefully closes all the transports with connected clients and releases any - * associated resources asynchronously. - * @return a {@link Mono} that completes when the connections have been closed. - */ - Mono closeGracefully(); - } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java new file mode 100644 index 000000000..87e7d6441 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java @@ -0,0 +1,58 @@ +package io.modelcontextprotocol.spec; + +import java.util.Map; + +import reactor.core.publisher.Mono; + +/** + * The core building block providing the server-side MCP transport. Implement this + * interface to bridge between a particular server-side technology and the MCP server + * transport layer. + * + *

+ * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or + * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As + * a result of the MCP server creation, the provider will be notified of a + * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication + * between a newly connected client and the server. The provider's responsibility is to + * create instances of {@link McpServerTransport} that the session will utilise during the + * session lifetime. + * + *

+ * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or + * {@link #closeGracefully()} are called as part of the normal application shutdown event. + * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where + * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} + * closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransportProviderBase { + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + * @see McpSession#sendNotification(String, Map) + */ + Mono notifyClients(String method, Object params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 42d170db5..7b29ca651 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.spec; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java new file mode 100644 index 000000000..513d551cc --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -0,0 +1,27 @@ +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +import java.util.function.BiFunction; + +public interface McpStatelessServerTransport { + + void setRequestHandler( + BiFunction> message); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java new file mode 100644 index 000000000..47ee8c2c4 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -0,0 +1,337 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +public class McpStreamableServerSession implements McpLoggableSession { + + private static final Logger logger = LoggerFactory.getLogger(McpStreamableServerSession.class); + + private final ConcurrentHashMap requestIdToStream = new ConcurrentHashMap<>(); + + private final String id; + + private final Duration requestTimeout; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private static final int STATE_UNINITIALIZED = 0; + + private static final int STATE_INITIALIZING = 1; + + private static final int STATE_INITIALIZED = 2; + + private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + + private final AtomicReference listeningStreamRef; + + private final MissingMcpTransportSession missingMcpTransportSession; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo, Duration requestTimeout, + Map> requestHandlers, + Map notificationHandlers) { + this.id = id; + this.missingMcpTransportSession = new MissingMcpTransportSession(id); + this.listeningStreamRef = new AtomicReference<>(this.missingMcpTransportSession); + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + this.requestTimeout = requestTimeout; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + + public String getId() { + return this.id; + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendRequest(method, requestParams, typeRef); + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendNotification(method, params); + }); + } + + public Mono delete() { + return this.closeGracefully().then(Mono.fromRunnable(() -> { + // delete history, etc. + })); + } + + public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) { + McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); + this.listeningStreamRef.set(listeningStream); + return listeningStream; + } + + // TODO: keep track of history by keeping a map from eventId to stream and then + // iterate over the events using the lastEventId + public Flux replay(Object lastEventId) { + return Flux.empty(); + } + + public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStreamableServerTransport transport) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + + McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport); + McpRequestHandler requestHandler = McpStreamableServerSession.this.requestHandlers + .get(jsonrpcRequest.method()); + // TODO: delegate to stream, which upon successful response should close + // remove itself from the registry and also close the underlying transport + // (sink) + if (requestHandler == null) { + MethodNotFoundError error = getMethodNotFoundError(jsonrpcRequest.method()); + return transport + .sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + return requestHandler + .handle(new McpAsyncServerExchange(this.id, stream, clientCapabilities.get(), clientInfo.get(), + transportContext), jsonrpcRequest.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, + null)) + .onErrorResume(e -> { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), + null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, + e.getMessage(), null)); + return Mono.just(errorResponse); + }) + .flatMap(transport::sendMessage) + .then(transport.closeGracefully()); + }); + } + + public Mono accept(McpSchema.JSONRPCNotification notification) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + McpNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); + if (notificationHandler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return notificationHandler.handle(new McpAsyncServerExchange(this.id, listeningStream, + this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.params()); + }); + + } + + public Mono accept(McpSchema.JSONRPCResponse response) { + return Mono.defer(() -> { + var stream = this.requestIdToStream.get(response.id()); + if (stream == null) { + return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO + // JSONize + } + // TODO: encapsulate this inside the stream itself + var sink = stream.pendingResponses.remove(response.id()); + if (sink == null) { + return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO + // JSONize + } + else { + sink.success(response); + } + return Mono.empty(); + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); + return listeningStream.closeGracefully(); + // TODO: Also close all the open streams + }); + } + + @Override + public void close() { + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); + if (listeningStream != null) { + listeningStream.close(); + } + // TODO: Also close all open streams + } + + /** + * Request handler for the initialization request. + */ + public interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + public interface Factory { + + McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest); + + } + + public record McpStreamableServerSessionInit(McpStreamableServerSession session, + Mono initResult) { + } + + public final class McpStreamableServerSessionStream implements McpLoggableSession { + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final McpStreamableServerTransport transport; + + private final String transportId; + + private final Supplier uuidGenerator; + + public McpStreamableServerSessionStream(McpStreamableServerTransport transport) { + this.transport = transport; + this.transportId = UUID.randomUUID().toString(); + // This ID design allows for a constant-time extraction of the history by + // precisely identifying the SSE stream using the first component + this.uuidGenerator = () -> this.transportId + "_" + UUID.randomUUID(); + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + McpStreamableServerSession.this.setMinLoggingLevel(minLoggingLevel); + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return McpStreamableServerSession.this.isNotificationForLevelAllowed(loggingLevel); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = McpStreamableServerSession.this.generateRequestId(); + + McpStreamableServerSession.this.requestIdToStream.put(requestId, this); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + method, requestId, requestParams); + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + this.transport.sendMessage(jsonrpcRequest, messageId).subscribe(v -> { + }, sink::error); + }).timeout(requestTimeout).doOnError(e -> { + this.pendingResponses.remove(requestId); + McpStreamableServerSession.this.requestIdToStream.remove(requestId); + }).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, method, params); + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + return this.transport.sendMessage(jsonrpcNotification, messageId); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + return this.transport.closeGracefully(); + }); + } + + @Override + public void close() { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + this.transport.close(); + } + + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java new file mode 100644 index 000000000..e49ba9c84 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +/** + * Marker interface for the server-side MCP transport. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransport extends McpServerTransport { + + Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java new file mode 100644 index 000000000..48b9cd75e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java @@ -0,0 +1,67 @@ +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +import java.util.Map; + +/** + * The core building block providing the server-side MCP transport. Implement this + * interface to bridge between a particular server-side technology and the MCP server + * transport layer. + * + *

+ * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpStreamableServerTransportProvider)} + * or + * {@link io.modelcontextprotocol.server.McpServer#async(McpStreamableServerTransportProvider)}. + * As a result of the MCP server creation, the provider will be notified of a + * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication + * between a newly connected client and the server. The provider's responsibility is to + * create instances of {@link McpServerTransport} that the session will utilise during the + * session lifetime. + * + *

+ * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or + * {@link #closeGracefully()} are called as part of the normal application shutdown event. + * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where + * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} + * closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransportProvider extends McpServerTransportProviderBase { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpStreamableServerSession.Factory sessionFactory); + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + * @see McpSession#sendNotification(String, Map) + */ + Mono notifyClients(String method, Object params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java new file mode 100644 index 000000000..bfffeccd6 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.spec; + +public interface McpTransportContext { + + String KEY = "MCP_TRANSPORT_CONTEXT"; + + McpTransportContext EMPTY = new DefaultMcpTransportContext(); + + Object get(String key); + + void put(String key, Object value); + + McpTransportContext copy(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java new file mode 100644 index 000000000..f41c8768e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java @@ -0,0 +1,47 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +public class MissingMcpTransportSession implements McpLoggableSession { + + private final String sessionId; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + public MissingMcpTransportSession(String sessionId) { + this.sessionId = sessionId; + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + } + + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java index a101f0177..e9356d0c0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpAsyncClientTests.java @@ -37,7 +37,7 @@ protected McpClientTransport createMcpTransport() { } protected Duration getInitializationTimeout() { - return Duration.ofSeconds(10); + return Duration.ofSeconds(20); } } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index b5841e755..7a1e90770 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -21,6 +21,8 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -43,7 +45,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); protected void onStart() { } @@ -64,28 +66,29 @@ void tearDown() { // Server Lifecycle Tests // --------------------------------------- - @Test - void testConstructorWithInvalidArguments() { + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "sse", "streamable" }) + void testConstructorWithInvalidArguments(String serverType) { assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy( - () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); + var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); } @@ -105,8 +108,7 @@ void testImmediateClose() { @Deprecated void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -120,8 +122,7 @@ void testAddTool() { @Test void testAddToolCall() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -138,8 +139,7 @@ void testAddToolCall() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -159,8 +159,7 @@ void testAddDuplicateTool() { void testAddDuplicateToolCall() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -181,8 +180,7 @@ void testDuplicateToolCallDuringBuilding() { Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) // Duplicate! @@ -204,8 +202,7 @@ void testDuplicateToolsInBatchListRegistration() { .build() // Duplicate! ); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(specs) .build()).isInstanceOf(IllegalArgumentException.class) @@ -216,8 +213,7 @@ void testDuplicateToolsInBatchListRegistration() { void testDuplicateToolsInBatchVarargsRegistration() { Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) @@ -236,8 +232,7 @@ void testDuplicateToolsInBatchVarargsRegistration() { void testRemoveTool() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -249,8 +244,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -265,8 +259,7 @@ void testRemoveNonexistentTool() { void testNotifyToolsListChanged() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -282,7 +275,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -291,7 +284,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier .create(mcpAsyncServer @@ -303,8 +296,7 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -320,8 +312,7 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -336,9 +327,7 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); @@ -354,9 +343,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -370,7 +357,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -379,8 +366,7 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); @@ -393,9 +379,7 @@ void testAddPromptWithNullSpecification() { @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( @@ -411,9 +395,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -430,8 +412,7 @@ void testRemovePrompt() { prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -443,8 +424,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -467,8 +447,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -487,8 +466,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -501,8 +479,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -514,9 +491,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java index 208d2e749..1e2c94fe5 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpSyncServerTests.java @@ -41,7 +41,7 @@ public abstract class AbstractMcpSyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); protected void onStart() { } @@ -69,28 +69,28 @@ void testConstructorWithInvalidArguments() { .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()).serverInfo(null)) + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.closeGracefully()).doesNotThrowAnyException(); } @Test void testImmediateClose() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.close()).doesNotThrowAnyException(); } @Test void testGetAsyncServer() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(mcpSyncServer.getAsyncServer()).isNotNull(); @@ -112,8 +112,7 @@ void testGetAsyncServer() { @Test @Deprecated void testAddTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -127,8 +126,7 @@ void testAddTool() { @Test void testAddToolCall() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -146,8 +144,7 @@ void testAddToolCall() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tool(duplicateTool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); @@ -164,8 +161,7 @@ void testAddDuplicateTool() { void testAddDuplicateToolCall() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .build(); @@ -184,8 +180,7 @@ void testDuplicateToolCallDuringBuilding() { Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) .toolCall(duplicateTool, (exchange, request) -> new CallToolResult(List.of(), false)) // Duplicate! @@ -207,8 +202,7 @@ void testDuplicateToolsInBatchListRegistration() { .build() // Duplicate! ); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(specs) .build()).isInstanceOf(IllegalArgumentException.class) @@ -219,8 +213,7 @@ void testDuplicateToolsInBatchListRegistration() { void testDuplicateToolsInBatchVarargsRegistration() { Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.SyncToolSpecification.builder() .tool(duplicateTool) @@ -239,8 +232,7 @@ void testDuplicateToolsInBatchVarargsRegistration() { void testRemoveTool() { Tool tool = new McpSchema.Tool(TEST_TOOL_NAME, "Test tool", emptyJsonSchema); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(tool, (exchange, args) -> new CallToolResult(List.of(), false)) .build(); @@ -252,8 +244,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -265,7 +256,7 @@ void testRemoveNonexistentTool() { @Test void testNotifyToolsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyToolsListChanged()).doesNotThrowAnyException(); @@ -278,7 +269,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyResourcesListChanged()).doesNotThrowAnyException(); @@ -287,7 +278,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer .notifyResourcesUpdated(new McpSchema.ResourcesUpdatedNotification(TEST_RESOURCE_URI))) @@ -298,8 +289,7 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -315,8 +305,7 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -329,9 +318,7 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); @@ -344,9 +331,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { - var serverWithoutResources = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutResources = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatThrownBy(() -> serverWithoutResources.removeResource(TEST_RESOURCE_URI)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with resource capabilities"); @@ -358,7 +343,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpSyncServer.notifyPromptsListChanged()).doesNotThrowAnyException(); @@ -367,8 +352,7 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); @@ -379,9 +363,7 @@ void testAddPromptWithNullSpecification() { @Test void testAddPromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.SyncPromptSpecification specification = new McpServerFeatures.SyncPromptSpecification(prompt, @@ -394,9 +376,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { - var serverWithoutPrompts = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var serverWithoutPrompts = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatThrownBy(() -> serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).isInstanceOf(McpError.class) .hasMessage("Server must be configured with prompt capabilities"); @@ -409,8 +389,7 @@ void testRemovePrompt() { (exchange, req) -> new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content"))))); - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -422,8 +401,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpSyncServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpSyncServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -443,8 +421,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -462,8 +439,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -475,8 +451,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.sync(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -487,7 +462,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.sync(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var noConsumersServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully()).doesNotThrowAnyException(); diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java index 39066a9a2..e1ca584ee 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpAsyncServerExchangeTests.java @@ -10,6 +10,7 @@ import java.util.Map; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.DefaultMcpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; @@ -54,7 +55,8 @@ void setUp() { clientInfo = new McpSchema.Implementation("test-client", "1.0.0"); - exchange = new McpAsyncServerExchange(mockSession, clientCapabilities, clientInfo); + exchange = new McpAsyncServerExchange("testSessionId", mockSession, clientCapabilities, clientInfo, + new DefaultMcpTransportContext()); } @Test @@ -219,27 +221,33 @@ void testLoggingNotificationWithNullMessage() { } @Test - void testLoggingNotificationWithAllowedLevel() { + void testSetMinLoggingLevelWithNullValue() { + assertThatThrownBy(() -> exchange.setMinLoggingLevel(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("minLoggingLevel must not be null"); + } + @Test + void testLoggingNotificationWithAllowedLevel() { McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.ERROR) .logger("test-logger") .data("Test error message") .build(); + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) .thenReturn(Mono.empty()); StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); - // Verify that sendNotification was called exactly once + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.ERROR)); verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); } @Test void testLoggingNotificationWithFilteredLevel() { - // Given - Set minimum level to WARNING, send DEBUG message - exchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); + exchange.setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).setMinLoggingLevel(eq(McpSchema.LoggingLevel.DEBUG)); McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.DEBUG) @@ -247,104 +255,38 @@ void testLoggingNotificationWithFilteredLevel() { .data("Debug message that should be filtered") .build(); - // When & Then - Should complete without sending notification - StepVerifier.create(exchange.loggingNotification(debugNotification)).verifyComplete(); - - // Verify that sendNotification was never called for filtered DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); - } - - @Test - void testLoggingNotificationLevelFiltering() { - // Given - Set minimum level to WARNING - exchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); - - // Test DEBUG (should be filtered) - McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build(); + when(mockSession.isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG))).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification))) + .thenReturn(Mono.empty()); StepVerifier.create(exchange.loggingNotification(debugNotification)).verifyComplete(); - // Verify that sendNotification was never called for DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); - - // Test INFO (should be filtered) - McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Info message") - .build(); - - StepVerifier.create(exchange.loggingNotification(infoNotification)).verifyComplete(); - - // Verify that sendNotification was never called for INFO level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); - - reset(mockSession); + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG)); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(debugNotification)); - // Test WARNING (should be sent) McpSchema.LoggingMessageNotification warningNotification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.WARNING) .logger("test-logger") - .data("Warning message") + .data("Debug message that should be filtered") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification))) - .thenReturn(Mono.empty()); - StepVerifier.create(exchange.loggingNotification(warningNotification)).verifyComplete(); - // Verify that sendNotification was called exactly once for WARNING level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.WARNING)); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification)); - - // Test ERROR (should be sent) - McpSchema.LoggingMessageNotification errorNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build(); - - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification))) - .thenReturn(Mono.empty()); - - StepVerifier.create(exchange.loggingNotification(errorNotification)).verifyComplete(); - - // Verify that sendNotification was called exactly once for ERROR level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), - eq(errorNotification)); - } - - @Test - void testLoggingNotificationWithDefaultLevel() { - - McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Info message") - .build(); - - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification))) - .thenReturn(Mono.empty()); - - StepVerifier.create(exchange.loggingNotification(infoNotification)).verifyComplete(); - - // Verify that sendNotification was called exactly once for default level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); } @Test void testLoggingNotificationWithSessionError() { - McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.ERROR) .logger("test-logger") .data("Test error message") .build(); + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) .thenReturn(Mono.error(new RuntimeException("Session error"))); @@ -353,44 +295,6 @@ void testLoggingNotificationWithSessionError() { }); } - @Test - void testSetMinLoggingLevelWithNullValue() { - // When & Then - assertThatThrownBy(() -> exchange.setMinLoggingLevel(null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("minLoggingLevel must not be null"); - } - - @Test - void testLoggingLevelHierarchy() { - // Test all logging levels to ensure proper hierarchy - McpSchema.LoggingLevel[] levels = { McpSchema.LoggingLevel.DEBUG, McpSchema.LoggingLevel.INFO, - McpSchema.LoggingLevel.NOTICE, McpSchema.LoggingLevel.WARNING, McpSchema.LoggingLevel.ERROR, - McpSchema.LoggingLevel.CRITICAL, McpSchema.LoggingLevel.ALERT, McpSchema.LoggingLevel.EMERGENCY }; - - // Set minimum level to WARNING - exchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); - - for (McpSchema.LoggingLevel level : levels) { - McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message for " + level) - .build(); - - if (level.level() >= McpSchema.LoggingLevel.WARNING.level()) { - // Should be sent - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) - .thenReturn(Mono.empty()); - - StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); - } - else { - // Should be filtered (completes without sending) - StepVerifier.create(exchange.loggingNotification(notification)).verifyComplete(); - } - } - } - // --------------------------------------- // Create Elicitation Tests // --------------------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java index 66d7695e8..63d827013 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/McpSyncServerExchangeTests.java @@ -226,19 +226,21 @@ void testLoggingNotificationWithAllowedLevel() { .data("Test error message") .build(); + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) .thenReturn(Mono.empty()); exchange.loggingNotification(notification); // Verify that sendNotification was called exactly once + verify(mockSession, times(1)).isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.ERROR)); verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification)); } @Test void testLoggingNotificationWithFilteredLevel() { - // Given - Set minimum level to WARNING, send DEBUG message - asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); + asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).setMinLoggingLevel(McpSchema.LoggingLevel.DEBUG); McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.DEBUG) @@ -246,93 +248,27 @@ void testLoggingNotificationWithFilteredLevel() { .data("Debug message that should be filtered") .build(); - // When & Then - Should complete without sending notification - exchange.loggingNotification(debugNotification); - - // Verify that sendNotification was never called for filtered DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); - } - - @Test - void testLoggingNotificationLevelFiltering() { - // Given - Set minimum level to WARNING - asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); - - // Test DEBUG (should be filtered) - McpSchema.LoggingMessageNotification debugNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build(); + when(mockSession.isNotificationForLevelAllowed(eq(McpSchema.LoggingLevel.DEBUG))).thenReturn(Boolean.TRUE); + when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification))) + .thenReturn(Mono.empty()); exchange.loggingNotification(debugNotification); - // Verify that sendNotification was never called for DEBUG level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(debugNotification)); - - // Test INFO (should be filtered) - McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Info message") - .build(); - - exchange.loggingNotification(infoNotification); - - // Verify that sendNotification was never called for INFO level - verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); - - reset(mockSession); + verify(mockSession, times(1)).isNotificationForLevelAllowed(McpSchema.LoggingLevel.DEBUG); + verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + eq(debugNotification)); - // Test WARNING (should be sent) McpSchema.LoggingMessageNotification warningNotification = McpSchema.LoggingMessageNotification.builder() .level(McpSchema.LoggingLevel.WARNING) .logger("test-logger") - .data("Warning message") + .data("Debug message that should be filtered") .build(); - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification))) - .thenReturn(Mono.empty()); - exchange.loggingNotification(warningNotification); - // Verify that sendNotification was called exactly once for WARNING level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), + verify(mockSession, times(1)).isNotificationForLevelAllowed(McpSchema.LoggingLevel.WARNING); + verify(mockSession, never()).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(warningNotification)); - - // Test ERROR (should be sent) - McpSchema.LoggingMessageNotification errorNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build(); - - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(errorNotification))) - .thenReturn(Mono.empty()); - - exchange.loggingNotification(errorNotification); - - // Verify that sendNotification was called exactly once for ERROR level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), - eq(errorNotification)); - } - - @Test - void testLoggingNotificationWithDefaultLevel() { - - McpSchema.LoggingMessageNotification infoNotification = McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Info message") - .build(); - - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification))) - .thenReturn(Mono.empty()); - - exchange.loggingNotification(infoNotification); - - // Verify that sendNotification was called exactly once for default level - verify(mockSession, times(1)).sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(infoNotification)); } @Test @@ -344,6 +280,7 @@ void testLoggingNotificationWithSessionError() { .data("Test error message") .build(); + when(mockSession.isNotificationForLevelAllowed(any())).thenReturn(Boolean.TRUE); when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) .thenReturn(Mono.error(new RuntimeException("Session error"))); @@ -351,37 +288,6 @@ void testLoggingNotificationWithSessionError() { .hasMessage("Session error"); } - @Test - void testLoggingLevelHierarchy() { - // Test all logging levels to ensure proper hierarchy - McpSchema.LoggingLevel[] levels = { McpSchema.LoggingLevel.DEBUG, McpSchema.LoggingLevel.INFO, - McpSchema.LoggingLevel.NOTICE, McpSchema.LoggingLevel.WARNING, McpSchema.LoggingLevel.ERROR, - McpSchema.LoggingLevel.CRITICAL, McpSchema.LoggingLevel.ALERT, McpSchema.LoggingLevel.EMERGENCY }; - - // Set minimum level to WARNING - asyncExchange.setMinLoggingLevel(McpSchema.LoggingLevel.WARNING); - - for (McpSchema.LoggingLevel level : levels) { - McpSchema.LoggingMessageNotification notification = McpSchema.LoggingMessageNotification.builder() - .level(level) - .logger("test-logger") - .data("Test message for " + level) - .build(); - - if (level.level() >= McpSchema.LoggingLevel.WARNING.level()) { - // Should be sent - when(mockSession.sendNotification(eq(McpSchema.METHOD_NOTIFICATION_MESSAGE), eq(notification))) - .thenReturn(Mono.empty()); - - exchange.loggingNotification(notification); - } - else { - // Should be filtered (completes without sending) - exchange.loggingNotification(notification); - } - } - } - // --------------------------------------- // Create Elicitation Tests // --------------------------------------- diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java index 81d904292..8906adfe0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpAsyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class ServletSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java index 154cf3a61..7b77f9241 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/ServletSseMcpSyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class ServletSseMcpSyncServerTests extends AbstractMcpSyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return HttpServletSseServerTransportProvider.builder().messageEndpoint("/mcp/message").build(); } + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java index 0381a43bd..97db5fa06 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpAsyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class StdioMcpAsyncServerTests extends AbstractMcpAsyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return new StdioServerTransportProvider(); } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java index a71c38493..1e01962e9 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/StdioMcpSyncServerTests.java @@ -16,9 +16,13 @@ @Timeout(15) // Giving extra time beyond the client timeout class StdioMcpSyncServerTests extends AbstractMcpSyncServerTests { - @Override protected McpServerTransportProvider createMcpTransportProvider() { return new StdioServerTransportProvider(); } + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + }