From 98990e95be5d8bc73b7d6c312939f7771ecbc3d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 18 Jul 2025 12:01:54 +0200 Subject: [PATCH 1/8] WIP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseServerTransportProvider.java | 3 +- ...FluxStreamableServerTransportProvider.java | 400 +++++++++ .../WebMvcSseServerTransportProvider.java | 3 +- .../server/McpAsyncServer.java | 100 ++- .../server/McpAsyncServerExchange.java | 8 +- .../server/McpInitRequestHandler.java | 19 + .../server/McpNotificationHandler.java | 20 + .../server/McpRequestHandler.java | 23 + .../server/McpServer.java | 17 + .../server/McpStatelessAsyncServer.java | 758 ++++++++++++++++++ ...HttpServletSseServerTransportProvider.java | 3 +- .../StdioServerTransportProvider.java | 3 +- ...aultMcpStreamableServerSessionFactory.java | 30 + .../spec/DisconnectedMcpSession.java | 28 + .../spec/McpServerSession.java | 21 +- .../spec/McpServerTransportProvider.java | 8 - ...pSingleSessionServerTransportProvider.java | 12 + .../spec/McpStatelessServerTransport.java | 26 + .../spec/McpStreamableServerSession.java | 271 +++++++ .../McpStreamableServerTransportProvider.java | 66 ++ .../MockMcpServerTransportProvider.java | 3 +- 21 files changed, 1767 insertions(+), 55 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DisconnectedMcpSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpSingleSessionServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9aa..d52902e32 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -11,6 +11,7 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,7 +64,7 @@ * @see McpServerTransport * @see ServerSentEvent */ -public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { +public class WebFluxSseServerTransportProvider implements McpSingleSessionServerTransportProvider { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class); diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java new file mode 100644 index 000000000..3dd279ff7 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -0,0 +1,400 @@ +package io.modelcontextprotocol.server.transport; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; + +public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebFluxStreamableServerTransportProvider.class); + + public static final String MESSAGE_EVENT_TYPE = "message"; + + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + public static final String DEFAULT_BASE_URL = ""; + + private final ObjectMapper objectMapper; + + private final String baseUrl; + + private final String mcpEndpoint; + + private final RouterFunction routerFunction; + + private McpStreamableServerSession.Factory sessionFactory; + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Constructs a new WebFlux SSE server transport provider instance with the default + * SSE endpoint. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint) { + this(objectMapper, DEFAULT_BASE_URL, mcpEndpoint); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String mcpEndpoint) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(baseUrl, "Message base path must not be null"); + Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); + + this.objectMapper = objectMapper; + this.baseUrl = baseUrl; + this.mcpEndpoint = mcpEndpoint; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .build(); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a JSON-RPC message to all connected clients through their SSE + * connections. The message is serialized to JSON and sent as a server-sent event to + * each active session. + * + *

+ * The method: + *

    + *
  • Serializes the message to JSON
  • + *
  • Creates a server-sent event with the message data
  • + *
  • Attempts to send the event to all active sessions
  • + *
  • Tracks and reports any delivery failures
  • + *
+ * @param method The JSON-RPC method to send to clients + * @param params The method parameters to send to clients + * @return A Mono that completes when the message has been sent to all sessions, or + * errors if any session fails to receive the message + */ + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + // FIXME: This javadoc makes claims about using isClosing flag but it's not + // actually + // doing that. + /** + * Initiates a graceful shutdown of all the sessions. This method ensures all active + * sessions are properly closed and cleaned up. + * + *

+ * The shutdown process: + *

    + *
  • Marks the transport as closing to prevent new connections
  • + *
  • Closes each active session
  • + *
  • Removes closed sessions from the sessions map
  • + *
  • Times out after 5 seconds if shutdown takes too long
  • + *
+ * @return A Mono that completes when all sessions have been closed + */ + @Override + public Mono closeGracefully() { + return Flux.fromIterable(sessions.values()) + .doFirst(() -> logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size())) + .flatMap(McpStreamableServerSession::closeGracefully) + .then(); + } + + /** + * Returns the WebFlux router function that defines the transport's HTTP endpoints. + * This router function should be integrated into the application's web configuration. + * + *

+ * The router function defines two endpoints: + *

    + *
  • GET {sseEndpoint} - For establishing SSE connections
  • + *
  • POST {messageEndpoint} - For receiving client messages
  • + *
+ * @return The configured {@link RouterFunction} for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Handles new SSE connection requests from clients. Creates a new session for each + * connection and sets up the SSE event stream. + * @param request The incoming server request + * @return A Mono which emits a response with the SSE event stream + */ + private Mono handleGet(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + return Mono.defer(() -> { + if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) { + return ServerResponse.badRequest().build(); // TODO: say we need a session id + } + + String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id"); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + if (request.headers().asHttpHeaders().containsKey("mcp-last-id")) { + String lastId = request.headers().asHttpHeaders().getFirst("mcp-last-id"); + return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(session.replay(lastId), ServerSentEvent.class); + } + + return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport(sink); + McpStreamableServerSession.McpStreamableServerSessionStream genericStream = session.newStream(sessionTransport); + sink.onDispose(genericStream::close); + }), ServerSentEvent.class); + + }); + } + + /** + * Handles incoming JSON-RPC messages from clients. Deserializes the message and + * processes it through the configured message handler. + * + *

+ * The handler: + *

    + *
  • Deserializes the incoming JSON-RPC message
  • + *
  • Passes it through the message handler chain
  • + *
  • Returns appropriate HTTP responses based on processing results
  • + *
  • Handles various error conditions with appropriate error responses
  • + *
+ * @param request The incoming server request containing the JSON-RPC message + * @return A Mono emitting the response indicating the message processing result + */ + private Mono handlePost(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + return request.bodyToMono(String.class).flatMap(body -> { + try { + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), new TypeReference() {}); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory.startSession(initializeRequest); + sessions.put(init.session().getId(), init.session()); + return init.initResult().flatMap(initResult -> ServerResponse.ok().header("mcp-session-id", init.session().getId()).bodyValue(initResult)); + } + + if (!request.headers().asHttpHeaders().containsKey("sessionId")) { + return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing")); + } + + String sessionId = request.headers().asHttpHeaders().getFirst("sessionId"); + McpStreamableServerSession session = sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .bodyValue(new McpError("Session not found: " + sessionId)); + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + return session.accept(jsonrpcResponse).then(ServerResponse.accepted().build()); + } else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + return session.accept(jsonrpcNotification).then(ServerResponse.accepted().build()); + } else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); + Mono stream = session.handleStream(jsonrpcRequest, st); + Disposable streamSubscription = stream + .doOnError(err -> sink.error(err)) + .contextWrite(sink.contextView()) + .subscribe(); + sink.onCancel(streamSubscription); + }), ServerSentEvent.class); + } else { + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).bodyValue(new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); + } + }); + } + + private class WebFluxStreamableMcpSessionTransport implements McpServerTransport { + + private final FluxSink> sink; + + public WebFluxStreamableMcpSessionTransport(FluxSink> sink) { + this.sink = sink; + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return Mono.fromSupplier(() -> { + try { + return objectMapper.writeValueAsString(message); + } + catch (IOException e) { + throw Exceptions.propagate(e); + } + }).doOnNext(jsonText -> { + ServerSentEvent event = ServerSentEvent.builder() + .event(MESSAGE_EVENT_TYPE) + .data(jsonText) + .build(); + sink.next(event); + }).doOnError(e -> { + // TODO log with sessionid + Throwable exception = Exceptions.unwrap(e); + sink.error(exception); + }).then(); + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(sink::complete); + } + + @Override + public void close() { + sink.complete(); + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebFluxStreamableServerTransportProvider}. + *

+ * This builder provides a fluent API for configuring and creating instances of + * WebFluxSseServerTransportProvider with custom settings. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String baseUrl = DEFAULT_BASE_URL; + + private String mcpEndpoint = "/mcp"; + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the project basePath as endpoint prefix where clients should send their + * JSON-RPC messages + * @param baseUrl the message basePath . Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if basePath is null + */ + public Builder basePath(String baseUrl) { + Assert.notNull(baseUrl, "basePath must not be null"); + this.baseUrl = baseUrl; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param messageEndpoint The message endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if messageEndpoint is null + */ + public Builder messageEndpoint(String messageEndpoint) { + Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + this.mcpEndpoint = messageEndpoint; + return this; + } + + /** + * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with the + * configured settings. + * @return A new WebFluxSseServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebFluxStreamableServerTransportProvider build() { + Assert.notNull(objectMapper, "ObjectMapper must be set"); + Assert.notNull(mcpEndpoint, "Message endpoint must be set"); + + return new WebFluxStreamableServerTransportProvider(objectMapper, baseUrl, mcpEndpoint); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa0..1e89eea8f 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -17,6 +17,7 @@ import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,7 +67,7 @@ * @see McpServerTransportProvider * @see RouterFunction */ -public class WebMvcSseServerTransportProvider implements McpServerTransportProvider { +public class WebMvcSseServerTransportProvider implements McpSingleSessionServerTransportProvider { private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransportProvider.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index d873a7fde..1106459ca 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -15,6 +15,9 @@ import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; +import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; +import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -123,9 +126,9 @@ public class McpAsyncServer { * @param features The MCP server supported features. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ - McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpAsyncServer(McpSingleSessionServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -139,7 +142,56 @@ public class McpAsyncServer { this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; - Map> requestHandlers = new HashMap<>(); + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + mcpTransportProvider.setSessionFactory( + transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, + this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + } + + McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(withStructuredOutputHandling(jsonSchemaValidator, features.tools())); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; + + Map> requestHandlers = prepareRequestHandlers(); + Map notificationHandlers = prepareNotificationHandlers(features); + + mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + } + + private Map prepareNotificationHandlers(McpServerFeatures.Async features) { + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + return notificationHandlers; + } + + private Map> prepareRequestHandlers() { + Map> requestHandlers = new HashMap<>(); // Initialize request handlers for standard MCP methods @@ -174,25 +226,7 @@ public class McpAsyncServer { if (this.serverCapabilities.completions() != null) { requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); } - - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features - .rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + return requestHandlers; } // --------------------------------------- @@ -258,7 +292,7 @@ public void close() { this.mcpTransportProvider.close(); } - private McpServerSession.NotificationHandler asyncRootsListChangedNotificationHandler( + private McpNotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { return (exchange, params) -> exchange.listRoots() .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) @@ -448,7 +482,7 @@ public Mono notifyToolsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); } - private McpServerSession.RequestHandler toolsListRequestHandler() { + private McpRequestHandler toolsListRequestHandler() { return (exchange, params) -> { List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); @@ -456,7 +490,7 @@ private McpServerSession.RequestHandler toolsListRequ }; } - private McpServerSession.RequestHandler toolsCallRequestHandler() { + private McpRequestHandler toolsCallRequestHandler() { return (exchange, params) -> { McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, new TypeReference() { @@ -549,7 +583,7 @@ public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification); } - private McpServerSession.RequestHandler resourcesListRequestHandler() { + private McpRequestHandler resourcesListRequestHandler() { return (exchange, params) -> { var resourceList = this.resources.values() .stream() @@ -559,7 +593,7 @@ private McpServerSession.RequestHandler resources }; } - private McpServerSession.RequestHandler resourceTemplateListRequestHandler() { + private McpRequestHandler resourceTemplateListRequestHandler() { return (exchange, params) -> Mono .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); @@ -583,7 +617,7 @@ private List getResourceTemplates() { return list; } - private McpServerSession.RequestHandler resourcesReadRequestHandler() { + private McpRequestHandler resourcesReadRequestHandler() { return (exchange, params) -> { McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, new TypeReference() { @@ -676,7 +710,7 @@ public Mono notifyPromptsListChanged() { return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); } - private McpServerSession.RequestHandler promptsListRequestHandler() { + private McpRequestHandler promptsListRequestHandler() { return (exchange, params) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, @@ -692,7 +726,7 @@ private McpServerSession.RequestHandler promptsList }; } - private McpServerSession.RequestHandler promptsGetRequestHandler() { + private McpRequestHandler promptsGetRequestHandler() { return (exchange, params) -> { McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, new TypeReference() { @@ -738,7 +772,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN loggingMessageNotification); } - private McpServerSession.RequestHandler setLoggerRequestHandler() { + private McpRequestHandler setLoggerRequestHandler() { return (exchange, params) -> { return Mono.defer(() -> { @@ -757,7 +791,7 @@ private McpServerSession.RequestHandler setLoggerRequestHandler() { }; } - private McpServerSession.RequestHandler completionCompleteRequestHandler() { + private McpRequestHandler completionCompleteRequestHandler() { return (exchange, params) -> { McpSchema.CompleteRequest request = parseCompletionParams(params); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index e56c695fa..7362033f9 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -12,7 +12,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; -import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpSession; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -25,7 +25,7 @@ */ public class McpAsyncServerExchange { - private final McpServerSession session; + private final McpSession session; private final McpSchema.ClientCapabilities clientCapabilities; @@ -52,8 +52,8 @@ public class McpAsyncServerExchange { * features and functionality. * @param clientInfo The client implementation information. */ - public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabilities clientCapabilities, - McpSchema.Implementation clientInfo) { + public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo) { this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java new file mode 100644 index 000000000..a6063a8b2 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java @@ -0,0 +1,19 @@ +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.publisher.Mono; + +/** + * Request handler for the initialization request. + */ +public interface McpInitRequestHandler { + + /** + * Handles the initialization request. + * + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java new file mode 100644 index 000000000..492454908 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java @@ -0,0 +1,20 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated notifications. + */ +public interface McpNotificationHandler { + + /** + * Handles a notification from the client. + * + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java new file mode 100644 index 000000000..c95af472a --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java @@ -0,0 +1,23 @@ +package io.modelcontextprotocol.server; + +import reactor.core.publisher.Mono; + +/** + * A handler for client-initiated requests. + * + * @param the type of the response that is expected as a result of handling the + * request. + */ +public interface McpRequestHandler { + + /** + * Handles a request from the client. + * + * @param exchange the exchange associated with the client that allows calling + * back to the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d4b8addf4..05734c272 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -21,6 +21,7 @@ import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; @@ -155,6 +156,22 @@ static AsyncSpecification async(McpServerTransportProvider transportProvider) { return new AsyncSpecification(transportProvider); } + static StatelessAsyncSpecification async(McpStatelessServerTransport transportProvider) { + // TODO + } + + static StatelessSyncSpecification sync(McpStatelessServerTransport transportProvider) { + // TODO + } + + class StatelessAsyncSpecification { + // TODO + } + + class StatelessSyncSpecification { + // TODO + } + /** * Asynchronous server specification. */ diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java new file mode 100644 index 000000000..94ef1a6df --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -0,0 +1,758 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; +import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; +import java.util.function.Function; + +/** + * The Model Context Protocol (MCP) server implementation that provides asynchronous + * communication using Project Reactor's Mono and Flux types. + * + *

+ * This server implements the MCP specification, enabling AI models to expose tools, + * resources, and prompts through a standardized interface. Key features include: + *

    + *
  • Asynchronous communication using reactive programming patterns + *
  • Dynamic tool registration and management + *
  • Resource handling with URI-based addressing + *
  • Prompt template management + *
  • Real-time client notifications for state changes + *
  • Structured logging with configurable severity levels + *
  • Support for client-side AI model sampling + *
+ * + *

+ * The server follows a lifecycle: + *

    + *
  1. Initialization - Accepts client connections and negotiates capabilities + *
  2. Normal Operation - Handles client requests and sends notifications + *
  3. Graceful Shutdown - Ensures clean connection termination + *
+ * + *

+ * This implementation uses Project Reactor for non-blocking operations, making it + * suitable for high-throughput scenarios and reactive applications. All operations return + * Mono or Flux types that can be composed into reactive pipelines. + * + *

+ * The server supports runtime modification of its capabilities through methods like + * {@link #addTool}, {@link #addResource}, and {@link #addPrompt}, automatically notifying + * connected clients of changes when configured to do so. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @author Jihoon Kim + * @see McpServer + * @see McpSchema + * @see McpClientSession + */ +public class McpStatelessAsyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessAsyncServer.class); + + private final McpStatelessServerTransport mcpTransportProvider; + + private final ObjectMapper objectMapper; + + private final McpSchema.ServerCapabilities serverCapabilities; + + private final McpSchema.Implementation serverInfo; + + private final String instructions; + + // TODO: all simple ones + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + + private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); + + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + + // FIXME: this field is deprecated and should be remvoed together with the + // broadcasting loggingNotification. + private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; + + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + + private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); + + private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + /** + * Create a new McpAsyncServer with the given transport provider and capabilities. + * @param mcpTransportProvider The transport layer implementation for MCP + * communication. + * @param features The MCP server supported features. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + */ + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransportProvider, ObjectMapper objectMapper, + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory) { + this.mcpTransportProvider = mcpTransportProvider; + this.objectMapper = objectMapper; + this.serverInfo = features.serverInfo(); + this.serverCapabilities = features.serverCapabilities(); + this.instructions = features.instructions(); + this.tools.addAll(features.tools()); + this.resources.putAll(features.resources()); + this.resourceTemplates.addAll(features.resourceTemplates()); + this.prompts.putAll(features.prompts()); + this.completions.putAll(features.completions()); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + + Map> requestHandlers = new HashMap<>(); + + // Initialize request handlers for standard MCP methods + + // Ping MUST respond with an empty data, but not NULL response. + requestHandlers.put(McpSchema.METHOD_PING, params -> Mono.just(Map.of())); + + // Add tools API handlers if the tool capability is enabled + if (this.serverCapabilities.tools() != null) { + requestHandlers.put(McpSchema.METHOD_TOOLS_LIST, toolsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_TOOLS_CALL, toolsCallRequestHandler()); + } + + // Add resources API handlers if provided + if (this.serverCapabilities.resources() != null) { + requestHandlers.put(McpSchema.METHOD_RESOURCES_LIST, resourcesListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_READ, resourcesReadRequestHandler()); + requestHandlers.put(McpSchema.METHOD_RESOURCES_TEMPLATES_LIST, resourceTemplateListRequestHandler()); + } + + // Add prompts API handlers if provider exists + if (this.serverCapabilities.prompts() != null) { + requestHandlers.put(McpSchema.METHOD_PROMPT_LIST, promptsListRequestHandler()); + requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); + } + + // Add logging API handlers if the logging capability is enabled + if (this.serverCapabilities.logging() != null) { + requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); + } + + // Add completion API handlers if the completion capability is enabled + if (this.serverCapabilities.completions() != null) { + requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); + } + + Map notificationHandlers = new HashMap<>(); + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); + + List, Mono>> rootsChangeConsumers = features + .rootsChangeConsumers(); + + if (Utils.isEmpty(rootsChangeConsumers)) { + rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger + .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); + } + + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, + asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); + + mcpTransportProvider.setHandler(request -> + requestHandlers.get(request.method()).apply(request.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(t -> + Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, t.getMessage(), null))) + )); + ); + } + + // --------------------------------------- + // Lifecycle Management + // --------------------------------------- + private Mono asyncInitializeRequestHandler( + McpSchema.InitializeRequest initializeRequest) { + return Mono.defer(() -> { + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", + initializeRequest.protocolVersion(), initializeRequest.capabilities(), + initializeRequest.clientInfo()); + + // The server MUST respond with the highest protocol version it supports + // if + // it does not support the requested (e.g. Client) version. + String serverProtocolVersion = this.protocolVersions.get(this.protocolVersions.size() - 1); + + if (this.protocolVersions.contains(initializeRequest.protocolVersion())) { + // If the server supports the requested protocol version, it MUST + // respond + // with the same version. + serverProtocolVersion = initializeRequest.protocolVersion(); + } + else { + logger.warn( + "Client requested unsupported protocol version: {}, so the server will suggest the {} version instead", + initializeRequest.protocolVersion(), serverProtocolVersion); + } + + return Mono.just(new McpSchema.InitializeResult(serverProtocolVersion, this.serverCapabilities, + this.serverInfo, this.instructions)); + }); + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.serverCapabilities; + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.serverInfo; + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.mcpTransportProvider.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.mcpTransportProvider.close(); + } + + private McpNotificationHandler asyncRootsListChangedNotificationHandler( + List, Mono>> rootsChangeConsumers) { + return (exchange, params) -> exchange.listRoots() + .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) + .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .onErrorResume(error -> { + logger.error("Error handling roots list change notification", error); + return Mono.empty(); + }) + .then()); + } + + // --------------------------------------- + // Tool Management + // --------------------------------------- + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + if (toolSpecification == null) { + return Mono.error(new McpError("Tool specification must not be null")); + } + if (toolSpecification.tool() == null) { + return Mono.error(new McpError("Tool must not be null")); + } + if (toolSpecification.call() == null) { + return Mono.error(new McpError("Tool call handler must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + // Check for duplicate tool names + if (this.tools.stream().anyMatch(th -> th.tool().name().equals(toolSpecification.tool().name()))) { + return Mono + .error(new McpError("Tool with name '" + toolSpecification.tool().name() + "' already exists")); + } + + this.tools.add(toolSpecification); + logger.debug("Added tool handler: {}", toolSpecification.tool().name()); + + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeTool(String toolName) { + if (toolName == null) { + return Mono.error(new McpError("Tool name must not be null")); + } + if (this.serverCapabilities.tools() == null) { + return Mono.error(new McpError("Server must be configured with tool capabilities")); + } + + return Mono.defer(() -> { + boolean removed = this.tools + .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); + if (removed) { + logger.debug("Removed tool handler: {}", toolName); + if (this.serverCapabilities.tools().listChanged()) { + return notifyToolsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available tools has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyToolsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); + } + + private Function> toolsListRequestHandler() { + return params -> { + List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + + return Mono.just(new McpSchema.ListToolsResult(tools, null)); + }; + } + + private McpRequestHandler toolsCallRequestHandler() { + return (exchange, params) -> { + McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + Optional toolSpecification = this.tools.stream() + .filter(tr -> callToolRequest.name().equals(tr.tool().name())) + .findAny(); + + if (toolSpecification.isEmpty()) { + return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); + } + + return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); + }; + } + + // --------------------------------------- + // Resource Management + // --------------------------------------- + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + if (resourceSpecification == null || resourceSpecification.resource() == null) { + return Mono.error(new McpError("Resource must not be null")); + } + + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + if (this.resources.putIfAbsent(resourceSpecification.resource().uri(), resourceSpecification) != null) { + return Mono.error(new McpError( + "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); + } + logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removeResource(String resourceUri) { + if (resourceUri == null) { + return Mono.error(new McpError("Resource URI must not be null")); + } + if (this.serverCapabilities.resources() == null) { + return Mono.error(new McpError("Server must be configured with resource capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + if (removed != null) { + logger.debug("Removed resource handler: {}", resourceUri); + if (this.serverCapabilities.resources().listChanged()) { + return notifyResourcesListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); + }); + } + + /** + * Notifies clients that the list of available resources has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); + } + + /** + * Notifies clients that the resources have updated. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification) { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED, + resourcesUpdatedNotification); + } + + private McpRequestHandler resourcesListRequestHandler() { + return (exchange, params) -> { + var resourceList = this.resources.values() + .stream() + .map(McpServerFeatures.AsyncResourceSpecification::resource) + .toList(); + return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + }; + } + + private McpRequestHandler resourceTemplateListRequestHandler() { + return (exchange, params) -> Mono + .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + + } + + private List getResourceTemplates() { + var list = new ArrayList<>(this.resourceTemplates); + List resourceTemplates = this.resources.keySet() + .stream() + .filter(uri -> uri.contains("{")) + .map(uri -> { + var resource = this.resources.get(uri).resource(); + var template = new ResourceTemplate(resource.uri(), resource.name(), resource.title(), + resource.description(), resource.mimeType(), resource.annotations()); + return template; + }) + .toList(); + + list.addAll(resourceTemplates); + + return list; + } + + private McpRequestHandler resourcesReadRequestHandler() { + return (exchange, params) -> { + McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + var resourceUri = resourceRequest.uri(); + + McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + .stream() + .filter(resourceSpecification -> this.uriTemplateManagerFactory + .create(resourceSpecification.resource().uri()) + .matches(resourceUri)) + .findFirst() + .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); + + return specification.readHandler().apply(exchange, resourceRequest); + }; + } + + // --------------------------------------- + // Prompt Management + // --------------------------------------- + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + if (promptSpecification == null) { + return Mono.error(new McpError("Prompt specification must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification specification = this.prompts + .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); + if (specification != null) { + return Mono.error( + new McpError("Prompt with name '" + promptSpecification.prompt().name() + "' already exists")); + } + + logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); + + // Servers that declared the listChanged capability SHOULD send a + // notification, + // when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return notifyPromptsListChanged(); + } + return Mono.empty(); + }); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public Mono removePrompt(String promptName) { + if (promptName == null) { + return Mono.error(new McpError("Prompt name must not be null")); + } + if (this.serverCapabilities.prompts() == null) { + return Mono.error(new McpError("Server must be configured with prompt capabilities")); + } + + return Mono.defer(() -> { + McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + + if (removed != null) { + logger.debug("Removed prompt handler: {}", promptName); + // Servers that declared the listChanged capability SHOULD send a + // notification, when the list of available prompts changes + if (this.serverCapabilities.prompts().listChanged()) { + return this.notifyPromptsListChanged(); + } + return Mono.empty(); + } + return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); + }); + } + + /** + * Notifies clients that the list of available prompts has changed. + * @return A Mono that completes when all clients have been notified + */ + public Mono notifyPromptsListChanged() { + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); + } + + private McpRequestHandler promptsListRequestHandler() { + return (exchange, params) -> { + // TODO: Implement pagination + // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + // new TypeReference() { + // }); + + var promptList = this.prompts.values() + .stream() + .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .toList(); + + return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + }; + } + + private McpRequestHandler promptsGetRequestHandler() { + return (exchange, params) -> { + McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, + new TypeReference() { + }); + + // Implement prompt retrieval logic here + McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + if (specification == null) { + return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); + } + + return specification.promptHandler().apply(exchange, promptRequest); + }; + } + + // --------------------------------------- + // Logging Management + // --------------------------------------- + + /** + * This implementation would, incorrectly, broadcast the logging message to all + * connected clients, using a single minLoggingLevel for all of them. Similar to the + * sampling and roots, the logging level should be set per client session and use the + * ServerExchange to send the logging message to the right client. + * @param loggingMessageNotification The logging message to send + * @return A Mono that completes when the notification has been sent + * @deprecated Use + * {@link McpAsyncServerExchange#loggingNotification(LoggingMessageNotification)} + * instead. + */ + @Deprecated + public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { + + if (loggingMessageNotification == null) { + return Mono.error(new McpError("Logging message must not be null")); + } + + if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { + return Mono.empty(); + } + + return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, + loggingMessageNotification); + } + + private McpRequestHandler setLoggerRequestHandler() { + return (exchange, params) -> { + return Mono.defer(() -> { + + SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, + new TypeReference() { + }); + + exchange.setMinLoggingLevel(newMinLoggingLevel.level()); + + // FIXME: this field is deprecated and should be removed together + // with the broadcasting loggingNotification. + this.minLoggingLevel = newMinLoggingLevel.level(); + + return Mono.just(Map.of()); + }); + }; + } + + private McpRequestHandler completionCompleteRequestHandler() { + return (exchange, params) -> { + McpSchema.CompleteRequest request = parseCompletionParams(params); + + if (request.ref() == null) { + return Mono.error(new McpError("ref must not be null")); + } + + if (request.ref().type() == null) { + return Mono.error(new McpError("type must not be null")); + } + + String type = request.ref().type(); + + String argumentName = request.argument().name(); + + // check if the referenced resource exists + if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { + McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + if (promptSpec == null) { + return Mono.error(new McpError("Prompt not found: " + promptReference.name())); + } + if (!promptSpec.prompt() + .arguments() + .stream() + .filter(arg -> arg.name().equals(argumentName)) + .findFirst() + .isPresent()) { + + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + } + + if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { + McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); + if (resourceSpec == null) { + return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); + } + if (!uriTemplateManagerFactory.create(resourceSpec.resource().uri()) + .getVariableNames() + .contains(argumentName)) { + return Mono.error(new McpError("Argument not found: " + argumentName)); + } + + } + + McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + + if (specification == null) { + return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); + } + + return specification.completionHandler().apply(exchange, request); + }; + } + + /** + * Parses the raw JSON-RPC request parameters into a {@link McpSchema.CompleteRequest} + * object. + *

+ * This method manually extracts the `ref` and `argument` fields from the input map, + * determines the correct reference type (either prompt or resource), and constructs a + * fully-typed {@code CompleteRequest} instance. + * @param object the raw request parameters, expected to be a Map containing "ref" and + * "argument" entries. + * @return a {@link McpSchema.CompleteRequest} representing the structured completion + * request. + * @throws IllegalArgumentException if the "ref" type is not recognized. + */ + @SuppressWarnings("unchecked") + private McpSchema.CompleteRequest parseCompletionParams(Object object) { + Map params = (Map) object; + Map refMap = (Map) params.get("ref"); + Map argMap = (Map) params.get("argument"); + + String refType = (String) refMap.get("type"); + + McpSchema.CompleteReference ref = switch (refType) { + case "ref/prompt" -> new McpSchema.PromptReference(refType, (String) refMap.get("name"), + refMap.get("title") != null ? (String) refMap.get("title") : null); + case "ref/resource" -> new McpSchema.ResourceReference(refType, (String) refMap.get("uri")); + default -> throw new IllegalArgumentException("Invalid ref type: " + refType); + }; + + String argName = (String) argMap.get("name"); + String argValue = (String) argMap.get("value"); + McpSchema.CompleteRequest.CompleteArgument argument = new McpSchema.CompleteRequest.CompleteArgument(argName, + argValue); + + return new McpSchema.CompleteRequest(ref, argument); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.protocolVersions = protocolVersions; + } + + static interface RequestHandler extends Function> { + + } +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index afdbff472..552ef7f17 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -18,6 +18,7 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import io.modelcontextprotocol.util.Assert; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; @@ -60,7 +61,7 @@ */ @WebServlet(asyncSupported = true) -public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider { +public class HttpServletSseServerTransportProvider extends HttpServlet implements McpSingleSessionServerTransportProvider { /** Logger for this class */ private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 819da9777..fc9e317bc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -24,6 +24,7 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,7 +41,7 @@ * * @author Christian Tzolov */ -public class StdioServerTransportProvider implements McpServerTransportProvider { +public class StdioServerTransportProvider implements McpSingleSessionServerTransportProvider { private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java new file mode 100644 index 000000000..c6f1f219c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java @@ -0,0 +1,30 @@ +package io.modelcontextprotocol.spec; + +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.Map; +import java.util.UUID; + +public class DefaultMcpStreamableServerSessionFactory implements McpStreamableServerSession.Factory { + Duration requestTimeout; + McpStreamableServerSession.InitRequestHandler initRequestHandler; + Map> requestHandlers; + Map notificationHandlers; + + public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout, McpStreamableServerSession.InitRequestHandler initRequestHandler, Map> requestHandlers, Map notificationHandlers) { + this.requestTimeout = requestTimeout; + this.initRequestHandler = initRequestHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public McpStreamableServerSession.McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest) { + return new McpStreamableServerSession.McpStreamableServerSessionInit(new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(), initializeRequest.clientInfo(), requestTimeout, + Mono::empty, requestHandlers, notificationHandlers), this.initRequestHandler.handle(initializeRequest)); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DisconnectedMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DisconnectedMcpSession.java new file mode 100644 index 000000000..998dc4a65 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DisconnectedMcpSession.java @@ -0,0 +1,28 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import reactor.core.publisher.Mono; + +public class DisconnectedMcpSession implements McpSession { + + public static final DisconnectedMcpSession INSTANCE = new DisconnectedMcpSession(); + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + return Mono.error(new IllegalStateException("Stream unavailable")); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.error(new IllegalStateException("Stream unavailable")); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + } +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 86906d859..381bd3675 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -9,6 +9,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpInitRequestHandler; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -32,13 +35,13 @@ public class McpServerSession implements McpSession { private final AtomicLong requestCounter = new AtomicLong(0); - private final InitRequestHandler initRequestHandler; + private final McpInitRequestHandler initRequestHandler; private final InitNotificationHandler initNotificationHandler; - private final Map> requestHandlers; + private final Map> requestHandlers; - private final Map notificationHandlers; + private final Map notificationHandlers; private final McpServerTransport transport; @@ -70,8 +73,8 @@ public class McpServerSession implements McpSession { * @param notificationHandlers map of notification handlers to use */ public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, - InitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, - Map> requestHandlers, Map notificationHandlers) { + McpInitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, Map notificationHandlers) { this.id = id; this.requestTimeout = requestTimeout; this.transport = transport; @@ -264,17 +267,21 @@ private MethodNotFoundError getMethodNotFoundError(String method) { @Override public Mono closeGracefully() { + // TODO: clear pendingResponses and emit errors? return this.transport.closeGracefully(); } @Override public void close() { + // TODO: clear pendingResponses and emit errors? this.transport.close(); } /** * Request handler for the initialization request. + * @deprecated Use {@link McpInitRequestHandler} */ + @Deprecated public interface InitRequestHandler { /** @@ -301,7 +308,9 @@ public interface InitNotificationHandler { /** * A handler for client-initiated notifications. + * @deprecated Use {@link McpNotificationHandler} */ + @Deprecated public interface NotificationHandler { /** @@ -320,7 +329,9 @@ public interface NotificationHandler { * * @param the type of the response that is expected as a result of handling the * request. + * @deprecated Use {@link McpRequestHandler} */ + @Deprecated public interface RequestHandler { /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index 5fdbd7ab6..a90c615f7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -31,14 +31,6 @@ */ public interface McpServerTransportProvider { - /** - * Sets the session factory that will be used to create sessions for new clients. An - * implementation of the MCP server MUST call this method before any MCP interactions - * take place. - * @param sessionFactory the session factory to be used for initiating client sessions - */ - void setSessionFactory(McpServerSession.Factory sessionFactory); - /** * Sends a notification to all connected clients. * @param method the name of the notification method to be called on the clients diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSingleSessionServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSingleSessionServerTransportProvider.java new file mode 100644 index 000000000..762968dc9 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSingleSessionServerTransportProvider.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.spec; + +public interface McpSingleSessionServerTransportProvider extends McpServerTransportProvider { + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpServerSession.Factory sessionFactory); +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java new file mode 100644 index 000000000..6c184b724 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -0,0 +1,26 @@ +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +import java.util.function.Function; + +public interface McpStatelessServerTransport { + + void setHandler(Function> message); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java new file mode 100644 index 000000000..53f242955 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -0,0 +1,271 @@ +package io.modelcontextprotocol.spec; + +import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.server.McpAsyncServerExchange; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; + +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +public class McpStreamableServerSession implements McpSession { + + private static final Logger logger = LoggerFactory.getLogger(McpStreamableServerSession.class); + + private final ConcurrentHashMap requestIdToStream = new ConcurrentHashMap<>(); + + private final String id; + + private final Duration requestTimeout; + + private final AtomicLong requestCounter = new AtomicLong(0); + + private final InitNotificationHandler initNotificationHandler; + + private final Map> requestHandlers; + + private final Map notificationHandlers; + + private final AtomicReference clientCapabilities = new AtomicReference<>(); + + private final AtomicReference clientInfo = new AtomicReference<>(); + + private static final int STATE_UNINITIALIZED = 0; + + private static final int STATE_INITIALIZING = 1; + + private static final int STATE_INITIALIZED = 2; + + private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + + private final AtomicReference genericStreamRef = new AtomicReference<>(); + + public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + Duration requestTimeout, + InitNotificationHandler initNotificationHandler, + Map> requestHandlers, Map notificationHandlers) { + this.id = id; + this.clientCapabilities.lazySet(clientCapabilities); + this.clientInfo.lazySet(clientInfo); + this.requestTimeout = requestTimeout; + this.initNotificationHandler = initNotificationHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + public String getId() { + return this.id; + } + + private String generateRequestId() { + return this.id + "-" + this.requestCounter.getAndIncrement(); + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + return Mono.defer(() -> { + McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); + return genericStream != null ? genericStream.sendRequest(method, requestParams, typeRef) : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.defer(() -> { + McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); + return genericStream != null ? genericStream.sendNotification(method, params) : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); + }); + } + + public McpStreamableServerSessionStream newStream(McpServerTransport transport) { + McpStreamableServerSessionStream genericStream = new McpStreamableServerSessionStream(transport); + this.genericStreamRef.set(genericStream); + return genericStream; + } + + // TODO: keep track of history by keeping a map from eventId to stream and then iterate over the events using the lastEventId + public Flux replay(Object lastEventId) { + return Flux.empty(); + } + + public Mono handleStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpServerTransport transport) { + McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport); + McpRequestHandler requestHandler = McpStreamableServerSession.this.requestHandlers.get(jsonrpcRequest.method()); + // TODO: delegate to stream, which upon successful response should close remove itself from the registry and also close the underlying transport (sink) + if (requestHandler == null) { + MethodNotFoundError error = getMethodNotFoundError(jsonrpcRequest.method()); + return transport.sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + return requestHandler.handle(new McpAsyncServerExchange(stream, clientCapabilities.get(), clientInfo.get()), jsonrpcRequest.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, null)) + .flatMap(transport::sendMessage).then(transport.closeGracefully()); + } + public Mono accept(McpSchema.JSONRPCNotification notification) { + return Mono.defer(() -> { + McpNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); + if (notificationHandler == null) { + logger.error("No handler registered for notification method: {}", notification.method()); + return Mono.empty(); + } + McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); + return notificationHandler.handle(new McpAsyncServerExchange(genericStream != null ? genericStream : DisconnectedMcpSession.INSTANCE, this.clientCapabilities.get(), this.clientInfo.get()), notification.params()); + }); + + } + public Mono accept(McpSchema.JSONRPCResponse response) { + return Mono.defer(() -> { + var stream = this.requestIdToStream.get(response.id()); + if (stream == null) { + return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO JSONize + } + var sink = stream.pendingResponses.remove(response.id()); + if (sink == null) { + return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO JSONize + } else { + sink.success(response); + } + return Mono.empty(); + }); + } + + record MethodNotFoundError(String method, String message, Object data) { + } + + private MethodNotFoundError getMethodNotFoundError(String method) { + return new MethodNotFoundError(method, "Method not found: " + method, null); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); + return genericStream != null ? genericStream.closeGracefully() : Mono.empty(); // TODO: Also close all the open streams + }); + } + + @Override + public void close() { + McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); + if (genericStream != null) { + genericStream.close(); + } + // TODO: Also close all open streams + } + + /** + * Request handler for the initialization request. + */ + public interface InitRequestHandler { + + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); + + } + + /** + * Notification handler for the initialization notification from the client. + */ + public interface InitNotificationHandler { + + /** + * Specifies an action to take upon successful initialization. + * @return a Mono that will complete when the initialization is acted upon. + */ + Mono handle(); + + } + + public interface Factory { + McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest); + } + + public record McpStreamableServerSessionInit(McpStreamableServerSession session, Mono initResult) {} + + public final class McpStreamableServerSessionStream implements McpSession { + + private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); + + private final McpServerTransport transport; + + public McpStreamableServerSessionStream(McpServerTransport transport) { + this.transport = transport; + } + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + String requestId = McpStreamableServerSession.this.generateRequestId(); + + McpStreamableServerSession.this.requestIdToStream.put(requestId, this); + + return Mono.create(sink -> { + this.pendingResponses.put(requestId, sink); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, + requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest).subscribe(v -> {}, sink::error); + }) + .timeout(requestTimeout) + .doOnError(e -> { + this.pendingResponses.remove(requestId); + McpStreamableServerSession.this.requestIdToStream.remove(requestId); + }) + .handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); + } + + @Override + public Mono sendNotification(String method, Object params) { + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, + method, params); + return this.transport.sendMessage(jsonrpcNotification); + } + + @Override + public Mono closeGracefully() { + return Mono.defer(() -> { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.genericStreamRef.compareAndExchange(this, null); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + return this.transport.closeGracefully(); + }); + } + + @Override + public void close() { + this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); + this.pendingResponses.clear(); + // If this was the generic stream, reset it + McpStreamableServerSession.this.genericStreamRef.compareAndExchange(this, null); + McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); + this.transport.close(); + } + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java new file mode 100644 index 000000000..22b618440 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java @@ -0,0 +1,66 @@ +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +import java.util.Map; + +/** + * The core building block providing the server-side MCP transport. Implement this + * interface to bridge between a particular server-side technology and the MCP server + * transport layer. + * + *

+ * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpStreamableServerTransportProvider)} or + * {@link io.modelcontextprotocol.server.McpServer#async(McpStreamableServerTransportProvider)}. As + * a result of the MCP server creation, the provider will be notified of a + * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication + * between a newly connected client and the server. The provider's responsibility is to + * create instances of {@link McpServerTransport} that the session will utilise during the + * session lifetime. + * + *

+ * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or + * {@link #closeGracefully()} are called as part of the normal application shutdown event. + * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where + * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} + * closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransportProvider extends McpServerTransportProvider { + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpStreamableServerSession.Factory sessionFactory); + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + * @see McpSession#sendNotification(String, Map) + */ + Mono notifyClients(String method, Object params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 7ba35bbf0..71e090890 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -19,12 +19,13 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerSession.Factory; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import reactor.core.publisher.Mono; /** * @author Christian Tzolov */ -public class MockMcpServerTransportProvider implements McpServerTransportProvider { +public class MockMcpServerTransportProvider implements McpSingleSessionServerTransportProvider { private McpServerSession session; From f83db452e6045ccbcb3cb323a57b1b1adf48c53a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 22 Jul 2025 15:36:08 +0200 Subject: [PATCH 2/8] More WIP: renaming, steteless server abstractions, client context extraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- ...FluxStreamableServerTransportProvider.java | 20 +- .../server/McpAsyncServer.java | 6 +- .../server/McpAsyncServerExchange.java | 25 + .../server/McpServer.java | 1380 ++++++++++++++--- .../server/McpStatelessAsyncServer.java | 314 +--- .../server/McpStatelessServerFeatures.java | 377 +++++ .../server/McpStatelessSyncServer.java | 152 ++ .../server/McpSyncServerExchange.java | 4 + ...HttpServletSseServerTransportProvider.java | 3 +- .../spec/DefaultMcpTransportContext.java | 30 + .../spec/McpServerTransportProvider.java | 64 +- .../spec/McpServerTransportProviderBase.java | 58 + ...pSingleSessionServerTransportProvider.java | 12 - .../spec/McpStatelessServerTransport.java | 4 +- .../spec/McpStreamableServerSession.java | 69 +- .../McpStreamableServerTransportProvider.java | 2 +- .../spec/McpTransportContext.java | 12 + ...n.java => MissingMcpTransportSession.java} | 4 +- 18 files changed, 1989 insertions(+), 547 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java delete mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpSingleSessionServerTransportProvider.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java rename mcp/src/main/java/io/modelcontextprotocol/spec/{DisconnectedMcpSession.java => MissingMcpTransportSession.java} (79%) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index 3dd279ff7..ffeb25e2b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -2,11 +2,13 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.DefaultMcpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerSession; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpTransportContext; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -25,6 +27,7 @@ import java.io.IOException; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider { @@ -48,6 +51,9 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + // TODO: add means to specify this + private Function contextExtractor = req -> new DefaultMcpTransportContext(); + /** * Flag indicating if the transport is shutting down. */ @@ -183,6 +189,8 @@ private Mono handleGet(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + McpTransportContext transportContext = this.contextExtractor.apply(request); + return Mono.defer(() -> { if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) { return ServerResponse.badRequest().build(); // TODO: say we need a session id @@ -204,11 +212,11 @@ private Mono handleGet(ServerRequest request) { return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport(sink); - McpStreamableServerSession.McpStreamableServerSessionStream genericStream = session.newStream(sessionTransport); - sink.onDispose(genericStream::close); + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session.listeningStream(sessionTransport); + sink.onDispose(listeningStream::close); }), ServerSentEvent.class); - }); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); } /** @@ -231,6 +239,8 @@ private Mono handlePost(ServerRequest request) { return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); } + McpTransportContext transportContext = this.contextExtractor.apply(request); + return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); @@ -261,7 +271,7 @@ private Mono handlePost(ServerRequest request) { return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM) .body(Flux.>create(sink -> { WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); - Mono stream = session.handleStream(jsonrpcRequest, st); + Mono stream = session.responseStream(jsonrpcRequest, st); Disposable streamSubscription = stream .doOnError(err -> sink.error(err)) .contextWrite(sink.contextView()) @@ -276,7 +286,7 @@ private Mono handlePost(ServerRequest request) { logger.error("Failed to deserialize message: {}", e.getMessage()); return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); } - }); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); } private class WebFluxStreamableMcpSessionTransport implements McpServerTransport { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 7efa02591..b4efad266 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -16,7 +16,7 @@ import java.util.function.BiFunction; import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; -import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProviderBase; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -89,7 +89,7 @@ public class McpAsyncServer { private static final Logger logger = LoggerFactory.getLogger(McpAsyncServer.class); - private final McpServerTransportProvider mcpTransportProvider; + private final McpServerTransportProviderBase mcpTransportProvider; private final ObjectMapper objectMapper; @@ -126,7 +126,7 @@ public class McpAsyncServer { * @param features The MCP server supported features. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ - McpAsyncServer(McpSingleSessionServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, + McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, McpServerFeatures.Async features, Duration requestTimeout, McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransportProvider; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index bf36ebb80..0d2a0a37e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -8,11 +8,13 @@ import java.util.Collections; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.spec.DefaultMcpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSession; +import io.modelcontextprotocol.spec.McpTransportContext; import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; @@ -31,6 +33,8 @@ public class McpAsyncServerExchange { private final McpSchema.Implementation clientInfo; + private final McpTransportContext transportContext; + private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { @@ -57,6 +61,23 @@ public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities c this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; + this.transportContext = new DefaultMcpTransportContext(); + } + + /** + * Create a new asynchronous exchange with the client. + * @param session The server session representing a 1-1 interaction. + * @param clientCapabilities The client capabilities that define the supported + * features and functionality. + * @param transportContext context associated with the client as extracted from the transport + * @param clientInfo The client implementation information. + */ + public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo, McpTransportContext transportContext) { + this.session = session; + this.clientCapabilities = clientCapabilities; + this.clientInfo = clientInfo; + this.transportContext = transportContext; } /** @@ -75,6 +96,10 @@ public McpSchema.Implementation getClientInfo() { return this.clientInfo; } + public McpTransportContext transportContext() { + return this.transportContext; + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index 05734c272..af2b8d92e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -22,6 +22,8 @@ import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpTransportContext; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; @@ -132,6 +134,8 @@ */ public interface McpServer { + McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", "1.0.0"); + /** * Starts building a synchronous MCP server that provides blocking operations. * Synchronous servers block the current Thread's execution upon each request before @@ -140,8 +144,8 @@ public interface McpServer { * @param transportProvider The transport layer implementation for MCP communication. * @return A new instance of {@link SyncSpecification} for configuring the server. */ - static SyncSpecification sync(McpServerTransportProvider transportProvider) { - return new SyncSpecification(transportProvider); + static SingleSessionSyncSpecification sync(McpServerTransportProvider transportProvider) { + return new SingleSessionSyncSpecification(transportProvider); } /** @@ -152,47 +156,107 @@ static SyncSpecification sync(McpServerTransportProvider transportProvider) { * @param transportProvider The transport layer implementation for MCP communication. * @return A new instance of {@link AsyncSpecification} for configuring the server. */ - static AsyncSpecification async(McpServerTransportProvider transportProvider) { - return new AsyncSpecification(transportProvider); + static SingleSessionAsyncSpecification async(McpServerTransportProvider transportProvider) { + return new SingleSessionAsyncSpecification(transportProvider); + } + + /** + * Starts building a synchronous MCP server that provides blocking operations. + * Synchronous servers block the current Thread's execution upon each request before + * giving the control back to the caller, making them simpler to implement but + * potentially less scalable for concurrent operations. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link SyncSpecification} for configuring the server. + */ + static StreamableSyncSpecification sync(McpStreamableServerTransportProvider transportProvider) { + return new StreamableSyncSpecification(transportProvider); } - static StatelessAsyncSpecification async(McpStatelessServerTransport transportProvider) { - // TODO + /** + * Starts building an asynchronous MCP server that provides non-blocking operations. + * Asynchronous servers can handle multiple requests concurrently on a single Thread + * using a functional paradigm with non-blocking server transports, making them more + * scalable for high-concurrency scenarios but more complex to implement. + * @param transportProvider The transport layer implementation for MCP communication. + * @return A new instance of {@link AsyncSpecification} for configuring the server. + */ + static StreamableServerAsyncSpecification async(McpStreamableServerTransportProvider transportProvider) { + return new StreamableServerAsyncSpecification(transportProvider); } - static StatelessSyncSpecification sync(McpStatelessServerTransport transportProvider) { - // TODO + static StatelessAsyncSpecification async(McpStatelessServerTransport transport) { + return new StatelessAsyncSpecification(transport); } - class StatelessAsyncSpecification { - // TODO + static StatelessSyncSpecification sync(McpStatelessServerTransport transport) { + return new StatelessSyncSpecification(transport); } - class StatelessSyncSpecification { - // TODO + class SingleSessionAsyncSpecification extends AsyncSpecification { + private final McpServerTransportProvider transportProvider; + + private SingleSessionAsyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + } + } + + class StreamableServerAsyncSpecification extends AsyncSpecification { + private final McpStreamableServerTransportProvider transportProvider; + + public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider transportProvider) { + this.transportProvider = transportProvider; + } + + /** + * Builds an asynchronous MCP server that provides non-blocking operations. + * @return A new instance of {@link McpAsyncServer} configured with this builder's + * settings. + */ + public McpAsyncServer build() { + var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, + this.instructions); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + } } /** * Asynchronous server specification. */ - class AsyncSpecification { - - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); + class AsyncSpecification> { - private final McpServerTransportProvider transportProvider; - - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - private ObjectMapper objectMapper; + ObjectMapper objectMapper; - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - private McpSchema.ServerCapabilities serverCapabilities; + McpSchema.ServerCapabilities serverCapabilities; - private JsonSchemaValidator jsonSchemaValidator; + JsonSchemaValidator jsonSchemaValidator; - private String instructions; + String instructions; /** * The Model Context Protocol (MCP) allows servers to expose tools that can be @@ -201,7 +265,7 @@ class AsyncSpecification { * Each tool is uniquely identified by a name and includes metadata describing its * schema. */ - private final List tools = new ArrayList<>(); + final List tools = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -210,9 +274,9 @@ class AsyncSpecification { * application-specific information. Each resource is uniquely identified by a * URI. */ - private final Map resources = new HashMap<>(); + final Map resources = new HashMap<>(); - private final List resourceTemplates = new ArrayList<>(); + final List resourceTemplates = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -221,17 +285,17 @@ class AsyncSpecification { * discover available prompts, retrieve their contents, and provide arguments to * customize them. */ - private final Map prompts = new HashMap<>(); + final Map prompts = new HashMap<>(); - private final Map completions = new HashMap<>(); + final Map completions = new HashMap<>(); - private final List, Mono>> rootsChangeHandlers = new ArrayList<>(); + final List, Mono>> rootsChangeHandlers = new ArrayList<>(); - private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout - private AsyncSpecification(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transportProvider = transportProvider; + @SuppressWarnings("unchecked") + S self() { + return (S) this; } /** @@ -241,10 +305,10 @@ private AsyncSpecification(McpServerTransportProvider transportProvider) { * @return This builder instance for method chaining * @throws IllegalArgumentException if uriTemplateManagerFactory is null */ - public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + public S uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); this.uriTemplateManagerFactory = uriTemplateManagerFactory; - return this; + return self(); } /** @@ -256,10 +320,10 @@ public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory * @return This builder instance for method chaining * @throws IllegalArgumentException if requestTimeout is null */ - public AsyncSpecification requestTimeout(Duration requestTimeout) { + public S requestTimeout(Duration requestTimeout) { Assert.notNull(requestTimeout, "Request timeout must not be null"); this.requestTimeout = requestTimeout; - return this; + return self(); } /** @@ -271,10 +335,10 @@ public AsyncSpecification requestTimeout(Duration requestTimeout) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + public S serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; - return this; + return self(); } /** @@ -287,11 +351,11 @@ public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public AsyncSpecification serverInfo(String name, String version) { + public S serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); - return this; + return self(); } /** @@ -301,9 +365,9 @@ public AsyncSpecification serverInfo(String name, String version) { * @param instructions The instructions text. Can be null or empty. * @return This builder instance for method chaining */ - public AsyncSpecification instructions(String instructions) { + public S instructions(String instructions) { this.instructions = instructions; - return this; + return self(); } /** @@ -320,10 +384,10 @@ public AsyncSpecification instructions(String instructions) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public S capabilities(McpSchema.ServerCapabilities serverCapabilities) { Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; - return this; + return self(); } /** @@ -351,15 +415,15 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * calls that require a request object. */ @Deprecated - public AsyncSpecification tool(McpSchema.Tool tool, - BiFunction, Mono> handler) { + public S tool(McpSchema.Tool tool, + BiFunction, Mono> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); assertNoDuplicateTool(tool.name()); this.tools.add(new McpServerFeatures.AsyncToolSpecification(tool, handler)); - return this; + return self(); } /** @@ -375,8 +439,8 @@ public AsyncSpecification tool(McpSchema.Tool tool, * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public AsyncSpecification toolCall(McpSchema.Tool tool, - BiFunction> callHandler) { + public S toolCall(McpSchema.Tool tool, + BiFunction> callHandler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(callHandler, "Handler must not be null"); @@ -385,7 +449,7 @@ public AsyncSpecification toolCall(McpSchema.Tool tool, this.tools .add(McpServerFeatures.AsyncToolSpecification.builder().tool(tool).callHandler(callHandler).build()); - return this; + return self(); } /** @@ -398,7 +462,7 @@ public AsyncSpecification toolCall(McpSchema.Tool tool, * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(McpServerFeatures.AsyncToolSpecification...) */ - public AsyncSpecification tools(List toolSpecifications) { + public S tools(List toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { @@ -406,7 +470,7 @@ public AsyncSpecification tools(List t this.tools.add(tool); } - return this; + return self(); } /** @@ -425,14 +489,14 @@ public AsyncSpecification tools(List t * @return This builder instance for method chaining * @throws IllegalArgumentException if toolSpecifications is null */ - public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + public S tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { assertNoDuplicateTool(tool.tool().name()); this.tools.add(tool); } - return this; + return self(); } private void assertNoDuplicateTool(String toolName) { @@ -451,11 +515,11 @@ private void assertNoDuplicateTool(String toolName) { * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public AsyncSpecification resources( + public S resources( Map resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); this.resources.putAll(resourceSpecifications); - return this; + return self(); } /** @@ -467,12 +531,12 @@ public AsyncSpecification resources( * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public AsyncSpecification resources(List resourceSpecifications) { + public S resources(List resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); } - return this; + return self(); } /** @@ -492,12 +556,12 @@ public AsyncSpecification resources(List resourceTemplates) { + public S resourceTemplates(List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); - return this; + return self(); } /** @@ -531,12 +595,12 @@ public AsyncSpecification resourceTemplates(List resourceTempl * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public S resourceTemplates(ResourceTemplate... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); } - return this; + return self(); } /** @@ -556,10 +620,10 @@ public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplate * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public AsyncSpecification prompts(Map prompts) { + public S prompts(Map prompts) { Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); - return this; + return self(); } /** @@ -570,12 +634,12 @@ public AsyncSpecification prompts(Map prompts) { + public S prompts(List prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } - return this; + return self(); } /** @@ -594,12 +658,12 @@ public AsyncSpecification prompts(List completions) { + public S completions(List completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); } - return this; + return self(); } /** @@ -624,12 +688,12 @@ public AsyncSpecification completions(List, Mono> handler) { Assert.notNull(handler, "Consumer must not be null"); this.rootsChangeHandlers.add(handler); - return this; + return self(); } /** @@ -658,11 +722,11 @@ public AsyncSpecification rootsChangeHandler( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandler(BiFunction) */ - public AsyncSpecification rootsChangeHandlers( + public S rootsChangeHandlers( List, Mono>> handlers) { Assert.notNull(handlers, "Handlers list must not be null"); this.rootsChangeHandlers.addAll(handlers); - return this; + return self(); } /** @@ -674,7 +738,7 @@ public AsyncSpecification rootsChangeHandlers( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandlers(List) */ - public AsyncSpecification rootsChangeHandlers( + public S rootsChangeHandlers( @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(Arrays.asList(handlers)); @@ -686,10 +750,10 @@ public AsyncSpecification rootsChangeHandlers( * @return This builder instance for method chaining. * @throws IllegalArgumentException if objectMapper is null */ - public AsyncSpecification objectMapper(ObjectMapper objectMapper) { + public S objectMapper(ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); this.objectMapper = objectMapper; - return this; + return self(); } /** @@ -700,49 +764,90 @@ public AsyncSpecification objectMapper(ObjectMapper objectMapper) { * @return This builder instance for method chaining * @throws IllegalArgumentException if jsonSchemaValidator is null */ - public AsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + public S jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); this.jsonSchemaValidator = jsonSchemaValidator; - return this; + return self(); + } + + } + + class SingleSessionSyncSpecification extends SyncSpecification { + private final McpServerTransportProvider transportProvider; + + private SingleSessionSyncSpecification(McpServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; } + /** - * Builds an asynchronous MCP server that provides non-blocking operations. - * @return A new instance of {@link McpAsyncServer} configured with this builder's + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's * settings. */ - public McpAsyncServer build() { - var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, - this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, - this.instructions); + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator : new DefaultJsonSchemaValidator(mapper); - return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, + + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); + + return new McpSyncServer(asyncServer, this.immediateExecution); } + } + + class StreamableSyncSpecification extends SyncSpecification { + private final McpStreamableServerTransportProvider transportProvider; + + private StreamableSyncSpecification(McpStreamableServerTransportProvider transportProvider) { + Assert.notNull(transportProvider, "Transport provider must not be null"); + this.transportProvider = transportProvider; + } + + + /** + * Builds a synchronous MCP server that provides blocking operations. + * @return A new instance of {@link McpSyncServer} configured with this builder's + * settings. + */ + public McpSyncServer build() { + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + return new McpSyncServer(asyncServer, this.immediateExecution); + } } /** * Synchronous server specification. */ - class SyncSpecification { - - private static final McpSchema.Implementation DEFAULT_SERVER_INFO = new McpSchema.Implementation("mcp-server", - "1.0.0"); + class SyncSpecification> { - private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - private final McpServerTransportProvider transportProvider; - - private ObjectMapper objectMapper; + ObjectMapper objectMapper; - private McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; - private McpSchema.ServerCapabilities serverCapabilities; + McpSchema.ServerCapabilities serverCapabilities; - private String instructions; + String instructions; /** * The Model Context Protocol (MCP) allows servers to expose tools that can be @@ -751,7 +856,7 @@ class SyncSpecification { * Each tool is uniquely identified by a name and includes metadata describing its * schema. */ - private final List tools = new ArrayList<>(); + final List tools = new ArrayList<>(); /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -760,11 +865,11 @@ class SyncSpecification { * application-specific information. Each resource is uniquely identified by a * URI. */ - private final Map resources = new HashMap<>(); + final Map resources = new HashMap<>(); - private final List resourceTemplates = new ArrayList<>(); + final List resourceTemplates = new ArrayList<>(); - private JsonSchemaValidator jsonSchemaValidator; + JsonSchemaValidator jsonSchemaValidator; /** * The Model Context Protocol (MCP) provides a standardized way for servers to @@ -773,19 +878,19 @@ class SyncSpecification { * discover available prompts, retrieve their contents, and provide arguments to * customize them. */ - private final Map prompts = new HashMap<>(); + final Map prompts = new HashMap<>(); - private final Map completions = new HashMap<>(); + final Map completions = new HashMap<>(); - private final List>> rootsChangeHandlers = new ArrayList<>(); + final List>> rootsChangeHandlers = new ArrayList<>(); - private Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout - private boolean immediateExecution = false; + boolean immediateExecution = false; - private SyncSpecification(McpServerTransportProvider transportProvider) { - Assert.notNull(transportProvider, "Transport provider must not be null"); - this.transportProvider = transportProvider; + @SuppressWarnings("unchecked") + protected S self() { + return (S) this; } /** @@ -795,10 +900,10 @@ private SyncSpecification(McpServerTransportProvider transportProvider) { * @return This builder instance for method chaining * @throws IllegalArgumentException if uriTemplateManagerFactory is null */ - public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + public S uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); this.uriTemplateManagerFactory = uriTemplateManagerFactory; - return this; + return self(); } /** @@ -807,13 +912,13 @@ public SyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory * resource access, and prompt operations. * @param requestTimeout The duration to wait before timing out requests. Must not * be null. - * @return This builder instance for method chaining + * @return self() builder instance for method chaining * @throws IllegalArgumentException if requestTimeout is null */ - public SyncSpecification requestTimeout(Duration requestTimeout) { + public S requestTimeout(Duration requestTimeout) { Assert.notNull(requestTimeout, "Request timeout must not be null"); this.requestTimeout = requestTimeout; - return this; + return self(); } /** @@ -825,10 +930,10 @@ public SyncSpecification requestTimeout(Duration requestTimeout) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + public S serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; - return this; + return self(); } /** @@ -841,11 +946,11 @@ public SyncSpecification serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public SyncSpecification serverInfo(String name, String version) { + public S serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); - return this; + return self(); } /** @@ -855,9 +960,9 @@ public SyncSpecification serverInfo(String name, String version) { * @param instructions The instructions text. Can be null or empty. * @return This builder instance for method chaining */ - public SyncSpecification instructions(String instructions) { + public S instructions(String instructions) { this.instructions = instructions; - return this; + return self(); } /** @@ -874,10 +979,10 @@ public SyncSpecification instructions(String instructions) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public S capabilities(McpSchema.ServerCapabilities serverCapabilities) { Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; - return this; + return self(); } /** @@ -904,7 +1009,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * calls that require a request object. */ @Deprecated - public SyncSpecification tool(McpSchema.Tool tool, + public S tool(McpSchema.Tool tool, BiFunction, McpSchema.CallToolResult> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -912,7 +1017,7 @@ public SyncSpecification tool(McpSchema.Tool tool, this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, handler)); - return this; + return self(); } /** @@ -928,7 +1033,7 @@ public SyncSpecification tool(McpSchema.Tool tool, * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public SyncSpecification toolCall(McpSchema.Tool tool, + public S toolCall(McpSchema.Tool tool, BiFunction handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -936,7 +1041,7 @@ public SyncSpecification toolCall(McpSchema.Tool tool, this.tools.add(new McpServerFeatures.SyncToolSpecification(tool, null, handler)); - return this; + return self(); } /** @@ -949,7 +1054,7 @@ public SyncSpecification toolCall(McpSchema.Tool tool, * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(McpServerFeatures.SyncToolSpecification...) */ - public SyncSpecification tools(List toolSpecifications) { + public S tools(List toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { @@ -958,7 +1063,7 @@ public SyncSpecification tools(List too this.tools.add(tool); } - return this; + return self(); } /** @@ -978,14 +1083,14 @@ public SyncSpecification tools(List too * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(List) */ - public SyncSpecification tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { + public S tools(McpServerFeatures.SyncToolSpecification... toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.SyncToolSpecification tool : toolSpecifications) { assertNoDuplicateTool(tool.tool().name()); this.tools.add(tool); } - return this; + return self(); } private void assertNoDuplicateTool(String toolName) { @@ -1004,11 +1109,11 @@ private void assertNoDuplicateTool(String toolName) { * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.SyncResourceSpecification...) */ - public SyncSpecification resources( + public S resources( Map resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); this.resources.putAll(resourceSpecifications); - return this; + return self(); } /** @@ -1020,12 +1125,12 @@ public SyncSpecification resources( * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.SyncResourceSpecification...) */ - public SyncSpecification resources(List resourceSpecifications) { + public S resources(List resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.SyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); } - return this; + return self(); } /** @@ -1045,12 +1150,12 @@ public SyncSpecification resources(List resourceTemplates) { + public S resourceTemplates(List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); - return this; + return self(); } /** @@ -1084,12 +1189,12 @@ public SyncSpecification resourceTemplates(List resourceTempla * @throws IllegalArgumentException if resourceTemplates is null * @see #resourceTemplates(List) */ - public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + public S resourceTemplates(ResourceTemplate... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); } - return this; + return self(); } /** @@ -1110,10 +1215,10 @@ public SyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public SyncSpecification prompts(Map prompts) { + public S prompts(Map prompts) { Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); - return this; + return self(); } /** @@ -1124,12 +1229,12 @@ public SyncSpecification prompts(Map prompts) { + public S prompts(List prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } - return this; + return self(); } /** @@ -1148,12 +1253,12 @@ public SyncSpecification prompts(List * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... prompts) { + public S prompts(McpServerFeatures.SyncPromptSpecification... prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.SyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } - return this; + return self(); } /** @@ -1164,12 +1269,12 @@ public SyncSpecification prompts(McpServerFeatures.SyncPromptSpecification... pr * @throws IllegalArgumentException if completions is null * @see #completions(McpServerFeatures.SyncCompletionSpecification...) */ - public SyncSpecification completions(List completions) { + public S completions(List completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.SyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); } - return this; + return self(); } /** @@ -1179,12 +1284,12 @@ public SyncSpecification completions(List> handler) { + public S rootsChangeHandler(BiConsumer> handler) { Assert.notNull(handler, "Consumer must not be null"); this.rootsChangeHandlers.add(handler); - return this; + return self(); } /** @@ -1212,11 +1317,11 @@ public SyncSpecification rootsChangeHandler(BiConsumer>> handlers) { Assert.notNull(handlers, "Handlers list must not be null"); this.rootsChangeHandlers.addAll(handlers); - return this; + return self(); } /** @@ -1228,7 +1333,7 @@ public SyncSpecification rootsChangeHandlers( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandlers(List) */ - public SyncSpecification rootsChangeHandlers( + public S rootsChangeHandlers( BiConsumer>... handlers) { Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(List.of(handlers)); @@ -1240,16 +1345,16 @@ public SyncSpecification rootsChangeHandlers( * @return This builder instance for method chaining. * @throws IllegalArgumentException if objectMapper is null */ - public SyncSpecification objectMapper(ObjectMapper objectMapper) { + public S objectMapper(ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); this.objectMapper = objectMapper; - return this; + return self(); } - public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + public S jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); this.jsonSchemaValidator = jsonSchemaValidator; - return this; + return self(); } /** @@ -1263,32 +1368,943 @@ public SyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValid * @return This builder instance for method chaining. * */ - public SyncSpecification immediateExecution(boolean immediateExecution) { + public S immediateExecution(boolean immediateExecution) { this.immediateExecution = immediateExecution; - return this; + return self(); } + } + + class StatelessAsyncSpecification { + + private final McpStatelessServerTransport transport; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + ObjectMapper objectMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; /** - * Builds a synchronous MCP server that provides blocking operations. - * @return A new instance of {@link McpSyncServer} configured with this builder's - * settings. + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. */ - public McpSyncServer build() { - McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, - this.rootsChangeHandlers, this.instructions); - McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, - this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); + final List tools = new ArrayList<>(); - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); - return new McpSyncServer(asyncServer, this.immediateExecution); + final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessAsyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessAsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessAsyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; } + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessAsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessAsyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessAsyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *

    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessAsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.AsyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpAsyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessAsyncSpecification toolCall(McpSchema.Tool tool, + BiFunction> callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.AsyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.AsyncToolSpecification...) + */ + public StatelessAsyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
+		 *     McpServerFeatures.AsyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessAsyncSpecification tools(McpStatelessServerFeatures.AsyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.AsyncResourceSpecification...) + */ + public StatelessAsyncSpecification resources(List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new McpServerFeatures.AsyncResourceSpecification(fileResource, fileHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(dbResource, dbHandler),
+		 *     new McpServerFeatures.AsyncResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessAsyncSpecification resources(McpStatelessServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

+ * Example usage:

{@code
+		 * .resourceTemplates(
+		 *     new ResourceTemplate("file://{path}", "Access files by path"),
+		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
+		 * )
+		 * }
+ * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(ResourceTemplate...) + */ + public StatelessAsyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessAsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * .prompts(Map.of("analysis", new McpServerFeatures.AsyncPromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
+		 *         .map(GetPromptResult::new)
+		 * )));
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.AsyncPromptSpecification...) + */ + public StatelessAsyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new McpServerFeatures.AsyncPromptSpecification(analysisPrompt, analysisHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(summaryPrompt, summaryHandler),
+		 *     new McpServerFeatures.AsyncPromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessAsyncSpecification prompts(McpStatelessServerFeatures.AsyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessAsyncSpecification completions(McpStatelessServerFeatures.AsyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public StatelessAsyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessAsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + public McpStatelessAsyncServer build() { + var features = new McpStatelessServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + return new McpStatelessAsyncServer(this.transport, mapper, features, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); + } + } + + class StatelessSyncSpecification { + + private final McpStatelessServerTransport transport; + + boolean immediateExecution = false; + + McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); + + ObjectMapper objectMapper; + + McpSchema.Implementation serverInfo = DEFAULT_SERVER_INFO; + + McpSchema.ServerCapabilities serverCapabilities; + + JsonSchemaValidator jsonSchemaValidator; + + String instructions; + + /** + * The Model Context Protocol (MCP) allows servers to expose tools that can be + * invoked by language models. Tools enable models to interact with external + * systems, such as querying databases, calling APIs, or performing computations. + * Each tool is uniquely identified by a name and includes metadata describing its + * schema. + */ + final List tools = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose resources to clients. Resources allow servers to share data that + * provides context to language models, such as files, database schemas, or + * application-specific information. Each resource is uniquely identified by a + * URI. + */ + final Map resources = new HashMap<>(); + + final List resourceTemplates = new ArrayList<>(); + + /** + * The Model Context Protocol (MCP) provides a standardized way for servers to + * expose prompt templates to clients. Prompts allow servers to provide structured + * messages and instructions for interacting with language models. Clients can + * discover available prompts, retrieve their contents, and provide arguments to + * customize them. + */ + final Map prompts = new HashMap<>(); + + final Map completions = new HashMap<>(); + + Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + + public StatelessSyncSpecification(McpStatelessServerTransport transport) { + this.transport = transport; + } + + /** + * Sets the URI template manager factory to use for creating URI templates. This + * allows for custom URI template parsing and variable extraction. + * @param uriTemplateManagerFactory The factory to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if uriTemplateManagerFactory is null + */ + public StatelessSyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); + this.uriTemplateManagerFactory = uriTemplateManagerFactory; + return this; + } + + /** + * Sets the duration to wait for server responses before timing out requests. This + * timeout applies to all requests made through the client, including tool calls, + * resource access, and prompt operations. + * @param requestTimeout The duration to wait before timing out requests. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if requestTimeout is null + */ + public StatelessSyncSpecification requestTimeout(Duration requestTimeout) { + Assert.notNull(requestTimeout, "Request timeout must not be null"); + this.requestTimeout = requestTimeout; + return this; + } + + /** + * Sets the server implementation information that will be shared with clients + * during connection initialization. This helps with version compatibility, + * debugging, and server identification. + * @param serverInfo The server implementation details including name and version. + * Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverInfo is null + */ + public StatelessSyncSpecification serverInfo(McpSchema.Implementation serverInfo) { + Assert.notNull(serverInfo, "Server info must not be null"); + this.serverInfo = serverInfo; + return this; + } + + /** + * Sets the server implementation information using name and version strings. This + * is a convenience method alternative to + * {@link #serverInfo(McpSchema.Implementation)}. + * @param name The server name. Must not be null or empty. + * @param version The server version. Must not be null or empty. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if name or version is null or empty + * @see #serverInfo(McpSchema.Implementation) + */ + public StatelessSyncSpecification serverInfo(String name, String version) { + Assert.hasText(name, "Name must not be null or empty"); + Assert.hasText(version, "Version must not be null or empty"); + this.serverInfo = new McpSchema.Implementation(name, version); + return this; + } + + /** + * Sets the server instructions that will be shared with clients during connection + * initialization. These instructions provide guidance to the client on how to + * interact with this server. + * @param instructions The instructions text. Can be null or empty. + * @return This builder instance for method chaining + */ + public StatelessSyncSpecification instructions(String instructions) { + this.instructions = instructions; + return this; + } + + /** + * Sets the server capabilities that will be advertised to clients during + * connection initialization. Capabilities define what features the server + * supports, such as: + *
    + *
  • Tool execution + *
  • Resource access + *
  • Prompt handling + *
+ * @param serverCapabilities The server capabilities configuration. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if serverCapabilities is null + */ + public StatelessSyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { + Assert.notNull(serverCapabilities, "Server capabilities must not be null"); + this.serverCapabilities = serverCapabilities; + return this; + } + + /** + * Adds a single tool with its implementation handler to the server. This is a + * convenience method for registering individual tools without creating a + * {@link McpServerFeatures.SyncToolSpecification} explicitly. + * @param tool The tool definition including name, description, and schema. Must + * not be null. + * @param callHandler The function that implements the tool's logic. Must not be + * null. The function's first argument is an {@link McpSyncServerExchange} upon + * which the server can interact with the connected client. The second argument is + * the {@link McpSchema.CallToolRequest} object containing the tool call + * @return This builder instance for method chaining + * @throws IllegalArgumentException if tool or handler is null + */ + public StatelessSyncSpecification toolCall(McpSchema.Tool tool, + BiFunction callHandler) { + + Assert.notNull(tool, "Tool must not be null"); + Assert.notNull(callHandler, "Handler must not be null"); + assertNoDuplicateTool(tool.name()); + + this.tools.add(new McpStatelessServerFeatures.SyncToolSpecification(tool, callHandler)); + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using a List. This method + * is useful when tools are dynamically generated or loaded from a configuration + * source. + * @param toolSpecifications The list of tool specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + * @see #tools(McpStatelessServerFeatures.SyncToolSpecification...) + */ + public StatelessSyncSpecification tools(List toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + + return this; + } + + /** + * Adds multiple tools with their handlers to the server using varargs. This + * method provides a convenient way to register multiple tools inline. + * + *

+ * Example usage:

{@code
+		 * .tools(
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(calculatorTool).callTool(calculatorHandler).build(),
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(weatherTool).callTool(weatherHandler).build(),
+		 *     McpServerFeatures.SyncToolSpecification.builder().tool(fileManagerTool).callTool(fileManagerHandler).build()
+		 * )
+		 * }
+ * @param toolSpecifications The tool specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if toolSpecifications is null + */ + public StatelessSyncSpecification tools(McpStatelessServerFeatures.SyncToolSpecification... toolSpecifications) { + Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); + + for (var tool : toolSpecifications) { + assertNoDuplicateTool(tool.tool().name()); + this.tools.add(tool); + } + return this; + } + + private void assertNoDuplicateTool(String toolName) { + if (this.tools.stream().anyMatch(toolSpec -> toolSpec.tool().name().equals(toolName))) { + throw new IllegalArgumentException("Tool with name '" + toolName + "' is already registered."); + } + } + + /** + * Registers multiple resources with their handlers using a Map. This method is + * useful when resources are dynamically generated or loaded from a configuration + * source. + * @param resourceSpecifications Map of resource name to specification. Must not + * be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources( + Map resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); + this.resources.putAll(resourceSpecifications); + return this; + } + + /** + * Registers multiple resources with their handlers using a List. This method is + * useful when resources need to be added in bulk from a collection. + * @param resourceSpecifications List of resource specifications. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + * @see #resources(McpStatelessServerFeatures.SyncResourceSpecification...) + */ + public StatelessSyncSpecification resources(List resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Registers multiple resources with their handlers using varargs. This method + * provides a convenient way to register multiple resources inline. + * + *

+ * Example usage:

{@code
+		 * .resources(
+		 *     new McpServerFeatures.SyncResourceSpecification(fileResource, fileHandler),
+		 *     new McpServerFeatures.SyncResourceSpecification(dbResource, dbHandler),
+		 *     new McpServerFeatures.SyncResourceSpecification(apiResource, apiHandler)
+		 * )
+		 * }
+ * @param resourceSpecifications The resource specifications to add. Must not be + * null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceSpecifications is null + */ + public StatelessSyncSpecification resources(McpStatelessServerFeatures.SyncResourceSpecification... resourceSpecifications) { + Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); + for (var resource : resourceSpecifications) { + this.resources.put(resource.resource().uri(), resource); + } + return this; + } + + /** + * Sets the resource templates that define patterns for dynamic resource access. + * Templates use URI patterns with placeholders that can be filled at runtime. + * + *

+ * Example usage:

{@code
+		 * .resourceTemplates(
+		 *     new ResourceTemplate("file://{path}", "Access files by path"),
+		 *     new ResourceTemplate("db://{table}/{id}", "Access database records")
+		 * )
+		 * }
+ * @param resourceTemplates List of resource templates. If null, clears existing + * templates. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(ResourceTemplate...) + */ + public StatelessSyncSpecification resourceTemplates(List resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + this.resourceTemplates.addAll(resourceTemplates); + return this; + } + + /** + * Sets the resource templates using varargs for convenience. This is an + * alternative to {@link #resourceTemplates(List)}. + * @param resourceTemplates The resource templates to set. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if resourceTemplates is null. + * @see #resourceTemplates(List) + */ + public StatelessSyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { + Assert.notNull(resourceTemplates, "Resource templates must not be null"); + for (ResourceTemplate resourceTemplate : resourceTemplates) { + this.resourceTemplates.add(resourceTemplate); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using a Map. This method is + * useful when prompts are dynamically generated or loaded from a configuration + * source. + * + *

+ * Example usage:

{@code
+		 * .prompts(Map.of("analysis", new McpServerFeatures.SyncPromptSpecification(
+		 *     new Prompt("analysis", "Code analysis template"),
+		 *     request -> Mono.fromSupplier(() -> generateAnalysisPrompt(request))
+		 *         .map(GetPromptResult::new)
+		 * )));
+		 * }
+ * @param prompts Map of prompt name to specification. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts(Map prompts) { + Assert.notNull(prompts, "Prompts map must not be null"); + this.prompts.putAll(prompts); + return this; + } + + /** + * Registers multiple prompts with their handlers using a List. This method is + * useful when prompts need to be added in bulk from a collection. + * @param prompts List of prompt specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + * @see #prompts(McpStatelessServerFeatures.SyncPromptSpecification...) + */ + public StatelessSyncSpecification prompts(List prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple prompts with their handlers using varargs. This method + * provides a convenient way to register multiple prompts inline. + * + *

+ * Example usage:

{@code
+		 * .prompts(
+		 *     new McpServerFeatures.SyncPromptSpecification(analysisPrompt, analysisHandler),
+		 *     new McpServerFeatures.SyncPromptSpecification(summaryPrompt, summaryHandler),
+		 *     new McpServerFeatures.SyncPromptSpecification(reviewPrompt, reviewHandler)
+		 * )
+		 * }
+ * @param prompts The prompt specifications to add. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if prompts is null + */ + public StatelessSyncSpecification prompts(McpStatelessServerFeatures.SyncPromptSpecification... prompts) { + Assert.notNull(prompts, "Prompts list must not be null"); + for (var prompt : prompts) { + this.prompts.put(prompt.prompt().name(), prompt); + } + return this; + } + + /** + * Registers multiple completions with their handlers using a List. This method is + * useful when completions need to be added in bulk from a collection. + * @param completions List of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions(List completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Registers multiple completions with their handlers using varargs. This method + * is useful when completions are defined inline and added directly. + * @param completions Array of completion specifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if completions is null + */ + public StatelessSyncSpecification completions(McpStatelessServerFeatures.SyncCompletionSpecification... completions) { + Assert.notNull(completions, "Completions list must not be null"); + for (var completion : completions) { + this.completions.put(completion.referenceKey(), completion); + } + return this; + } + + /** + * Sets the object mapper to use for serializing and deserializing JSON messages. + * @param objectMapper the instance to use. Must not be null. + * @return This builder instance for method chaining. + * @throws IllegalArgumentException if objectMapper is null + */ + public StatelessSyncSpecification objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the JSON schema validator to use for validating tool and resource schemas. + * This ensures that the server's tools and resources conform to the expected + * schema definitions. + * @param jsonSchemaValidator The validator to use. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if jsonSchemaValidator is null + */ + public StatelessSyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); + this.jsonSchemaValidator = jsonSchemaValidator; + return this; + } + + /** + * Enable on "immediate execution" of the operations on the underlying + * {@link McpStatelessAsyncServer}. Defaults to false, which does blocking code offloading + * to prevent accidental blocking of the non-blocking transport. + *

+ * Do NOT set to true if the underlying transport is a non-blocking + * implementation. + * @param immediateExecution When true, do not offload work asynchronously. + * @return This builder instance for method chaining. + * + */ + public StatelessSyncSpecification immediateExecution(boolean immediateExecution) { + this.immediateExecution = immediateExecution; + return this; + } + + public McpStatelessSyncServer build() { + /* + McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, + this.rootsChangeHandlers, this.instructions); + McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, + this.immediateExecution); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + + var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); + + return new McpSyncServer(asyncServer, this.immediateExecution); + */ + var syncFeatures = new McpStatelessServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + var asyncFeatures = McpStatelessServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); + var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); + var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator + : new DefaultJsonSchemaValidator(mapper); + var asyncServer = new McpStatelessAsyncServer(this.transport, mapper, asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); + return new McpStatelessSyncServer(asyncServer, this.immediateExecution); + } } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java index 94ef1a6df..fdeb517f2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -6,19 +6,16 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpClientSession; +import io.modelcontextprotocol.spec.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; -import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; -import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.spec.McpTransportContext; import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; -import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -33,49 +30,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.function.BiFunction; -import java.util.function.Function; /** - * The Model Context Protocol (MCP) server implementation that provides asynchronous - * communication using Project Reactor's Mono and Flux types. - * - *

- * This server implements the MCP specification, enabling AI models to expose tools, - * resources, and prompts through a standardized interface. Key features include: - *

    - *
  • Asynchronous communication using reactive programming patterns - *
  • Dynamic tool registration and management - *
  • Resource handling with URI-based addressing - *
  • Prompt template management - *
  • Real-time client notifications for state changes - *
  • Structured logging with configurable severity levels - *
  • Support for client-side AI model sampling - *
- * - *

- * The server follows a lifecycle: - *

    - *
  1. Initialization - Accepts client connections and negotiates capabilities - *
  2. Normal Operation - Handles client requests and sends notifications - *
  3. Graceful Shutdown - Ensures clean connection termination - *
- * - *

- * This implementation uses Project Reactor for non-blocking operations, making it - * suitable for high-throughput scenarios and reactive applications. All operations return - * Mono or Flux types that can be composed into reactive pipelines. - * - *

- * The server supports runtime modification of its capabilities through methods like - * {@link #addTool}, {@link #addResource}, and {@link #addPrompt}, automatically notifying - * connected clients of changes when configured to do so. - * - * @author Christian Tzolov * @author Dariusz Jędrzejczyk - * @author Jihoon Kim - * @see McpServer - * @see McpSchema - * @see McpClientSession */ public class McpStatelessAsyncServer { @@ -91,36 +48,27 @@ public class McpStatelessAsyncServer { private final String instructions; - // TODO: all simple ones - private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); + private final CopyOnWriteArrayList tools = new CopyOnWriteArrayList<>(); private final CopyOnWriteArrayList resourceTemplates = new CopyOnWriteArrayList<>(); - private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); + private final ConcurrentHashMap resources = new ConcurrentHashMap<>(); - private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); - // FIXME: this field is deprecated and should be remvoed together with the - // broadcasting loggingNotification. - private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; - - private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); + private final ConcurrentHashMap completions = new ConcurrentHashMap<>(); private List protocolVersions = List.of(McpSchema.LATEST_PROTOCOL_VERSION); private McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); - /** - * Create a new McpAsyncServer with the given transport provider and capabilities. - * @param mcpTransportProvider The transport layer implementation for MCP - * communication. - * @param features The MCP server supported features. - * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization - */ - McpStatelessAsyncServer(McpStatelessServerTransport mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory) { - this.mcpTransportProvider = mcpTransportProvider; + private final JsonSchemaValidator jsonSchemaValidator; + + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, ObjectMapper objectMapper, + McpStatelessServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, + JsonSchemaValidator jsonSchemaValidator) { + this.mcpTransportProvider = mcpTransport; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); this.serverCapabilities = features.serverCapabilities(); @@ -131,13 +79,16 @@ public class McpStatelessAsyncServer { this.prompts.putAll(features.prompts()); this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; + this.jsonSchemaValidator = jsonSchemaValidator; Map> requestHandlers = new HashMap<>(); // Initialize request handlers for standard MCP methods // Ping MUST respond with an empty data, but not NULL response. - requestHandlers.put(McpSchema.METHOD_PING, params -> Mono.just(Map.of())); + requestHandlers.put(McpSchema.METHOD_PING, (ctx, params) -> Mono.just(Map.of())); + + requestHandlers.put(McpSchema.METHOD_INITIALIZE, asyncInitializeRequestHandler()); // Add tools API handlers if the tool capability is enabled if (this.serverCapabilities.tools() != null) { @@ -158,46 +109,26 @@ public class McpStatelessAsyncServer { requestHandlers.put(McpSchema.METHOD_PROMPT_GET, promptsGetRequestHandler()); } - // Add logging API handlers if the logging capability is enabled - if (this.serverCapabilities.logging() != null) { - requestHandlers.put(McpSchema.METHOD_LOGGING_SET_LEVEL, setLoggerRequestHandler()); - } - // Add completion API handlers if the completion capability is enabled if (this.serverCapabilities.completions() != null) { requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); } - Map notificationHandlers = new HashMap<>(); - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_INITIALIZED, (exchange, params) -> Mono.empty()); - - List, Mono>> rootsChangeConsumers = features - .rootsChangeConsumers(); - - if (Utils.isEmpty(rootsChangeConsumers)) { - rootsChangeConsumers = List.of((exchange, roots) -> Mono.fromRunnable(() -> logger - .warn("Roots list changed notification, but no consumers provided. Roots list changed: {}", roots))); - } - - notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, - asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - - mcpTransportProvider.setHandler(request -> - requestHandlers.get(request.method()).apply(request.params()) + mcpTransport.setRequestHandler((context, request) -> + requestHandlers.get(request.method()).apply(context, request.params()) .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) .onErrorResume(t -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, t.getMessage(), null))) )); - ); } // --------------------------------------- // Lifecycle Management // --------------------------------------- - private Mono asyncInitializeRequestHandler( - McpSchema.InitializeRequest initializeRequest) { - return Mono.defer(() -> { + private RequestHandler asyncInitializeRequestHandler() { + return (ctx, req) -> Mono.defer(() -> { + McpSchema.InitializeRequest initializeRequest = this.objectMapper.convertValue(req, McpSchema.InitializeRequest.class); + logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", initializeRequest.protocolVersion(), initializeRequest.capabilities(), initializeRequest.clientInfo()); @@ -255,18 +186,6 @@ public void close() { this.mcpTransportProvider.close(); } - private McpNotificationHandler asyncRootsListChangedNotificationHandler( - List, Mono>> rootsChangeConsumers) { - return (exchange, params) -> exchange.listRoots() - .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) - .onErrorResume(error -> { - logger.error("Error handling roots list change notification", error); - return Mono.empty(); - }) - .then()); - } - // --------------------------------------- // Tool Management // --------------------------------------- @@ -276,14 +195,14 @@ private McpNotificationHandler asyncRootsListChangedNotificationHandler( * @param toolSpecification The tool specification to add * @return Mono that completes when clients have been notified of the change */ - public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecification) { + public Mono addTool(McpStatelessServerFeatures.AsyncToolSpecification toolSpecification) { if (toolSpecification == null) { return Mono.error(new McpError("Tool specification must not be null")); } if (toolSpecification.tool() == null) { return Mono.error(new McpError("Tool must not be null")); } - if (toolSpecification.call() == null) { + if (toolSpecification.callHandler() == null) { return Mono.error(new McpError("Tool call handler must not be null")); } if (this.serverCapabilities.tools() == null) { @@ -300,9 +219,6 @@ public Mono addTool(McpServerFeatures.AsyncToolSpecification toolSpecifica this.tools.add(toolSpecification); logger.debug("Added tool handler: {}", toolSpecification.tool().name()); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } return Mono.empty(); }); } @@ -325,38 +241,26 @@ public Mono removeTool(String toolName) { .removeIf(toolSpecification -> toolSpecification.tool().name().equals(toolName)); if (removed) { logger.debug("Removed tool handler: {}", toolName); - if (this.serverCapabilities.tools().listChanged()) { - return notifyToolsListChanged(); - } return Mono.empty(); } return Mono.error(new McpError("Tool with name '" + toolName + "' not found")); }); } - /** - * Notifies clients that the list of available tools has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyToolsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_TOOLS_LIST_CHANGED, null); - } - - private Function> toolsListRequestHandler() { - return params -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); - + private RequestHandler toolsListRequestHandler() { + return (ctx, params) -> { + List tools = this.tools.stream().map(McpStatelessServerFeatures.AsyncToolSpecification::tool).toList(); return Mono.just(new McpSchema.ListToolsResult(tools, null)); }; } - private McpRequestHandler toolsCallRequestHandler() { - return (exchange, params) -> { + private RequestHandler toolsCallRequestHandler() { + return (ctx, params) -> { McpSchema.CallToolRequest callToolRequest = objectMapper.convertValue(params, new TypeReference() { }); - Optional toolSpecification = this.tools.stream() + Optional toolSpecification = this.tools.stream() .filter(tr -> callToolRequest.name().equals(tr.tool().name())) .findAny(); @@ -364,7 +268,7 @@ private McpRequestHandler toolsCallRequestHandler() { return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + return toolSpecification.map(tool -> tool.callHandler().apply(ctx, callToolRequest)) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -378,7 +282,7 @@ private McpRequestHandler toolsCallRequestHandler() { * @param resourceSpecification The resource handler to add * @return Mono that completes when clients have been notified of the change */ - public Mono addResource(McpServerFeatures.AsyncResourceSpecification resourceSpecification) { + public Mono addResource(McpStatelessServerFeatures.AsyncResourceSpecification resourceSpecification) { if (resourceSpecification == null || resourceSpecification.resource() == null) { return Mono.error(new McpError("Resource must not be null")); } @@ -393,9 +297,6 @@ public Mono addResource(McpServerFeatures.AsyncResourceSpecification resou "Resource with URI '" + resourceSpecification.resource().uri() + "' already exists")); } logger.debug("Added resource handler: {}", resourceSpecification.resource().uri()); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } return Mono.empty(); }); } @@ -414,47 +315,27 @@ public Mono removeResource(String resourceUri) { } return Mono.defer(() -> { - McpServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); + McpStatelessServerFeatures.AsyncResourceSpecification removed = this.resources.remove(resourceUri); if (removed != null) { logger.debug("Removed resource handler: {}", resourceUri); - if (this.serverCapabilities.resources().listChanged()) { - return notifyResourcesListChanged(); - } return Mono.empty(); } return Mono.error(new McpError("Resource with URI '" + resourceUri + "' not found")); }); } - /** - * Notifies clients that the list of available resources has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyResourcesListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_LIST_CHANGED, null); - } - - /** - * Notifies clients that the resources have updated. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification resourcesUpdatedNotification) { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_RESOURCES_UPDATED, - resourcesUpdatedNotification); - } - - private McpRequestHandler resourcesListRequestHandler() { - return (exchange, params) -> { + private RequestHandler resourcesListRequestHandler() { + return (ctx, params) -> { var resourceList = this.resources.values() .stream() - .map(McpServerFeatures.AsyncResourceSpecification::resource) + .map(McpStatelessServerFeatures.AsyncResourceSpecification::resource) .toList(); return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); }; } - private McpRequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono + private RequestHandler resourceTemplateListRequestHandler() { + return (ctx, params) -> Mono .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); } @@ -477,14 +358,14 @@ private List getResourceTemplates() { return list; } - private McpRequestHandler resourcesReadRequestHandler() { - return (exchange, params) -> { + private RequestHandler resourcesReadRequestHandler() { + return (ctx, params) -> { McpSchema.ReadResourceRequest resourceRequest = objectMapper.convertValue(params, new TypeReference() { }); var resourceUri = resourceRequest.uri(); - McpServerFeatures.AsyncResourceSpecification specification = this.resources.values() + McpStatelessServerFeatures.AsyncResourceSpecification specification = this.resources.values() .stream() .filter(resourceSpecification -> this.uriTemplateManagerFactory .create(resourceSpecification.resource().uri()) @@ -492,7 +373,7 @@ private McpRequestHandler resourcesReadRequestHand .findFirst() .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); - return specification.readHandler().apply(exchange, resourceRequest); + return specification.readHandler().apply(ctx, resourceRequest); }; } @@ -505,7 +386,7 @@ private McpRequestHandler resourcesReadRequestHand * @param promptSpecification The prompt handler to add * @return Mono that completes when clients have been notified of the change */ - public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpecification) { + public Mono addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification promptSpecification) { if (promptSpecification == null) { return Mono.error(new McpError("Prompt specification must not be null")); } @@ -514,7 +395,7 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe } return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification specification = this.prompts + McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts .putIfAbsent(promptSpecification.prompt().name(), promptSpecification); if (specification != null) { return Mono.error( @@ -523,12 +404,6 @@ public Mono addPrompt(McpServerFeatures.AsyncPromptSpecification promptSpe logger.debug("Added prompt handler: {}", promptSpecification.prompt().name()); - // Servers that declared the listChanged capability SHOULD send a - // notification, - // when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return notifyPromptsListChanged(); - } return Mono.empty(); }); } @@ -547,31 +422,18 @@ public Mono removePrompt(String promptName) { } return Mono.defer(() -> { - McpServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); + McpStatelessServerFeatures.AsyncPromptSpecification removed = this.prompts.remove(promptName); if (removed != null) { logger.debug("Removed prompt handler: {}", promptName); - // Servers that declared the listChanged capability SHOULD send a - // notification, when the list of available prompts changes - if (this.serverCapabilities.prompts().listChanged()) { - return this.notifyPromptsListChanged(); - } return Mono.empty(); } return Mono.error(new McpError("Prompt with name '" + promptName + "' not found")); }); } - /** - * Notifies clients that the list of available prompts has changed. - * @return A Mono that completes when all clients have been notified - */ - public Mono notifyPromptsListChanged() { - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_PROMPTS_LIST_CHANGED, null); - } - - private McpRequestHandler promptsListRequestHandler() { - return (exchange, params) -> { + private RequestHandler promptsListRequestHandler() { + return (ctx, params) -> { // TODO: Implement pagination // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, // new TypeReference() { @@ -579,80 +441,31 @@ private McpRequestHandler promptsListRequestHandler var promptList = this.prompts.values() .stream() - .map(McpServerFeatures.AsyncPromptSpecification::prompt) + .map(McpStatelessServerFeatures.AsyncPromptSpecification::prompt) .toList(); return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); }; } - private McpRequestHandler promptsGetRequestHandler() { - return (exchange, params) -> { + private RequestHandler promptsGetRequestHandler() { + return (ctx, params) -> { McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params, new TypeReference() { }); // Implement prompt retrieval logic here - McpServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); + McpStatelessServerFeatures.AsyncPromptSpecification specification = this.prompts.get(promptRequest.name()); if (specification == null) { return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); } - return specification.promptHandler().apply(exchange, promptRequest); - }; - } - - // --------------------------------------- - // Logging Management - // --------------------------------------- - - /** - * This implementation would, incorrectly, broadcast the logging message to all - * connected clients, using a single minLoggingLevel for all of them. Similar to the - * sampling and roots, the logging level should be set per client session and use the - * ServerExchange to send the logging message to the right client. - * @param loggingMessageNotification The logging message to send - * @return A Mono that completes when the notification has been sent - * @deprecated Use - * {@link McpAsyncServerExchange#loggingNotification(LoggingMessageNotification)} - * instead. - */ - @Deprecated - public Mono loggingNotification(LoggingMessageNotification loggingMessageNotification) { - - if (loggingMessageNotification == null) { - return Mono.error(new McpError("Logging message must not be null")); - } - - if (loggingMessageNotification.level().level() < minLoggingLevel.level()) { - return Mono.empty(); - } - - return this.mcpTransportProvider.notifyClients(McpSchema.METHOD_NOTIFICATION_MESSAGE, - loggingMessageNotification); - } - - private McpRequestHandler setLoggerRequestHandler() { - return (exchange, params) -> { - return Mono.defer(() -> { - - SetLevelRequest newMinLoggingLevel = objectMapper.convertValue(params, - new TypeReference() { - }); - - exchange.setMinLoggingLevel(newMinLoggingLevel.level()); - - // FIXME: this field is deprecated and should be removed together - // with the broadcasting loggingNotification. - this.minLoggingLevel = newMinLoggingLevel.level(); - - return Mono.just(Map.of()); - }); + return specification.promptHandler().apply(ctx, promptRequest); }; } - private McpRequestHandler completionCompleteRequestHandler() { - return (exchange, params) -> { + private RequestHandler completionCompleteRequestHandler() { + return (ctx, params) -> { McpSchema.CompleteRequest request = parseCompletionParams(params); if (request.ref() == null) { @@ -669,23 +482,21 @@ private McpRequestHandler completionCompleteRequestHan // check if the referenced resource exists if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + McpStatelessServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); if (promptSpec == null) { return Mono.error(new McpError("Prompt not found: " + promptReference.name())); } - if (!promptSpec.prompt() + if (promptSpec.prompt() .arguments() .stream() - .filter(arg -> arg.name().equals(argumentName)) - .findFirst() - .isPresent()) { + .noneMatch(arg -> arg.name().equals(argumentName))) { return Mono.error(new McpError("Argument not found: " + argumentName)); } } if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); + McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); if (resourceSpec == null) { return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); } @@ -697,13 +508,13 @@ private McpRequestHandler completionCompleteRequestHan } - McpServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); + McpStatelessServerFeatures.AsyncCompletionSpecification specification = this.completions.get(request.ref()); if (specification == null) { return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); } - return specification.completionHandler().apply(exchange, request); + return specification.completionHandler().apply(ctx, request); }; } @@ -752,7 +563,8 @@ void setProtocolVersions(List protocolVersions) { this.protocolVersions = protocolVersions; } - static interface RequestHandler extends Function> { + @FunctionalInterface + interface RequestHandler extends BiFunction> { } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java new file mode 100644 index 000000000..6813c658c --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java @@ -0,0 +1,377 @@ +/* + * Copyright 2024-2025 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpTransportContext; +import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiFunction; + +/** + * MCP stateless server features specification that a particular server can choose to support. + * + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessServerFeatures { + + /** + * Asynchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, Map resources, + List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, Map resources, + List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : List.of(); + this.resources = (resources != null) ? resources : Map.of(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : List.of(); + this.prompts = (prompts != null) ? prompts : Map.of(); + this.completions = (completions != null) ? completions : Map.of(); + this.instructions = instructions; + } + + /** + * Convert a synchronous specification into an asynchronous one and provide + * blocking code offloading to prevent accidental blocking of the non-blocking + * transport. + * @param syncSpec a potentially blocking, synchronous specification. + * @param immediateExecution when true, do not offload. Do NOT set to true when + * using a non-blocking transport. + * @return a specification which is protected from blocking calls specified by the + * user. + */ + static Async fromSync(Sync syncSpec, boolean immediateExecution) { + List tools = new ArrayList<>(); + for (var tool : syncSpec.tools()) { + tools.add(AsyncToolSpecification.fromSync(tool, immediateExecution)); + } + + Map resources = new HashMap<>(); + syncSpec.resources().forEach((key, resource) -> { + resources.put(key, AsyncResourceSpecification.fromSync(resource, immediateExecution)); + }); + + Map prompts = new HashMap<>(); + syncSpec.prompts().forEach((key, prompt) -> { + prompts.put(key, AsyncPromptSpecification.fromSync(prompt, immediateExecution)); + }); + + Map completions = new HashMap<>(); + syncSpec.completions().forEach((key, completion) -> { + completions.put(key, AsyncCompletionSpecification.fromSync(completion, immediateExecution)); + }); + + return new Async(syncSpec.serverInfo(), syncSpec.serverCapabilities(), tools, resources, + syncSpec.resourceTemplates(), prompts, completions, syncSpec.instructions()); + } + } + + /** + * Synchronous server features specification. + * + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + /** + * Create an instance and validate the arguments. + * @param serverInfo The server implementation details + * @param serverCapabilities The server capabilities + * @param tools The list of tool specifications + * @param resources The map of resource specifications + * @param resourceTemplates The list of resource templates + * @param prompts The map of prompt specifications + * @param instructions The server instructions text + */ + Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, + List tools, + Map resources, + List resourceTemplates, + Map prompts, + Map completions, + String instructions) { + + Assert.notNull(serverInfo, "Server info must not be null"); + + this.serverInfo = serverInfo; + this.serverCapabilities = (serverCapabilities != null) ? serverCapabilities + : new McpSchema.ServerCapabilities(null, // completions + null, // experimental + new McpSchema.ServerCapabilities.LoggingCapabilities(), // Enable + // logging + // by + // default + !Utils.isEmpty(prompts) ? new McpSchema.ServerCapabilities.PromptCapabilities(false) : null, + !Utils.isEmpty(resources) + ? new McpSchema.ServerCapabilities.ResourceCapabilities(false, false) : null, + !Utils.isEmpty(tools) ? new McpSchema.ServerCapabilities.ToolCapabilities(false) : null); + + this.tools = (tools != null) ? tools : new ArrayList<>(); + this.resources = (resources != null) ? resources : new HashMap<>(); + this.resourceTemplates = (resourceTemplates != null) ? resourceTemplates : new ArrayList<>(); + this.prompts = (prompts != null) ? prompts : new HashMap<>(); + this.completions = (completions != null) ? completions : new HashMap<>(); + this.instructions = instructions; + } + + } + + /** + * Specification of a tool with its asynchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. Each tool + * represents a specific capability. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning the result. + */ + public record AsyncToolSpecification(McpSchema.Tool tool, + BiFunction> callHandler) { + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec) { + return fromSync(syncToolSpec, false); + } + + static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boolean immediate) { + + // FIXME: This is temporary, proper validation should be implemented + if (syncToolSpec == null) { + return null; + } + + BiFunction> callHandler = (ctx, req) -> { + var toolResult = Mono + .fromCallable(() -> syncToolSpec.callHandler().apply(ctx, req)); + return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); + }; + + return new AsyncToolSpecification(syncToolSpec.tool(), callHandler); + } + } + + /** + * Specification of a resource with its asynchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
    + *
  • File contents + *
  • Database records + *
  • API responses + *
  • System information + *
  • Application state + *
+ * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record AsyncResourceSpecification(McpSchema.Resource resource, + BiFunction> readHandler) { + + static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (resource == null) { + return null; + } + return new AsyncResourceSpecification(resource.resource(), (ctx, req) -> { + var resourceResult = Mono.fromCallable(() -> resource.readHandler().apply(ctx, req)); + return immediateExecution ? resourceResult : resourceResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a prompt template with its asynchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
    + *
  • Consistent message formatting + *
  • Parameter substitution + *
  • Context injection + *
  • Response formatting + *
  • Instruction templating + *
+ * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a {@link McpSchema.GetPromptRequest}. + */ + public record AsyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction> promptHandler) { + + static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt, boolean immediateExecution) { + // FIXME: This is temporary, proper validation should be implemented + if (prompt == null) { + return null; + } + return new AsyncPromptSpecification(prompt.prompt(), (ctx, req) -> { + var promptResult = Mono.fromCallable(() -> prompt.promptHandler().apply(ctx, req)); + return immediateExecution ? promptResult : promptResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a completion handler function with asynchronous execution support. + * Completions generate AI model outputs based on prompt or resource references and + * user-provided arguments. This abstraction enables: + *
    + *
  • Customizable response generation logic + *
  • Parameter-driven template expansion + *
  • Dynamic interaction with connected clients + *
+ * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The asynchronous function that processes completion + * requests and returns results. The function's argument is a {@link McpSchema.CompleteRequest}. + */ + public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction> completionHandler) { + + /** + * Converts a synchronous {@link SyncCompletionSpecification} into an + * {@link AsyncCompletionSpecification} by wrapping the handler in a bounded + * elastic scheduler for safe non-blocking execution. + * @param completion the synchronous completion specification + * @return an asynchronous wrapper of the provided sync specification, or + * {@code null} if input is null + */ + static AsyncCompletionSpecification fromSync(SyncCompletionSpecification completion, + boolean immediateExecution) { + if (completion == null) { + return null; + } + return new AsyncCompletionSpecification(completion.referenceKey(), (ctx, req) -> { + var completionResult = Mono.fromCallable( + () -> completion.completionHandler().apply(ctx, req)); + return immediateExecution ? completionResult + : completionResult.subscribeOn(Schedulers.boundedElastic()); + }); + } + } + + /** + * Specification of a tool with its synchronous handler function. Tools are the + * primary way for MCP servers to expose functionality to AI models. + * + * @param tool The tool definition including name, description, and parameter schema + * @param callHandler The function that implements the tool's logic, receiving a + * {@link CallToolRequest} and returning results. + */ + public record SyncToolSpecification(McpSchema.Tool tool, + BiFunction callHandler) { + } + + /** + * Specification of a resource with its synchronous handler function. Resources + * provide context to AI models by exposing data such as: + *
    + *
  • File contents + *
  • Database records + *
  • API responses + *
  • System information + *
  • Application state + *
+ * + * @param resource The resource definition including name, description, and MIME type + * @param readHandler The function that handles resource read requests. The function's + * argument is a {@link McpSchema.ReadResourceRequest}. + */ + public record SyncResourceSpecification(McpSchema.Resource resource, + BiFunction readHandler) { + } + + /** + * Specification of a prompt template with its synchronous handler function. Prompts + * provide structured templates for AI model interactions, supporting: + *
    + *
  • Consistent message formatting + *
  • Parameter substitution + *
  • Context injection + *
  • Response formatting + *
  • Instruction templating + *
+ * + * @param prompt The prompt definition including name and description + * @param promptHandler The function that processes prompt requests and returns + * formatted templates. The function's argument is a {@link McpSchema.GetPromptRequest}. + */ + public record SyncPromptSpecification(McpSchema.Prompt prompt, + BiFunction promptHandler) { + } + + /** + * Specification of a completion handler function with synchronous execution support. + * + * @param referenceKey The unique key representing the completion reference. + * @param completionHandler The synchronous function that processes completion + * requests and returns results. The argument is a {@link McpSchema.CompleteRequest}. + */ + public record SyncCompletionSpecification(McpSchema.CompleteReference referenceKey, + BiFunction completionHandler) { + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java new file mode 100644 index 000000000..35eb4bff8 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java @@ -0,0 +1,152 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import io.modelcontextprotocol.spec.McpStatelessServerTransport; +import io.modelcontextprotocol.spec.McpTransportContext; +import io.modelcontextprotocol.util.DeafaultMcpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + +/** + * @author Dariusz Jędrzejczyk + */ +public class McpStatelessSyncServer { + + private static final Logger logger = LoggerFactory.getLogger(McpStatelessSyncServer.class); + + private final McpStatelessAsyncServer asyncServer; + + private final boolean immediateExecution; + + McpStatelessSyncServer(McpStatelessAsyncServer asyncServer) { + this(asyncServer, false); + } + + McpStatelessSyncServer(McpStatelessAsyncServer asyncServer, boolean immediateExecution) { + this.asyncServer = asyncServer; + this.immediateExecution = immediateExecution; + } + + /** + * Get the server capabilities that define the supported features and functionality. + * @return The server capabilities + */ + public McpSchema.ServerCapabilities getServerCapabilities() { + return this.asyncServer.getServerCapabilities(); + } + + /** + * Get the server implementation information. + * @return The server implementation details + */ + public McpSchema.Implementation getServerInfo() { + return this.asyncServer.getServerInfo(); + } + + /** + * Gracefully closes the server, allowing any in-progress operations to complete. + * @return A Mono that completes when the server has been closed + */ + public Mono closeGracefully() { + return this.asyncServer.closeGracefully(); + } + + /** + * Close the server immediately. + */ + public void close() { + this.asyncServer.close(); + } + + /** + * Add a new tool specification at runtime. + * @param toolSpecification The tool specification to add + * @return Mono that completes when clients have been notified of the change + */ + public void addTool(McpStatelessServerFeatures.SyncToolSpecification toolSpecification) { + this.asyncServer + .addTool(McpStatelessServerFeatures.AsyncToolSpecification.fromSync(toolSpecification, this.immediateExecution)) + .block(); + } + + /** + * Remove a tool handler at runtime. + * @param toolName The name of the tool handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public void removeTool(String toolName) { + this.asyncServer.removeTool(toolName).block(); + } + + /** + * Add a new resource handler at runtime. + * @param resourceSpecification The resource handler to add + * @return Mono that completes when clients have been notified of the change + */ + public void addResource(McpStatelessServerFeatures.SyncResourceSpecification resourceSpecification) { + this.asyncServer + .addResource(McpStatelessServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, this.immediateExecution)) + .block(); + } + + /** + * Remove a resource handler at runtime. + * @param resourceUri The URI of the resource handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public void removeResource(String resourceUri) { + this.asyncServer.removeResource(resourceUri).block(); + } + + /** + * Add a new prompt handler at runtime. + * @param promptSpecification The prompt handler to add + * @return Mono that completes when clients have been notified of the change + */ + public void addPrompt(McpStatelessServerFeatures.SyncPromptSpecification promptSpecification) { + this.asyncServer + .addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, this.immediateExecution)) + .block(); + } + + /** + * Remove a prompt handler at runtime. + * @param promptName The name of the prompt handler to remove + * @return Mono that completes when clients have been notified of the change + */ + public void removePrompt(String promptName) { + this.asyncServer.removePrompt(promptName).block(); + } + + /** + * This method is package-private and used for test only. Should not be called by user + * code. + * @param protocolVersions the Client supported protocol versions. + */ + void setProtocolVersions(List protocolVersions) { + this.asyncServer.setProtocolVersions(protocolVersions); + } +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index dad1e4c19..0963a3761 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -6,6 +6,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpTransportContext; /** * Represents a synchronous exchange with a Model Context Protocol (MCP) client. The @@ -43,6 +44,9 @@ public McpSchema.Implementation getClientInfo() { return this.exchange.getClientInfo(); } + public McpTransportContext transportContext() { + return this.exchange.transportContext(); + } /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java index 552ef7f17..afdbff472 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java @@ -18,7 +18,6 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import io.modelcontextprotocol.util.Assert; import jakarta.servlet.AsyncContext; import jakarta.servlet.ServletException; @@ -61,7 +60,7 @@ */ @WebServlet(asyncSupported = true) -public class HttpServletSseServerTransportProvider extends HttpServlet implements McpSingleSessionServerTransportProvider { +public class HttpServletSseServerTransportProvider extends HttpServlet implements McpServerTransportProvider { /** Logger for this class */ private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java new file mode 100644 index 000000000..cf40b9c5e --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java @@ -0,0 +1,30 @@ +package io.modelcontextprotocol.spec; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class DefaultMcpTransportContext implements McpTransportContext { + + private final Map storage; + + public DefaultMcpTransportContext() { + this.storage = new ConcurrentHashMap<>(); + } + DefaultMcpTransportContext(Map storage) { + this.storage = storage; + } + + @Override + public Object get(String key) { + return this.storage.get(key); + } + + @Override + public void put(String key, Object value) { + this.storage.put(key, value); + } + + public McpTransportContext copy() { + return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage)); + } +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index a90c615f7..1c830d1b8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -1,58 +1,12 @@ package io.modelcontextprotocol.spec; -import java.util.Map; - -import reactor.core.publisher.Mono; - -/** - * The core building block providing the server-side MCP transport. Implement this - * interface to bridge between a particular server-side technology and the MCP server - * transport layer. - * - *

- * The lifecycle of the provider dictates that it be created first, upon application - * startup, and then passed into either - * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or - * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As - * a result of the MCP server creation, the provider will be notified of a - * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication - * between a newly connected client and the server. The provider's responsibility is to - * create instances of {@link McpServerTransport} that the session will utilise during the - * session lifetime. - * - *

- * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or - * {@link #closeGracefully()} are called as part of the normal application shutdown event. - * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where - * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} - * closes the provided transport. - * - * @author Dariusz Jędrzejczyk - */ -public interface McpServerTransportProvider { - - /** - * Sends a notification to all connected clients. - * @param method the name of the notification method to be called on the clients - * @param params parameters to be sent with the notification - * @return a Mono that completes when the notification has been broadcast - * @see McpSession#sendNotification(String, Map) - */ - Mono notifyClients(String method, Object params); - - /** - * Immediately closes all the transports with connected clients and releases any - * associated resources. - */ - default void close() { - this.closeGracefully().subscribe(); - } - - /** - * Gracefully closes all the transports with connected clients and releases any - * associated resources asynchronously. - * @return a {@link Mono} that completes when the connections have been closed. - */ - Mono closeGracefully(); - +public interface McpServerTransportProvider extends McpServerTransportProviderBase { + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpServerSession.Factory sessionFactory); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java new file mode 100644 index 000000000..87e7d6441 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProviderBase.java @@ -0,0 +1,58 @@ +package io.modelcontextprotocol.spec; + +import java.util.Map; + +import reactor.core.publisher.Mono; + +/** + * The core building block providing the server-side MCP transport. Implement this + * interface to bridge between a particular server-side technology and the MCP server + * transport layer. + * + *

+ * The lifecycle of the provider dictates that it be created first, upon application + * startup, and then passed into either + * {@link io.modelcontextprotocol.server.McpServer#sync(McpServerTransportProvider)} or + * {@link io.modelcontextprotocol.server.McpServer#async(McpServerTransportProvider)}. As + * a result of the MCP server creation, the provider will be notified of a + * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication + * between a newly connected client and the server. The provider's responsibility is to + * create instances of {@link McpServerTransport} that the session will utilise during the + * session lifetime. + * + *

+ * Finally, the {@link McpServerTransport}s can be closed in bulk when {@link #close()} or + * {@link #closeGracefully()} are called as part of the normal application shutdown event. + * Individual {@link McpServerTransport}s can also be closed on a per-session basis, where + * the {@link McpServerSession#close()} or {@link McpServerSession#closeGracefully()} + * closes the provided transport. + * + * @author Dariusz Jędrzejczyk + */ +public interface McpServerTransportProviderBase { + + /** + * Sends a notification to all connected clients. + * @param method the name of the notification method to be called on the clients + * @param params parameters to be sent with the notification + * @return a Mono that completes when the notification has been broadcast + * @see McpSession#sendNotification(String, Map) + */ + Mono notifyClients(String method, Object params); + + /** + * Immediately closes all the transports with connected clients and releases any + * associated resources. + */ + default void close() { + this.closeGracefully().subscribe(); + } + + /** + * Gracefully closes all the transports with connected clients and releases any + * associated resources asynchronously. + * @return a {@link Mono} that completes when the connections have been closed. + */ + Mono closeGracefully(); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSingleSessionServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSingleSessionServerTransportProvider.java deleted file mode 100644 index 762968dc9..000000000 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSingleSessionServerTransportProvider.java +++ /dev/null @@ -1,12 +0,0 @@ -package io.modelcontextprotocol.spec; - -public interface McpSingleSessionServerTransportProvider extends McpServerTransportProvider { - /** - * Sets the session factory that will be used to create sessions for new clients. An - * implementation of the MCP server MUST call this method before any MCP interactions - * take place. - * - * @param sessionFactory the session factory to be used for initiating client sessions - */ - void setSessionFactory(McpServerSession.Factory sessionFactory); -} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java index 6c184b724..7668e00b3 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -2,11 +2,11 @@ import reactor.core.publisher.Mono; -import java.util.function.Function; +import java.util.function.BiFunction; public interface McpStatelessServerTransport { - void setHandler(Function> message); + void setRequestHandler(BiFunction> message); /** * Immediately closes all the transports with connected clients and releases any diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index 53f242955..b61928544 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -47,7 +47,7 @@ public class McpStreamableServerSession implements McpSession { private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); - private final AtomicReference genericStreamRef = new AtomicReference<>(); + private final AtomicReference listeningStreamRef = new AtomicReference<>(); public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, Duration requestTimeout, @@ -73,23 +73,23 @@ private String generateRequestId() { @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { return Mono.defer(() -> { - McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); - return genericStream != null ? genericStream.sendRequest(method, requestParams, typeRef) : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); + McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); + return listeningStream != null ? listeningStream.sendRequest(method, requestParams, typeRef) : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); }); } @Override public Mono sendNotification(String method, Object params) { return Mono.defer(() -> { - McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); - return genericStream != null ? genericStream.sendNotification(method, params) : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); + McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); + return listeningStream != null ? listeningStream.sendNotification(method, params) : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); }); } - public McpStreamableServerSessionStream newStream(McpServerTransport transport) { - McpStreamableServerSessionStream genericStream = new McpStreamableServerSessionStream(transport); - this.genericStreamRef.set(genericStream); - return genericStream; + public McpStreamableServerSessionStream listeningStream(McpServerTransport transport) { + McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); + this.listeningStreamRef.set(listeningStream); + return listeningStream; } // TODO: keep track of history by keeping a map from eventId to stream and then iterate over the events using the lastEventId @@ -97,29 +97,34 @@ public Flux replay(Object lastEventId) { return Flux.empty(); } - public Mono handleStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpServerTransport transport) { - McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport); - McpRequestHandler requestHandler = McpStreamableServerSession.this.requestHandlers.get(jsonrpcRequest.method()); - // TODO: delegate to stream, which upon successful response should close remove itself from the registry and also close the underlying transport (sink) - if (requestHandler == null) { - MethodNotFoundError error = getMethodNotFoundError(jsonrpcRequest.method()); - return transport.sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, - error.message(), error.data()))); - } - return requestHandler.handle(new McpAsyncServerExchange(stream, clientCapabilities.get(), clientInfo.get()), jsonrpcRequest.params()) - .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, null)) - .flatMap(transport::sendMessage).then(transport.closeGracefully()); + public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpServerTransport transport) { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); + + McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport); + McpRequestHandler requestHandler = McpStreamableServerSession.this.requestHandlers.get(jsonrpcRequest.method()); + // TODO: delegate to stream, which upon successful response should close remove itself from the registry and also close the underlying transport (sink) + if (requestHandler == null) { + MethodNotFoundError error = getMethodNotFoundError(jsonrpcRequest.method()); + return transport.sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); + } + return requestHandler.handle(new McpAsyncServerExchange(stream, clientCapabilities.get(), clientInfo.get(), transportContext), jsonrpcRequest.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, null)) + .flatMap(transport::sendMessage).then(transport.closeGracefully()); + }); } public Mono accept(McpSchema.JSONRPCNotification notification) { - return Mono.defer(() -> { + return Mono.deferContextual(ctx -> { + McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); McpNotificationHandler notificationHandler = this.notificationHandlers.get(notification.method()); if (notificationHandler == null) { logger.error("No handler registered for notification method: {}", notification.method()); return Mono.empty(); } - McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); - return notificationHandler.handle(new McpAsyncServerExchange(genericStream != null ? genericStream : DisconnectedMcpSession.INSTANCE, this.clientCapabilities.get(), this.clientInfo.get()), notification.params()); + McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); + return notificationHandler.handle(new McpAsyncServerExchange(listeningStream != null ? listeningStream : MissingMcpTransportSession.INSTANCE, this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.params()); }); } @@ -149,16 +154,16 @@ private MethodNotFoundError getMethodNotFoundError(String method) { @Override public Mono closeGracefully() { return Mono.defer(() -> { - McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); - return genericStream != null ? genericStream.closeGracefully() : Mono.empty(); // TODO: Also close all the open streams + McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); + return listeningStream != null ? listeningStream.closeGracefully() : Mono.empty(); // TODO: Also close all the open streams }); } @Override public void close() { - McpStreamableServerSessionStream genericStream = this.genericStreamRef.get(); - if (genericStream != null) { - genericStream.close(); + McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); + if (listeningStream != null) { + listeningStream.close(); } // TODO: Also close all open streams } @@ -251,7 +256,7 @@ public Mono closeGracefully() { this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); this.pendingResponses.clear(); // If this was the generic stream, reset it - McpStreamableServerSession.this.genericStreamRef.compareAndExchange(this, null); + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, null); McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); return this.transport.closeGracefully(); }); @@ -262,7 +267,7 @@ public void close() { this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); this.pendingResponses.clear(); // If this was the generic stream, reset it - McpStreamableServerSession.this.genericStreamRef.compareAndExchange(this, null); + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, null); McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); this.transport.close(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java index 22b618440..088cd7547 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java @@ -29,7 +29,7 @@ * * @author Dariusz Jędrzejczyk */ -public interface McpStreamableServerTransportProvider extends McpServerTransportProvider { +public interface McpStreamableServerTransportProvider extends McpServerTransportProviderBase { /** * Sets the session factory that will be used to create sessions for new clients. An diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java new file mode 100644 index 000000000..45e7b3502 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java @@ -0,0 +1,12 @@ +package io.modelcontextprotocol.spec; + +public interface McpTransportContext { + + String KEY = "MCP_TRANSPORT_CONTEXT"; + + McpTransportContext EMPTY = new DefaultMcpTransportContext(); + + Object get(String key); + void put(String key, Object value); + McpTransportContext copy(); +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DisconnectedMcpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java similarity index 79% rename from mcp/src/main/java/io/modelcontextprotocol/spec/DisconnectedMcpSession.java rename to mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java index 998dc4a65..46e588eaf 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DisconnectedMcpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java @@ -3,9 +3,9 @@ import com.fasterxml.jackson.core.type.TypeReference; import reactor.core.publisher.Mono; -public class DisconnectedMcpSession implements McpSession { +public class MissingMcpTransportSession implements McpSession { - public static final DisconnectedMcpSession INSTANCE = new DisconnectedMcpSession(); + public static final MissingMcpTransportSession INSTANCE = new MissingMcpTransportSession(); @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { From 1645ecb73f34581a7031f26d2cc7d09edae51488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 22 Jul 2025 15:36:32 +0200 Subject: [PATCH 3/8] Formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../server/McpAsyncServer.java | 11 +- .../server/McpAsyncServerExchange.java | 7 +- .../server/McpInitRequestHandler.java | 13 +- .../server/McpNotificationHandler.java | 17 +-- .../server/McpRequestHandler.java | 19 ++- .../server/McpServer.java | 130 ++++++++++------- .../server/McpStatelessAsyncServer.java | 39 ++--- .../server/McpStatelessServerFeatures.java | 41 +++--- .../server/McpStatelessSyncServer.java | 16 +- .../server/McpSyncServerExchange.java | 1 + ...aultMcpStreamableServerSessionFactory.java | 45 +++--- .../spec/DefaultMcpTransportContext.java | 46 +++--- .../spec/McpServerSession.java | 7 +- .../spec/McpServerTransportProvider.java | 17 ++- .../spec/McpStatelessServerTransport.java | 3 +- .../spec/McpStreamableServerSession.java | 137 +++++++++++------- .../McpStreamableServerTransportProvider.java | 7 +- .../spec/McpTransportContext.java | 13 +- .../spec/MissingMcpTransportSession.java | 41 +++--- 19 files changed, 348 insertions(+), 262 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index b4efad266..b12200ed5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -127,8 +127,8 @@ public class McpAsyncServer { * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -151,8 +151,8 @@ public class McpAsyncServer { } McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, ObjectMapper objectMapper, - McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransportProvider; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -169,7 +169,8 @@ public class McpAsyncServer { Map> requestHandlers = prepareRequestHandlers(); Map notificationHandlers = prepareNotificationHandlers(features); - mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout, this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory(new DefaultMcpStreamableServerSessionFactory(requestTimeout, + this::asyncInitializeRequestHandler, requestHandlers, notificationHandlers)); } private Map prepareNotificationHandlers(McpServerFeatures.Async features) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 0d2a0a37e..ec5713261 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -57,7 +57,7 @@ public class McpAsyncServerExchange { * @param clientInfo The client implementation information. */ public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, - McpSchema.Implementation clientInfo) { + McpSchema.Implementation clientInfo) { this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; @@ -69,11 +69,12 @@ public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities c * @param session The server session representing a 1-1 interaction. * @param clientCapabilities The client capabilities that define the supported * features and functionality. - * @param transportContext context associated with the client as extracted from the transport + * @param transportContext context associated with the client as extracted from the + * transport * @param clientInfo The client implementation information. */ public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, - McpSchema.Implementation clientInfo, McpTransportContext transportContext) { + McpSchema.Implementation clientInfo, McpTransportContext transportContext) { this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java index a6063a8b2..609744637 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpInitRequestHandler.java @@ -8,12 +8,11 @@ */ public interface McpInitRequestHandler { - /** - * Handles the initialization request. - * - * @param initializeRequest the initialization request by the client - * @return a Mono that will emit the result of the initialization - */ - Mono handle(McpSchema.InitializeRequest initializeRequest); + /** + * Handles the initialization request. + * @param initializeRequest the initialization request by the client + * @return a Mono that will emit the result of the initialization + */ + Mono handle(McpSchema.InitializeRequest initializeRequest); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java index 492454908..6b1061c03 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpNotificationHandler.java @@ -7,14 +7,13 @@ */ public interface McpNotificationHandler { - /** - * Handles a notification from the client. - * - * @param exchange the exchange associated with the client that allows calling - * back to the connected client or inspecting its capabilities. - * @param params the parameters of the notification. - * @return a Mono that completes once the notification is handled. - */ - Mono handle(McpAsyncServerExchange exchange, Object params); + /** + * Handles a notification from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the notification. + * @return a Mono that completes once the notification is handled. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java index c95af472a..c9d70ad04 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpRequestHandler.java @@ -6,18 +6,17 @@ * A handler for client-initiated requests. * * @param the type of the response that is expected as a result of handling the - * request. + * request. */ public interface McpRequestHandler { - /** - * Handles a request from the client. - * - * @param exchange the exchange associated with the client that allows calling - * back to the connected client or inspecting its capabilities. - * @param params the parameters of the request. - * @return a Mono that will emit the response to the request. - */ - Mono handle(McpAsyncServerExchange exchange, Object params); + /** + * Handles a request from the client. + * @param exchange the exchange associated with the client that allows calling back to + * the connected client or inspecting its capabilities. + * @param params the parameters of the request. + * @return a Mono that will emit the response to the request. + */ + Mono handle(McpAsyncServerExchange exchange, Object params); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index af2b8d92e..d5ee64758 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -193,6 +193,7 @@ static StatelessSyncSpecification sync(McpStatelessServerTransport transport) { } class SingleSessionAsyncSpecification extends AsyncSpecification { + private final McpServerTransportProvider transportProvider; private SingleSessionAsyncSpecification(McpServerTransportProvider transportProvider) { @@ -215,9 +216,11 @@ public McpAsyncServer build() { return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); } + } class StreamableServerAsyncSpecification extends AsyncSpecification { + private final McpStreamableServerTransportProvider transportProvider; public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider transportProvider) { @@ -239,6 +242,7 @@ public McpAsyncServer build() { return new McpAsyncServer(this.transportProvider, mapper, features, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); } + } /** @@ -416,7 +420,7 @@ public S capabilities(McpSchema.ServerCapabilities serverCapabilities) { */ @Deprecated public S tool(McpSchema.Tool tool, - BiFunction, Mono> handler) { + BiFunction, Mono> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); assertNoDuplicateTool(tool.name()); @@ -440,7 +444,7 @@ public S tool(McpSchema.Tool tool, * @throws IllegalArgumentException if tool or handler is null */ public S toolCall(McpSchema.Tool tool, - BiFunction> callHandler) { + BiFunction> callHandler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(callHandler, "Handler must not be null"); @@ -515,8 +519,7 @@ private void assertNoDuplicateTool(String toolName) { * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public S resources( - Map resourceSpecifications) { + public S resources(Map resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); this.resources.putAll(resourceSpecifications); return self(); @@ -706,8 +709,7 @@ public S completions(McpServerFeatures.AsyncCompletionSpecification... completio * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ - public S rootsChangeHandler( - BiFunction, Mono> handler) { + public S rootsChangeHandler(BiFunction, Mono> handler) { Assert.notNull(handler, "Consumer must not be null"); this.rootsChangeHandlers.add(handler); return self(); @@ -773,6 +775,7 @@ public S jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { } class SingleSessionSyncSpecification extends SyncSpecification { + private final McpServerTransportProvider transportProvider; private SingleSessionSyncSpecification(McpServerTransportProvider transportProvider) { @@ -780,7 +783,6 @@ private SingleSessionSyncSpecification(McpServerTransportProvider transportProvi this.transportProvider = transportProvider; } - /** * Builds a synchronous MCP server that provides blocking operations. * @return A new instance of {@link McpSyncServer} configured with this builder's @@ -801,9 +803,11 @@ public McpSyncServer build() { return new McpSyncServer(asyncServer, this.immediateExecution); } + } class StreamableSyncSpecification extends SyncSpecification { + private final McpStreamableServerTransportProvider transportProvider; private StreamableSyncSpecification(McpStreamableServerTransportProvider transportProvider) { @@ -811,7 +815,6 @@ private StreamableSyncSpecification(McpStreamableServerTransportProvider transpo this.transportProvider = transportProvider; } - /** * Builds a synchronous MCP server that provides blocking operations. * @return A new instance of {@link McpSyncServer} configured with this builder's @@ -832,6 +835,7 @@ public McpSyncServer build() { return new McpSyncServer(asyncServer, this.immediateExecution); } + } /** @@ -1109,8 +1113,7 @@ private void assertNoDuplicateTool(String toolName) { * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.SyncResourceSpecification...) */ - public S resources( - Map resourceSpecifications) { + public S resources(Map resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); this.resources.putAll(resourceSpecifications); return self(); @@ -1317,8 +1320,7 @@ public S rootsChangeHandler(BiConsumer>> handlers) { + public S rootsChangeHandlers(List>> handlers) { Assert.notNull(handlers, "Handlers list must not be null"); this.rootsChangeHandlers.addAll(handlers); return self(); @@ -1333,8 +1335,7 @@ public S rootsChangeHandlers( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandlers(List) */ - public S rootsChangeHandlers( - BiConsumer>... handlers) { + public S rootsChangeHandlers(BiConsumer>... handlers) { Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(List.of(handlers)); } @@ -1372,6 +1373,7 @@ public S immediateExecution(boolean immediateExecution) { this.immediateExecution = immediateExecution; return self(); } + } class StatelessAsyncSpecification { @@ -1434,7 +1436,8 @@ public StatelessAsyncSpecification(McpStatelessServerTransport transport) { * @return This builder instance for method chaining * @throws IllegalArgumentException if uriTemplateManagerFactory is null */ - public StatelessAsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + public StatelessAsyncSpecification uriTemplateManagerFactory( + McpUriTemplateManagerFactory uriTemplateManagerFactory) { Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); this.uriTemplateManagerFactory = uriTemplateManagerFactory; return this; @@ -1533,7 +1536,7 @@ public StatelessAsyncSpecification capabilities(McpSchema.ServerCapabilities ser * @throws IllegalArgumentException if tool or handler is null */ public StatelessAsyncSpecification toolCall(McpSchema.Tool tool, - BiFunction> callHandler) { + BiFunction> callHandler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(callHandler, "Handler must not be null"); @@ -1554,7 +1557,8 @@ public StatelessAsyncSpecification toolCall(McpSchema.Tool tool, * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(McpStatelessServerFeatures.AsyncToolSpecification...) */ - public StatelessAsyncSpecification tools(List toolSpecifications) { + public StatelessAsyncSpecification tools( + List toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { @@ -1581,7 +1585,8 @@ public StatelessAsyncSpecification tools(List resourceSpecifications) { + public StatelessAsyncSpecification resources( + List resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (var resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -1648,7 +1654,8 @@ public StatelessAsyncSpecification resources(List prompts) { + public StatelessAsyncSpecification prompts( + Map prompts) { Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; @@ -1765,7 +1773,8 @@ public StatelessAsyncSpecification prompts(McpStatelessServerFeatures.AsyncPromp * @return This builder instance for method chaining * @throws IllegalArgumentException if completions is null */ - public StatelessAsyncSpecification completions(List completions) { + public StatelessAsyncSpecification completions( + List completions) { Assert.notNull(completions, "Completions list must not be null"); for (var completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -1780,7 +1789,8 @@ public StatelessAsyncSpecification completions(List callHandler) { + BiFunction callHandler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(callHandler, "Handler must not be null"); @@ -2005,7 +2019,8 @@ public StatelessSyncSpecification toolCall(McpSchema.Tool tool, * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(McpStatelessServerFeatures.SyncToolSpecification...) */ - public StatelessSyncSpecification tools(List toolSpecifications) { + public StatelessSyncSpecification tools( + List toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { @@ -2032,7 +2047,8 @@ public StatelessSyncSpecification tools(List resourceSpecifications) { + public StatelessSyncSpecification resources( + List resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (var resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); @@ -2099,7 +2116,8 @@ public StatelessSyncSpecification resources(List prompts) { + public StatelessSyncSpecification prompts( + Map prompts) { Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); return this; @@ -2216,7 +2235,8 @@ public StatelessSyncSpecification prompts(McpStatelessServerFeatures.SyncPromptS * @return This builder instance for method chaining * @throws IllegalArgumentException if completions is null */ - public StatelessSyncSpecification completions(List completions) { + public StatelessSyncSpecification completions( + List completions) { Assert.notNull(completions, "Completions list must not be null"); for (var completion : completions) { this.completions.put(completion.referenceKey(), completion); @@ -2231,7 +2251,8 @@ public StatelessSyncSpecification completions(List * Do NOT set to true if the underlying transport is a non-blocking * implementation. @@ -2283,28 +2304,33 @@ public StatelessSyncSpecification immediateExecution(boolean immediateExecution) public McpStatelessSyncServer build() { /* - McpServerFeatures.Sync syncFeatures = new McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, - this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, - this.rootsChangeHandlers, this.instructions); - McpServerFeatures.Async asyncFeatures = McpServerFeatures.Async.fromSync(syncFeatures, - this.immediateExecution); - var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); - var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator - : new DefaultJsonSchemaValidator(mapper); - - var asyncServer = new McpAsyncServer(this.transportProvider, mapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); - - return new McpSyncServer(asyncServer, this.immediateExecution); - */ - var syncFeatures = new McpStatelessServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); + * McpServerFeatures.Sync syncFeatures = new + * McpServerFeatures.Sync(this.serverInfo, this.serverCapabilities, + * this.tools, this.resources, this.resourceTemplates, this.prompts, + * this.completions, this.rootsChangeHandlers, this.instructions); + * McpServerFeatures.Async asyncFeatures = + * McpServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); + * var mapper = this.objectMapper != null ? this.objectMapper : new + * ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null + * ? this.jsonSchemaValidator : new DefaultJsonSchemaValidator(mapper); + * + * var asyncServer = new McpAsyncServer(this.transportProvider, mapper, + * asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, + * jsonSchemaValidator); + * + * return new McpSyncServer(asyncServer, this.immediateExecution); + */ + var syncFeatures = new McpStatelessServerFeatures.Sync(this.serverInfo, this.serverCapabilities, this.tools, + this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); var asyncFeatures = McpStatelessServerFeatures.Async.fromSync(syncFeatures, this.immediateExecution); var mapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper(); var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator : new DefaultJsonSchemaValidator(mapper); - var asyncServer = new McpStatelessAsyncServer(this.transport, mapper, asyncFeatures, this.requestTimeout, this.uriTemplateManagerFactory, jsonSchemaValidator); + var asyncServer = new McpStatelessAsyncServer(this.transport, mapper, asyncFeatures, this.requestTimeout, + this.uriTemplateManagerFactory, jsonSchemaValidator); return new McpStatelessSyncServer(asyncServer, this.immediateExecution); } + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java index fdeb517f2..fa768d7f8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -65,9 +65,8 @@ public class McpStatelessAsyncServer { private final JsonSchemaValidator jsonSchemaValidator; McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, ObjectMapper objectMapper, - McpStatelessServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, - JsonSchemaValidator jsonSchemaValidator) { + McpStatelessServerFeatures.Async features, Duration requestTimeout, + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { this.mcpTransportProvider = mcpTransport; this.objectMapper = objectMapper; this.serverInfo = features.serverInfo(); @@ -114,12 +113,12 @@ public class McpStatelessAsyncServer { requestHandlers.put(McpSchema.METHOD_COMPLETION_COMPLETE, completionCompleteRequestHandler()); } - mcpTransport.setRequestHandler((context, request) -> - requestHandlers.get(request.method()).apply(context, request.params()) - .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(t -> - Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, t.getMessage(), null))) - )); + mcpTransport.setRequestHandler((context, request) -> requestHandlers.get(request.method()) + .apply(context, request.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) + .onErrorResume(t -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, t.getMessage(), + null))))); } // --------------------------------------- @@ -127,7 +126,8 @@ public class McpStatelessAsyncServer { // --------------------------------------- private RequestHandler asyncInitializeRequestHandler() { return (ctx, req) -> Mono.defer(() -> { - McpSchema.InitializeRequest initializeRequest = this.objectMapper.convertValue(req, McpSchema.InitializeRequest.class); + McpSchema.InitializeRequest initializeRequest = this.objectMapper.convertValue(req, + McpSchema.InitializeRequest.class); logger.info("Client initialize request - Protocol: {}, Capabilities: {}, Info: {}", initializeRequest.protocolVersion(), initializeRequest.capabilities(), @@ -249,7 +249,9 @@ public Mono removeTool(String toolName) { private RequestHandler toolsListRequestHandler() { return (ctx, params) -> { - List tools = this.tools.stream().map(McpStatelessServerFeatures.AsyncToolSpecification::tool).toList(); + List tools = this.tools.stream() + .map(McpStatelessServerFeatures.AsyncToolSpecification::tool) + .toList(); return Mono.just(new McpSchema.ListToolsResult(tools, null)); }; } @@ -335,8 +337,7 @@ private RequestHandler resourcesListRequestHandle } private RequestHandler resourceTemplateListRequestHandler() { - return (ctx, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + return (ctx, params) -> Mono.just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); } @@ -482,21 +483,20 @@ private RequestHandler completionCompleteRequestHandle // check if the referenced resource exists if (type.equals("ref/prompt") && request.ref() instanceof McpSchema.PromptReference promptReference) { - McpStatelessServerFeatures.AsyncPromptSpecification promptSpec = this.prompts.get(promptReference.name()); + McpStatelessServerFeatures.AsyncPromptSpecification promptSpec = this.prompts + .get(promptReference.name()); if (promptSpec == null) { return Mono.error(new McpError("Prompt not found: " + promptReference.name())); } - if (promptSpec.prompt() - .arguments() - .stream() - .noneMatch(arg -> arg.name().equals(argumentName))) { + if (promptSpec.prompt().arguments().stream().noneMatch(arg -> arg.name().equals(argumentName))) { return Mono.error(new McpError("Argument not found: " + argumentName)); } } if (type.equals("ref/resource") && request.ref() instanceof McpSchema.ResourceReference resourceReference) { - McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this.resources.get(resourceReference.uri()); + McpStatelessServerFeatures.AsyncResourceSpecification resourceSpec = this.resources + .get(resourceReference.uri()); if (resourceSpec == null) { return Mono.error(new McpError("Resource not found: " + resourceReference.uri())); } @@ -567,4 +567,5 @@ void setProtocolVersions(List protocolVersions) { interface RequestHandler extends BiFunction> { } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java index 6813c658c..c64ec7cf7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessServerFeatures.java @@ -19,7 +19,8 @@ import java.util.function.BiFunction; /** - * MCP stateless server features specification that a particular server can choose to support. + * MCP stateless server features specification that a particular server can choose to + * support. * * @author Dariusz Jędrzejczyk */ @@ -37,11 +38,11 @@ public class McpStatelessServerFeatures { * @param instructions The server instructions text */ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, - List resourceTemplates, - Map prompts, - Map completions, - String instructions) { + List tools, + Map resources, List resourceTemplates, + Map prompts, + Map completions, + String instructions) { /** * Create an instance and validate the arguments. @@ -54,11 +55,11 @@ record Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities s * @param instructions The server instructions text */ Async(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities serverCapabilities, - List tools, Map resources, - List resourceTemplates, - Map prompts, - Map completions, - String instructions) { + List tools, + Map resources, List resourceTemplates, + Map prompts, + Map completions, + String instructions) { Assert.notNull(serverInfo, "Server info must not be null"); @@ -204,9 +205,9 @@ static AsyncToolSpecification fromSync(SyncToolSpecification syncToolSpec, boole return null; } - BiFunction> callHandler = (ctx, req) -> { - var toolResult = Mono - .fromCallable(() -> syncToolSpec.callHandler().apply(ctx, req)); + BiFunction> callHandler = (ctx, + req) -> { + var toolResult = Mono.fromCallable(() -> syncToolSpec.callHandler().apply(ctx, req)); return immediate ? toolResult : toolResult.subscribeOn(Schedulers.boundedElastic()); }; @@ -257,7 +258,8 @@ static AsyncResourceSpecification fromSync(SyncResourceSpecification resource, b * * @param prompt The prompt definition including name and description * @param promptHandler The function that processes prompt requests and returns - * formatted templates. The function's argument is a {@link McpSchema.GetPromptRequest}. + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. */ public record AsyncPromptSpecification(McpSchema.Prompt prompt, BiFunction> promptHandler) { @@ -286,7 +288,8 @@ static AsyncPromptSpecification fromSync(SyncPromptSpecification prompt, boolean * * @param referenceKey The unique key representing the completion reference. * @param completionHandler The asynchronous function that processes completion - * requests and returns results. The function's argument is a {@link McpSchema.CompleteRequest}. + * requests and returns results. The function's argument is a + * {@link McpSchema.CompleteRequest}. */ public record AsyncCompletionSpecification(McpSchema.CompleteReference referenceKey, BiFunction> completionHandler) { @@ -305,8 +308,7 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet return null; } return new AsyncCompletionSpecification(completion.referenceKey(), (ctx, req) -> { - var completionResult = Mono.fromCallable( - () -> completion.completionHandler().apply(ctx, req)); + var completionResult = Mono.fromCallable(() -> completion.completionHandler().apply(ctx, req)); return immediateExecution ? completionResult : completionResult.subscribeOn(Schedulers.boundedElastic()); }); @@ -357,7 +359,8 @@ public record SyncResourceSpecification(McpSchema.Resource resource, * * @param prompt The prompt definition including name and description * @param promptHandler The function that processes prompt requests and returns - * formatted templates. The function's argument is a {@link McpSchema.GetPromptRequest}. + * formatted templates. The function's argument is a + * {@link McpSchema.GetPromptRequest}. */ public record SyncPromptSpecification(McpSchema.Prompt prompt, BiFunction promptHandler) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java index 35eb4bff8..2f9715776 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpStatelessSyncServer.java @@ -88,8 +88,9 @@ public void close() { */ public void addTool(McpStatelessServerFeatures.SyncToolSpecification toolSpecification) { this.asyncServer - .addTool(McpStatelessServerFeatures.AsyncToolSpecification.fromSync(toolSpecification, this.immediateExecution)) - .block(); + .addTool(McpStatelessServerFeatures.AsyncToolSpecification.fromSync(toolSpecification, + this.immediateExecution)) + .block(); } /** @@ -108,8 +109,9 @@ public void removeTool(String toolName) { */ public void addResource(McpStatelessServerFeatures.SyncResourceSpecification resourceSpecification) { this.asyncServer - .addResource(McpStatelessServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, this.immediateExecution)) - .block(); + .addResource(McpStatelessServerFeatures.AsyncResourceSpecification.fromSync(resourceSpecification, + this.immediateExecution)) + .block(); } /** @@ -128,8 +130,9 @@ public void removeResource(String resourceUri) { */ public void addPrompt(McpStatelessServerFeatures.SyncPromptSpecification promptSpecification) { this.asyncServer - .addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, this.immediateExecution)) - .block(); + .addPrompt(McpStatelessServerFeatures.AsyncPromptSpecification.fromSync(promptSpecification, + this.immediateExecution)) + .block(); } /** @@ -149,4 +152,5 @@ public void removePrompt(String promptName) { void setProtocolVersions(List protocolVersions) { this.asyncServer.setProtocolVersions(protocolVersions); } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 0963a3761..c5ab8d733 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -47,6 +47,7 @@ public McpSchema.Implementation getClientInfo() { public McpTransportContext transportContext() { return this.exchange.transportContext(); } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java index c6f1f219c..48d244a58 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpStreamableServerSessionFactory.java @@ -9,22 +9,33 @@ import java.util.UUID; public class DefaultMcpStreamableServerSessionFactory implements McpStreamableServerSession.Factory { - Duration requestTimeout; - McpStreamableServerSession.InitRequestHandler initRequestHandler; - Map> requestHandlers; - Map notificationHandlers; - - public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout, McpStreamableServerSession.InitRequestHandler initRequestHandler, Map> requestHandlers, Map notificationHandlers) { - this.requestTimeout = requestTimeout; - this.initRequestHandler = initRequestHandler; - this.requestHandlers = requestHandlers; - this.notificationHandlers = notificationHandlers; - } - - @Override - public McpStreamableServerSession.McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest) { - return new McpStreamableServerSession.McpStreamableServerSessionInit(new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(), initializeRequest.clientInfo(), requestTimeout, - Mono::empty, requestHandlers, notificationHandlers), this.initRequestHandler.handle(initializeRequest)); - } + + Duration requestTimeout; + + McpStreamableServerSession.InitRequestHandler initRequestHandler; + + Map> requestHandlers; + + Map notificationHandlers; + + public DefaultMcpStreamableServerSessionFactory(Duration requestTimeout, + McpStreamableServerSession.InitRequestHandler initRequestHandler, + Map> requestHandlers, + Map notificationHandlers) { + this.requestTimeout = requestTimeout; + this.initRequestHandler = initRequestHandler; + this.requestHandlers = requestHandlers; + this.notificationHandlers = notificationHandlers; + } + + @Override + public McpStreamableServerSession.McpStreamableServerSessionInit startSession( + McpSchema.InitializeRequest initializeRequest) { + return new McpStreamableServerSession.McpStreamableServerSessionInit( + new McpStreamableServerSession(UUID.randomUUID().toString(), initializeRequest.capabilities(), + initializeRequest.clientInfo(), requestTimeout, Mono::empty, requestHandlers, + notificationHandlers), + this.initRequestHandler.handle(initializeRequest)); + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java index cf40b9c5e..ba5f3ed29 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportContext.java @@ -5,26 +5,28 @@ public class DefaultMcpTransportContext implements McpTransportContext { - private final Map storage; - - public DefaultMcpTransportContext() { - this.storage = new ConcurrentHashMap<>(); - } - DefaultMcpTransportContext(Map storage) { - this.storage = storage; - } - - @Override - public Object get(String key) { - return this.storage.get(key); - } - - @Override - public void put(String key, Object value) { - this.storage.put(key, value); - } - - public McpTransportContext copy() { - return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage)); - } + private final Map storage; + + public DefaultMcpTransportContext() { + this.storage = new ConcurrentHashMap<>(); + } + + DefaultMcpTransportContext(Map storage) { + this.storage = storage; + } + + @Override + public Object get(String key) { + return this.storage.get(key); + } + + @Override + public void put(String key, Object value) { + this.storage.put(key, value); + } + + public McpTransportContext copy() { + return new DefaultMcpTransportContext(new ConcurrentHashMap<>(this.storage)); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 381bd3675..17e6ea362 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -73,8 +73,9 @@ public class McpServerSession implements McpSession { * @param notificationHandlers map of notification handlers to use */ public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, - McpInitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, - Map> requestHandlers, Map notificationHandlers) { + McpInitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, + Map> requestHandlers, + Map notificationHandlers) { this.id = id; this.requestTimeout = requestTimeout; this.transport = transport; @@ -279,6 +280,7 @@ public void close() { /** * Request handler for the initialization request. + * * @deprecated Use {@link McpInitRequestHandler} */ @Deprecated @@ -308,6 +310,7 @@ public interface InitNotificationHandler { /** * A handler for client-initiated notifications. + * * @deprecated Use {@link McpNotificationHandler} */ @Deprecated diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java index 1c830d1b8..c04a4283d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransportProvider.java @@ -1,12 +1,13 @@ package io.modelcontextprotocol.spec; public interface McpServerTransportProvider extends McpServerTransportProviderBase { - /** - * Sets the session factory that will be used to create sessions for new clients. An - * implementation of the MCP server MUST call this method before any MCP interactions - * take place. - * - * @param sessionFactory the session factory to be used for initiating client sessions - */ - void setSessionFactory(McpServerSession.Factory sessionFactory); + + /** + * Sets the session factory that will be used to create sessions for new clients. An + * implementation of the MCP server MUST call this method before any MCP interactions + * take place. + * @param sessionFactory the session factory to be used for initiating client sessions + */ + void setSessionFactory(McpServerSession.Factory sessionFactory); + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java index 7668e00b3..513d551cc 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStatelessServerTransport.java @@ -6,7 +6,8 @@ public interface McpStatelessServerTransport { - void setRequestHandler(BiFunction> message); + void setRequestHandler( + BiFunction> message); /** * Immediately closes all the transports with connected clients and releases any diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index b61928544..73379d614 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -49,10 +49,10 @@ public class McpStreamableServerSession implements McpSession { private final AtomicReference listeningStreamRef = new AtomicReference<>(); - public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, - Duration requestTimeout, - InitNotificationHandler initNotificationHandler, - Map> requestHandlers, Map notificationHandlers) { + public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, + McpSchema.Implementation clientInfo, Duration requestTimeout, + InitNotificationHandler initNotificationHandler, Map> requestHandlers, + Map notificationHandlers) { this.id = id; this.clientCapabilities.lazySet(clientCapabilities); this.clientInfo.lazySet(clientInfo); @@ -74,7 +74,8 @@ private String generateRequestId() { public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { return Mono.defer(() -> { McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); - return listeningStream != null ? listeningStream.sendRequest(method, requestParams, typeRef) : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); + return listeningStream != null ? listeningStream.sendRequest(method, requestParams, typeRef) + : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); }); } @@ -82,7 +83,8 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc public Mono sendNotification(String method, Object params) { return Mono.defer(() -> { McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); - return listeningStream != null ? listeningStream.sendNotification(method, params) : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); + return listeningStream != null ? listeningStream.sendNotification(method, params) + : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); }); } @@ -92,7 +94,8 @@ public McpStreamableServerSessionStream listeningStream(McpServerTransport trans return listeningStream; } - // TODO: keep track of history by keeping a map from eventId to stream and then iterate over the events using the lastEventId + // TODO: keep track of history by keeping a map from eventId to stream and then + // iterate over the events using the lastEventId public Flux replay(Object lastEventId) { return Flux.empty(); } @@ -102,19 +105,28 @@ public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpSer McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); McpStreamableServerSessionStream stream = new McpStreamableServerSessionStream(transport); - McpRequestHandler requestHandler = McpStreamableServerSession.this.requestHandlers.get(jsonrpcRequest.method()); - // TODO: delegate to stream, which upon successful response should close remove itself from the registry and also close the underlying transport (sink) + McpRequestHandler requestHandler = McpStreamableServerSession.this.requestHandlers + .get(jsonrpcRequest.method()); + // TODO: delegate to stream, which upon successful response should close + // remove itself from the registry and also close the underlying transport + // (sink) if (requestHandler == null) { MethodNotFoundError error = getMethodNotFoundError(jsonrpcRequest.method()); - return transport.sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), null, - new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, - error.message(), error.data()))); + return transport + .sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.METHOD_NOT_FOUND, + error.message(), error.data()))); } - return requestHandler.handle(new McpAsyncServerExchange(stream, clientCapabilities.get(), clientInfo.get(), transportContext), jsonrpcRequest.params()) - .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, null)) - .flatMap(transport::sendMessage).then(transport.closeGracefully()); + return requestHandler + .handle(new McpAsyncServerExchange(stream, clientCapabilities.get(), clientInfo.get(), + transportContext), jsonrpcRequest.params()) + .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, + null)) + .flatMap(transport::sendMessage) + .then(transport.closeGracefully()); }); } + public Mono accept(McpSchema.JSONRPCNotification notification) { return Mono.deferContextual(ctx -> { McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); @@ -124,24 +136,32 @@ public Mono accept(McpSchema.JSONRPCNotification notification) { return Mono.empty(); } McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); - return notificationHandler.handle(new McpAsyncServerExchange(listeningStream != null ? listeningStream : MissingMcpTransportSession.INSTANCE, this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.params()); + return notificationHandler.handle( + new McpAsyncServerExchange( + listeningStream != null ? listeningStream : MissingMcpTransportSession.INSTANCE, + this.clientCapabilities.get(), this.clientInfo.get(), transportContext), + notification.params()); }); } + public Mono accept(McpSchema.JSONRPCResponse response) { return Mono.defer(() -> { - var stream = this.requestIdToStream.get(response.id()); - if (stream == null) { - return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO JSONize - } - var sink = stream.pendingResponses.remove(response.id()); - if (sink == null) { - return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO JSONize - } else { - sink.success(response); - } - return Mono.empty(); - }); + var stream = this.requestIdToStream.get(response.id()); + if (stream == null) { + return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO + // JSONize + } + var sink = stream.pendingResponses.remove(response.id()); + if (sink == null) { + return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO + // JSONize + } + else { + sink.success(response); + } + return Mono.empty(); + }); } record MethodNotFoundError(String method, String message, Object data) { @@ -155,7 +175,13 @@ private MethodNotFoundError getMethodNotFoundError(String method) { public Mono closeGracefully() { return Mono.defer(() -> { McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); - return listeningStream != null ? listeningStream.closeGracefully() : Mono.empty(); // TODO: Also close all the open streams + return listeningStream != null ? listeningStream.closeGracefully() : Mono.empty(); // TODO: + // Also + // close + // all + // the + // open + // streams }); } @@ -196,10 +222,14 @@ public interface InitNotificationHandler { } public interface Factory { + McpStreamableServerSessionInit startSession(McpSchema.InitializeRequest initializeRequest); + } - public record McpStreamableServerSessionInit(McpStreamableServerSession session, Mono initResult) {} + public record McpStreamableServerSessionInit(McpStreamableServerSession session, + Mono initResult) { + } public final class McpStreamableServerSessionStream implements McpSession { @@ -219,34 +249,32 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc return Mono.create(sink -> { this.pendingResponses.put(requestId, sink); - McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, - requestId, requestParams); - this.transport.sendMessage(jsonrpcRequest).subscribe(v -> {}, sink::error); - }) - .timeout(requestTimeout) - .doOnError(e -> { - this.pendingResponses.remove(requestId); - McpStreamableServerSession.this.requestIdToStream.remove(requestId); - }) - .handle((jsonRpcResponse, sink) -> { - if (jsonRpcResponse.error() != null) { - sink.error(new McpError(jsonRpcResponse.error())); - } - else { - if (typeRef.getType().equals(Void.class)) { - sink.complete(); - } - else { - sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); - } - } - }); + McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, + method, requestId, requestParams); + this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { + }, sink::error); + }).timeout(requestTimeout).doOnError(e -> { + this.pendingResponses.remove(requestId); + McpStreamableServerSession.this.requestIdToStream.remove(requestId); + }).handle((jsonRpcResponse, sink) -> { + if (jsonRpcResponse.error() != null) { + sink.error(new McpError(jsonRpcResponse.error())); + } + else { + if (typeRef.getType().equals(Void.class)) { + sink.complete(); + } + else { + sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef)); + } + } + }); } @Override public Mono sendNotification(String method, Object params) { - McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, - method, params); + McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification( + McpSchema.JSONRPC_VERSION, method, params); return this.transport.sendMessage(jsonrpcNotification); } @@ -271,6 +299,7 @@ public void close() { McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); this.transport.close(); } + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java index 088cd7547..48b9cd75e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransportProvider.java @@ -12,9 +12,10 @@ *

* The lifecycle of the provider dictates that it be created first, upon application * startup, and then passed into either - * {@link io.modelcontextprotocol.server.McpServer#sync(McpStreamableServerTransportProvider)} or - * {@link io.modelcontextprotocol.server.McpServer#async(McpStreamableServerTransportProvider)}. As - * a result of the MCP server creation, the provider will be notified of a + * {@link io.modelcontextprotocol.server.McpServer#sync(McpStreamableServerTransportProvider)} + * or + * {@link io.modelcontextprotocol.server.McpServer#async(McpStreamableServerTransportProvider)}. + * As a result of the MCP server creation, the provider will be notified of a * {@link McpServerSession.Factory} which will be used to handle a 1:1 communication * between a newly connected client and the server. The provider's responsibility is to * create instances of {@link McpServerTransport} that the session will utilise during the diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java index 45e7b3502..bfffeccd6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpTransportContext.java @@ -2,11 +2,14 @@ public interface McpTransportContext { - String KEY = "MCP_TRANSPORT_CONTEXT"; + String KEY = "MCP_TRANSPORT_CONTEXT"; - McpTransportContext EMPTY = new DefaultMcpTransportContext(); + McpTransportContext EMPTY = new DefaultMcpTransportContext(); + + Object get(String key); + + void put(String key, Object value); + + McpTransportContext copy(); - Object get(String key); - void put(String key, Object value); - McpTransportContext copy(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java index 46e588eaf..79ca44d2c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java @@ -5,24 +5,25 @@ public class MissingMcpTransportSession implements McpSession { - public static final MissingMcpTransportSession INSTANCE = new MissingMcpTransportSession(); - - @Override - public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { - return Mono.error(new IllegalStateException("Stream unavailable")); - } - - @Override - public Mono sendNotification(String method, Object params) { - return Mono.error(new IllegalStateException("Stream unavailable")); - } - - @Override - public Mono closeGracefully() { - return Mono.empty(); - } - - @Override - public void close() { - } + public static final MissingMcpTransportSession INSTANCE = new MissingMcpTransportSession(); + + @Override + public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { + return Mono.error(new IllegalStateException("Stream unavailable")); + } + + @Override + public Mono sendNotification(String method, Object params) { + return Mono.error(new IllegalStateException("Stream unavailable")); + } + + @Override + public Mono closeGracefully() { + return Mono.empty(); + } + + @Override + public void close() { + } + } From a3ddb75f36ff480b163df3e0b85698523eea700a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 22 Jul 2025 15:53:31 +0200 Subject: [PATCH 4/8] WIP: make it compile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebFluxSseServerTransportProvider.java | 3 +- ...FluxStreamableServerTransportProvider.java | 75 ++++++++++++------- .../WebMvcSseServerTransportProvider.java | 3 +- .../StdioServerTransportProvider.java | 3 +- .../MockMcpServerTransportProvider.java | 3 +- 5 files changed, 50 insertions(+), 37 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index f15621837..fde067f03 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -10,7 +10,6 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,7 +62,7 @@ * @see McpServerTransport * @see ServerSentEvent */ -public class WebFluxSseServerTransportProvider implements McpSingleSessionServerTransportProvider { +public class WebFluxSseServerTransportProvider implements McpServerTransportProvider { private static final Logger logger = LoggerFactory.getLogger(WebFluxSseServerTransportProvider.class); diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index ffeb25e2b..cc84e05b1 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -193,7 +193,8 @@ private Mono handleGet(ServerRequest request) { return Mono.defer(() -> { if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) { - return ServerResponse.badRequest().build(); // TODO: say we need a session id + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id } String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id"); @@ -206,15 +207,20 @@ private Mono handleGet(ServerRequest request) { if (request.headers().asHttpHeaders().containsKey("mcp-last-id")) { String lastId = request.headers().asHttpHeaders().getFirst("mcp-last-id"); - return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(session.replay(lastId), ServerSentEvent.class); + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(session.replay(lastId), ServerSentEvent.class); } - return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM) - .body(Flux.>create(sink -> { - WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport(sink); - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session.listeningStream(sessionTransport); - sink.onDispose(listeningStream::close); - }), ServerSentEvent.class); + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport( + sink); + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + sink.onDispose(listeningStream::close); + }), ServerSentEvent.class); }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); } @@ -244,11 +250,18 @@ private Mono handlePost(ServerRequest request) { return request.bodyToMono(String.class).flatMap(body -> { try { McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); - if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { - McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), new TypeReference() {}); - McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory.startSession(initializeRequest); + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), + new TypeReference() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); sessions.put(init.session().getId(), init.session()); - return init.initResult().flatMap(initResult -> ServerResponse.ok().header("mcp-session-id", init.session().getId()).bodyValue(initResult)); + return init.initResult() + .flatMap(initResult -> ServerResponse.ok() + .header("mcp-session-id", init.session().getId()) + .bodyValue(initResult)); } if (!request.headers().asHttpHeaders().containsKey("sessionId")) { @@ -260,26 +273,30 @@ private Mono handlePost(ServerRequest request) { if (session == null) { return ServerResponse.status(HttpStatus.NOT_FOUND) - .bodyValue(new McpError("Session not found: " + sessionId)); + .bodyValue(new McpError("Session not found: " + sessionId)); } if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { return session.accept(jsonrpcResponse).then(ServerResponse.accepted().build()); - } else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { return session.accept(jsonrpcNotification).then(ServerResponse.accepted().build()); - } else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { - return ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM) - .body(Flux.>create(sink -> { - WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); - Mono stream = session.responseStream(jsonrpcRequest, st); - Disposable streamSubscription = stream - .doOnError(err -> sink.error(err)) - .contextWrite(sink.contextView()) - .subscribe(); - sink.onCancel(streamSubscription); - }), ServerSentEvent.class); - } else { - return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).bodyValue(new McpError("Unknown message type")); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + return ServerResponse.ok() + .contentType(MediaType.TEXT_EVENT_STREAM) + .body(Flux.>create(sink -> { + WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); + Mono stream = session.responseStream(jsonrpcRequest, st); + Disposable streamSubscription = stream.doOnError(err -> sink.error(err)) + .contextWrite(sink.contextView()) + .subscribe(); + sink.onCancel(streamSubscription); + }), ServerSentEvent.class); + } + else { + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .bodyValue(new McpError("Unknown message type")); } } catch (IllegalArgumentException | IOException e) { @@ -393,8 +410,8 @@ public Builder messageEndpoint(String messageEndpoint) { } /** - * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with the - * configured settings. + * Builds a new instance of {@link WebFluxStreamableServerTransportProvider} with + * the configured settings. * @return A new WebFluxSseServerTransportProvider instance * @throws IllegalStateException if required parameters are not set */ diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index c753f836c..114eff607 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -16,7 +16,6 @@ import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -66,7 +65,7 @@ * @see McpServerTransportProvider * @see RouterFunction */ -public class WebMvcSseServerTransportProvider implements McpSingleSessionServerTransportProvider { +public class WebMvcSseServerTransportProvider implements McpServerTransportProvider { private static final Logger logger = LoggerFactory.getLogger(WebMvcSseServerTransportProvider.class); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index b607dbf6f..9ef9c7829 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -22,7 +22,6 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -39,7 +38,7 @@ * * @author Christian Tzolov */ -public class StdioServerTransportProvider implements McpSingleSessionServerTransportProvider { +public class StdioServerTransportProvider implements McpServerTransportProvider { private static final Logger logger = LoggerFactory.getLogger(StdioServerTransportProvider.class); diff --git a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java index 71e090890..7ba35bbf0 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java +++ b/mcp/src/test/java/io/modelcontextprotocol/MockMcpServerTransportProvider.java @@ -19,13 +19,12 @@ import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerSession.Factory; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpSingleSessionServerTransportProvider; import reactor.core.publisher.Mono; /** * @author Christian Tzolov */ -public class MockMcpServerTransportProvider implements McpSingleSessionServerTransportProvider { +public class MockMcpServerTransportProvider implements McpServerTransportProvider { private McpServerSession session; From cef7aac81c1868aa93a5c328934d6a8fef5a0a78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 22 Jul 2025 18:24:55 +0200 Subject: [PATCH 5/8] WIP: Add session ID to exchange MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../server/McpAsyncServerExchange.java | 14 ++++++++++++-- .../server/McpSyncServerExchange.java | 4 ++++ .../spec/McpServerSession.java | 3 ++- .../spec/McpStreamableServerSession.java | 4 ++-- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index ec5713261..fe5536db6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -27,6 +27,8 @@ */ public class McpAsyncServerExchange { + private final String sessionId; + private final McpSession session; private final McpSchema.ClientCapabilities clientCapabilities; @@ -55,13 +57,16 @@ public class McpAsyncServerExchange { * @param clientCapabilities The client capabilities that define the supported * features and functionality. * @param clientInfo The client implementation information. + * @deprecated Use + * {@link #McpAsyncServerExchange(String, McpSession, McpSchema.ClientCapabilities, McpSchema.Implementation, McpTransportContext)} */ + @Deprecated public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; - this.transportContext = new DefaultMcpTransportContext(); + this.transportContext = McpTransportContext.EMPTY; } /** @@ -73,8 +78,9 @@ public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities c * transport * @param clientInfo The client implementation information. */ - public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, + public McpAsyncServerExchange(String sessionId, McpSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, McpTransportContext transportContext) { + this.sessionId = sessionId; this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; @@ -101,6 +107,10 @@ public McpTransportContext transportContext() { return this.transportContext; } + public String sessionId() { + return this.sessionId; + } + /** * Create a new message using the sampling capabilities of the client. The Model * Context Protocol (MCP) provides a standardized way for servers to request LLM diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index c5ab8d733..d5fc317fe 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -28,6 +28,10 @@ public McpSyncServerExchange(McpAsyncServerExchange exchange) { this.exchange = exchange; } + public String sessionId() { + return this.exchange.sessionId(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 17e6ea362..e1911e6bb 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -246,7 +246,8 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); - exchangeSink.tryEmitValue(new McpAsyncServerExchange(this, clientCapabilities.get(), clientInfo.get())); + exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(), + clientInfo.get(), McpTransportContext.EMPTY)); return this.initNotificationHandler.handle(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index 73379d614..59aad6dc4 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -118,7 +118,7 @@ public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpSer error.message(), error.data()))); } return requestHandler - .handle(new McpAsyncServerExchange(stream, clientCapabilities.get(), clientInfo.get(), + .handle(new McpAsyncServerExchange(this.id, stream, clientCapabilities.get(), clientInfo.get(), transportContext), jsonrpcRequest.params()) .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), result, null)) @@ -137,7 +137,7 @@ public Mono accept(McpSchema.JSONRPCNotification notification) { } McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); return notificationHandler.handle( - new McpAsyncServerExchange( + new McpAsyncServerExchange(this.id, listeningStream != null ? listeningStream : MissingMcpTransportSession.INSTANCE, this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.params()); From 6423fcd24d0b28f4d3a8f90039c37b4ddd9a5beb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Tue, 22 Jul 2025 18:41:23 +0200 Subject: [PATCH 6/8] WIP: make it compile MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../io/modelcontextprotocol/server/McpAsyncServerExchange.java | 2 ++ .../java/io/modelcontextprotocol/spec/McpServerSession.java | 2 ++ 2 files changed, 4 insertions(+) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index fe5536db6..cbf13f73d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -13,6 +13,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; +import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpSession; import io.modelcontextprotocol.spec.McpTransportContext; import io.modelcontextprotocol.util.Assert; @@ -63,6 +64,7 @@ public class McpAsyncServerExchange { @Deprecated public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { + this.sessionId = null; this.session = session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index e1911e6bb..2dad7174e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -246,6 +246,8 @@ private Mono handleIncomingNotification(McpSchema.JSONRPCNotification noti return Mono.defer(() -> { if (McpSchema.METHOD_NOTIFICATION_INITIALIZED.equals(notification.method())) { this.state.lazySet(STATE_INITIALIZED); + // FIXME: The session ID passed here is not the same as the one in the + // legacy SSE transport. exchangeSink.tryEmitValue(new McpAsyncServerExchange(this.id, this, clientCapabilities.get(), clientInfo.get(), McpTransportContext.EMPTY)); return this.initNotificationHandler.handle(); From 8e9ab52061bc482139da9a5256d23ef2862322ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Wed, 23 Jul 2025 18:02:45 +0200 Subject: [PATCH 7/8] WIP: add SSE id generation and DELETE handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- ...FluxStreamableServerTransportProvider.java | 51 +++++++++++++++++-- .../spec/McpStreamableServerSession.java | 32 +++++++++--- .../spec/McpStreamableServerTransport.java | 15 ++++++ 3 files changed, 87 insertions(+), 11 deletions(-) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index cc84e05b1..db1148efa 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -5,8 +5,8 @@ import io.modelcontextprotocol.spec.DefaultMcpTransportContext; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.spec.McpTransportContext; import io.modelcontextprotocol.util.Assert; @@ -45,6 +45,8 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe private final String mcpEndpoint; + private final boolean disallowDelete; + private final RouterFunction routerFunction; private McpStreamableServerSession.Factory sessionFactory; @@ -70,7 +72,7 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe * @throws IllegalArgumentException if either parameter is null */ public WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, mcpEndpoint); + this(objectMapper, DEFAULT_BASE_URL, mcpEndpoint, false); } /** @@ -83,7 +85,8 @@ public WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, Strin * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null */ - public WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String mcpEndpoint) { + public WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String mcpEndpoint, + boolean disallowDelete) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(mcpEndpoint, "Message endpoint must not be null"); @@ -91,9 +94,11 @@ public WebFluxStreamableServerTransportProvider(ObjectMapper objectMapper, Strin this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; this.routerFunction = RouterFunctions.route() .GET(this.mcpEndpoint, this::handleGet) .POST(this.mcpEndpoint, this::handlePost) + .DELETE(this.mcpEndpoint, this::handleDelete) .build(); } @@ -306,7 +311,37 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); } - private class WebFluxStreamableMcpSessionTransport implements McpServerTransport { + private Mono handleDelete(ServerRequest request) { + if (isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).bodyValue("Server is shutting down"); + } + + McpTransportContext transportContext = this.contextExtractor.apply(request); + + return Mono.defer(() -> { + if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) { + return ServerResponse.badRequest().build(); // TODO: say we need a session + // id + } + + // TODO: The user can configure whether deletions are permitted + if (this.disallowDelete) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id"); + + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + return session.delete().then(ServerResponse.ok().build()); + }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + } + + private class WebFluxStreamableMcpSessionTransport implements McpStreamableServerTransport { private final FluxSink> sink; @@ -316,6 +351,11 @@ public WebFluxStreamableMcpSessionTransport(FluxSink> sink) { @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return this.sendMessage(message, null); + } + + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { return Mono.fromSupplier(() -> { try { return objectMapper.writeValueAsString(message); @@ -325,6 +365,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { } }).doOnNext(jsonText -> { ServerSentEvent event = ServerSentEvent.builder() + .id(messageId) .event(MESSAGE_EVENT_TYPE) .data(jsonText) .build(); @@ -419,7 +460,7 @@ public WebFluxStreamableServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(mcpEndpoint, "Message endpoint must be set"); - return new WebFluxStreamableServerTransportProvider(objectMapper, baseUrl, mcpEndpoint); + return new WebFluxStreamableServerTransportProvider(objectMapper, baseUrl, mcpEndpoint, false); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index 59aad6dc4..0532704b5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -12,10 +12,12 @@ import java.time.Duration; import java.util.Map; +import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; public class McpStreamableServerSession implements McpSession { @@ -88,7 +90,13 @@ public Mono sendNotification(String method, Object params) { }); } - public McpStreamableServerSessionStream listeningStream(McpServerTransport transport) { + public Mono delete() { + return this.closeGracefully().then(Mono.fromRunnable(() -> { + // delete history, etc. + })); + } + + public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) { McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); this.listeningStreamRef.set(listeningStream); return listeningStream; @@ -100,7 +108,7 @@ public Flux replay(Object lastEventId) { return Flux.empty(); } - public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpServerTransport transport) { + public Mono responseStream(McpSchema.JSONRPCRequest jsonrpcRequest, McpStreamableServerTransport transport) { return Mono.deferContextual(ctx -> { McpTransportContext transportContext = ctx.getOrDefault(McpTransportContext.KEY, McpTransportContext.EMPTY); @@ -235,10 +243,18 @@ public final class McpStreamableServerSessionStream implements McpSession { private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); - private final McpServerTransport transport; + private final McpStreamableServerTransport transport; + + private final String transportId; + + private final Supplier uuidGenerator; - public McpStreamableServerSessionStream(McpServerTransport transport) { + public McpStreamableServerSessionStream(McpStreamableServerTransport transport) { this.transport = transport; + this.transportId = UUID.randomUUID().toString(); + // This ID design allows for a constant-time extraction of the history by + // precisely identifying the SSE stream using the first component + this.uuidGenerator = () -> this.transportId + "_" + UUID.randomUUID(); } @Override @@ -251,7 +267,9 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc this.pendingResponses.put(requestId, sink); McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, method, requestId, requestParams); - this.transport.sendMessage(jsonrpcRequest).subscribe(v -> { + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + this.transport.sendMessage(jsonrpcRequest, messageId).subscribe(v -> { }, sink::error); }).timeout(requestTimeout).doOnError(e -> { this.pendingResponses.remove(requestId); @@ -275,7 +293,9 @@ public Mono sendRequest(String method, Object requestParams, TypeReferenc public Mono sendNotification(String method, Object params) { McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification( McpSchema.JSONRPC_VERSION, method, params); - return this.transport.sendMessage(jsonrpcNotification); + String messageId = this.uuidGenerator.get(); + // TODO: store message in history + return this.transport.sendMessage(jsonrpcNotification, messageId); } @Override diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java new file mode 100644 index 000000000..e49ba9c84 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerTransport.java @@ -0,0 +1,15 @@ +package io.modelcontextprotocol.spec; + +import reactor.core.publisher.Mono; + +/** + * Marker interface for the server-side MCP transport. + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + */ +public interface McpStreamableServerTransport extends McpServerTransport { + + Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId); + +} From 92ca81fc71ab504047190bd580a7a4ae4e8d78e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dariusz=20J=C4=99drzejczyk?= Date: Fri, 25 Jul 2025 17:22:52 +0200 Subject: [PATCH 8/8] WIP: Handle logging notifications, start with integration tests and async server tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Dariusz Jędrzejczyk --- .../WebClientStreamableHttpTransport.java | 3 +- ...FluxStreamableServerTransportProvider.java | 23 +- .../WebFluxStreamableIntegrationTests.java | 1496 +++++++++++++++++ .../server/WebFluxSseMcpAsyncServerTests.java | 8 +- .../WebFluxStreamableMcpAsyncServerTests.java | 58 + .../WebMvcSseAsyncServerTransportTests.java | 8 +- .../server/AbstractMcpAsyncServerTests.java | 98 +- .../server/McpAsyncServerExchange.java | 25 +- .../server/McpServer.java | 121 +- .../spec/McpLoggableSession.java | 14 + .../modelcontextprotocol/spec/McpSchema.java | 2 +- .../spec/McpServerSession.java | 16 +- .../modelcontextprotocol/spec/McpSession.java | 1 + .../spec/McpStreamableServerSession.java | 74 +- .../spec/MissingMcpTransportSession.java | 26 +- 15 files changed, 1799 insertions(+), 174 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java create mode 100644 mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java index 53b59cb30..e8cb26144 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebClientStreamableHttpTransport.java @@ -340,7 +340,8 @@ private Flux extractError(ClientResponse response, Str McpSchema.JSONRPCResponse jsonRpcResponse = objectMapper.readValue(body, McpSchema.JSONRPCResponse.class); jsonRpcError = jsonRpcResponse.error(); - toPropagate = new McpError(jsonRpcError); + toPropagate = jsonRpcError != null ? new McpError(jsonRpcError) + : new McpError("Can't parse the jsonResponse " + jsonRpcResponse); } catch (IOException ex) { toPropagate = new RuntimeException("Sending request failed", e); diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index db1148efa..172882007 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -35,8 +35,6 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe public static final String MESSAGE_EVENT_TYPE = "message"; - public static final String ENDPOINT_EVENT_TYPE = "endpoint"; - public static final String DEFAULT_BASE_URL = ""; private final ObjectMapper objectMapper; @@ -263,17 +261,28 @@ private Mono handlePost(ServerRequest request) { McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory .startSession(initializeRequest); sessions.put(init.session().getId(), init.session()); - return init.initResult() + return init.initResult().map(initializeResult -> { + McpSchema.JSONRPCResponse jsonrpcResponse = new McpSchema.JSONRPCResponse( + McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initializeResult, null); + try { + return this.objectMapper.writeValueAsString(jsonrpcResponse); + } + catch (IOException e) { + logger.warn("Failed to serialize initResponse", e); + throw Exceptions.propagate(e); + } + }) .flatMap(initResult -> ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) .header("mcp-session-id", init.session().getId()) .bodyValue(initResult)); } - if (!request.headers().asHttpHeaders().containsKey("sessionId")) { + if (!request.headers().asHttpHeaders().containsKey("mcp-session-id")) { return ServerResponse.badRequest().bodyValue(new McpError("Session ID missing")); } - String sessionId = request.headers().asHttpHeaders().getFirst("sessionId"); + String sessionId = request.headers().asHttpHeaders().getFirst("mcp-session-id"); McpStreamableServerSession session = sessions.get(sessionId); if (session == null) { @@ -308,7 +317,9 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { logger.error("Failed to deserialize message: {}", e.getMessage()); return ServerResponse.badRequest().bodyValue(new McpError("Invalid message format")); } - }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); + }) + .switchIfEmpty(ServerResponse.badRequest().build()) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); } private Mono handleDelete(ServerRequest request) { diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java new file mode 100644 index 000000000..dfd1f5565 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxStreamableIntegrationTests.java @@ -0,0 +1,1496 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.server.TestUtil; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CompleteRequest; +import io.modelcontextprotocol.spec.McpSchema.CompleteResult; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Prompt; +import io.modelcontextprotocol.spec.McpSchema.PromptArgument; +import io.modelcontextprotocol.spec.McpSchema.PromptReference; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertWith; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +class WebFluxStreamableIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String CUSTOM_MESSAGE_ENDPOINT = "/otherPath/mcp/message"; + + private DisposableServer httpServer; + + private WebFluxStreamableServerTransportProvider mcpStreamableServerTransportProvider; + + ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + @BeforeEach + public void before() { + + this.mcpStreamableServerTransportProvider = new WebFluxStreamableServerTransportProvider(new ObjectMapper(), + CUSTOM_MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions + .toHttpHandler(mcpStreamableServerTransportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + clientBuilders + .put("webflux", McpClient + .sync(WebClientStreamableHttpTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(CUSTOM_MESSAGE_ENDPOINT) + .build()) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + // @ParameterizedTest(name = "{0} : {displayName} ") + // @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> exchange.createMessage(mock(CreateMessageRequest.class)) + .thenReturn(mock(CallToolResult.class))) + .build(); + + var server = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build();) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + server.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var createMessageRequest = CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(createMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + // Server + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + AtomicReference samplingResult = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var craeteMessageRequest = CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + return exchange.createMessage(craeteMessageRequest) + .doOnNext(samplingResult::set) + .thenReturn(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .requestTimeout(Duration.ofSeconds(4)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + assertWith(samplingResult.get(), result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }); + } + + mcpServer.closeGracefully().block(); + } + + // @ParameterizedTest(name = "{0} : {displayName} ") + // @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + // Server + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var craeteMessageRequest = CreateMessageRequest.builder() + .messages(List + .of(new McpSchema.SamplingMessage(Role.USER, new McpSchema.TextContent("Test message")))) + .build(); + + return exchange.createMessage(craeteMessageRequest).thenReturn(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .requestTimeout(Duration.ofSeconds(1)) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + } + + mcpServer.closeGracefully().block(); + } + + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + // @ParameterizedTest(name = "{0} : {displayName} ") + // @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + exchange.createElicitation(mock(ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }) + .build(); + + var server = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + // @ParameterizedTest(name = "{0} : {displayName} ") + // @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + var latch = new CountDownLatch(1); + // Client + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + try { + if (!latch.await(2, TimeUnit.SECONDS)) { + throw new RuntimeException("Timeout waiting for elicitation processing"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + // Server + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) // 1 second. + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } + + mcpServer.close(); + } + + // @ParameterizedTest(name = "{0} : {displayName} ") + // @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }) + .build(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> { + }) + .tools(tool) + .build(); + + // Create client without roots capability + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsNotificationWithEmptyRootsList(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithMultipleHandlers(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }) + .build(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(new Tool("tool1", "tool1 description", emptyJsonSchema)) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + return callResponse; + }) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + String response = RestClient.create() + .get() + .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") + .retrieve() + .body(String.class); + assertThat(response).isNotBlank(); + rootsRef.set(toolsUpdate); + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() + .tool(new Tool("tool2", "tool2 description", emptyJsonSchema)) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider).build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Logging Tests + // --------------------------------------- + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testLoggingNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 3; + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("logging-test", "Test logging notifications", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + // Create and send notifications with different levels + + //@formatter:off + return exchange // This should be filtered out (DEBUG < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(new CallToolResult("Logging test completed", false)); + //@formatter:on + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with logging notification handler + var mcpClient = clientBuilder.loggingConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Set minimum logging level to NOTICE + mcpClient.setLoggingLevel(McpSchema.LoggingLevel.NOTICE); + + // Call the tool that sends logging notifications + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("logging-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Logging test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications (1 NOTICE and 2 ERROR) + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.data(), n -> n)); + + // First notification should be NOTICE level + assertThat(notificationMap.get("Notice message").level()).isEqualTo(McpSchema.LoggingLevel.NOTICE); + assertThat(notificationMap.get("Notice message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Notice message").data()).isEqualTo("Notice message"); + + // Second notification should be ERROR level + assertThat(notificationMap.get("Error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Error message").data()).isEqualTo("Error message"); + + // Third notification should be ERROR level + assertThat(notificationMap.get("Another error message").level()).isEqualTo(McpSchema.LoggingLevel.ERROR); + assertThat(notificationMap.get("Another error message").logger()).isEqualTo("test-logger"); + assertThat(notificationMap.get("Another error message").data()).isEqualTo("Another error message"); + } + mcpServer.close(); + } + + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testProgressNotification(String clientType) throws InterruptedException { + int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress + // token + CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); + // Create a list to store received logging notifications + List receivedNotifications = new CopyOnWriteArrayList<>(); + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("progress-test") + .description("Test progress notifications") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + // Create and send notifications + var progressToken = (String) request.meta().get("progressToken"); + + return exchange + .progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.0, 1.0, "Processing started")) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Processing data"))) + .then(// Send a progress notification with another progress value + // should + exchange.progressNotification(new McpSchema.ProgressNotification("another-progress-token", + 0.0, 1.0, "Another processing started"))) + .then(exchange.progressNotification( + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Processing completed"))) + .thenReturn(new CallToolResult(("Progress test completed"), false)); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(notification -> { + receivedNotifications.add(notification); + latch.countDown(); + }).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + McpSchema.CallToolRequest callToolRequest = McpSchema.CallToolRequest.builder() + .name("progress-test") + .meta(Map.of("progressToken", "test-progress-token")) + .build(); + CallToolResult result = mcpClient.callTool(callToolRequest); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + assertThat(latch.await(5, TimeUnit.SECONDS)).as("Should receive notifications in reasonable time").isTrue(); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(expectedNotificationsCount); + + Map notificationMap = receivedNotifications.stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("Processing started").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing started").message()).isEqualTo("Processing started"); + + // Second notification should be 0.5/1.0 progress + assertThat(notificationMap.get("Processing data").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing data").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("Processing data").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing data").message()).isEqualTo("Processing data"); + + // Third notification should be another progress token with 0.0/1.0 progress + assertThat(notificationMap.get("Another processing started").progressToken()) + .isEqualTo("another-progress-token"); + assertThat(notificationMap.get("Another processing started").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("Another processing started").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Another processing started").message()) + .isEqualTo("Another processing started"); + + // Fourth notification should be 1.0/1.0 progress + assertThat(notificationMap.get("Processing completed").progressToken()).isEqualTo("test-progress-token"); + assertThat(notificationMap.get("Processing completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("Processing completed").message()).isEqualTo("Processing completed"); + } + finally { + mcpServer.close(); + } + } + + // --------------------------------------- + // Completion Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : Completion call") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCompletionShouldReturnExpectedSuggestions(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + var expectedValues = List.of("python", "pytorch", "pyside"); + var completionResponse = new CompleteResult(new CompleteResult.CompleteCompletion(expectedValues, 10, // total + true // hasMore + )); + + AtomicReference samplingRequest = new AtomicReference<>(); + BiFunction completionHandler = (mcpSyncServerExchange, + request) -> { + samplingRequest.set(request); + return completionResponse; + }; + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification( + new Prompt("code_review", "Code review", "this is code review prompt", + List.of(new PromptArgument("language", "Language", "string", false))), + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new PromptReference("ref/prompt", "code_review", "Code review"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CompleteRequest request = new CompleteRequest( + new PromptReference("ref/prompt", "code_review", "Code review"), + new CompleteRequest.CompleteArgument("language", "py")); + + CompleteResult result = mcpClient.completeCompletion(request); + + assertThat(result).isNotNull(); + + assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); + assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); + assertThat(samplingRequest.get().ref().type()).isEqualTo("ref/prompt"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Ping Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testPingSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that uses ping functionality + AtomicReference executionOrder = new AtomicReference<>(""); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(new Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema)) + .callHandler((exchange, request) -> { + + executionOrder.set(executionOrder.get() + "1"); + + // Test async ping behavior + return exchange.ping().doOnNext(result -> { + + assertThat(result).isNotNull(); + // Ping should return an empty object or map + assertThat(result).isInstanceOf(Map.class); + + executionOrder.set(executionOrder.get() + "2"); + assertThat(result).isNotNull(); + }).then(Mono.fromCallable(() -> { + executionOrder.set(executionOrder.get() + "3"); + return new CallToolResult("Async ping test completed", false); + })); + }) + .build(); + + var mcpServer = McpServer.async(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that tests ping async behavior + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); + + // Verify execution order + assertThat(executionOrder.get()).isEqualTo("123"); + } + + mcpServer.closeGracefully().block(); + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, + (exchange, request) -> { + String expression = (String) request.getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + assertThatJson(((McpSchema.TextContent) response.content().get(0)).text()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationFailure(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, + (exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, + (exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }); + + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .instructions("bla") + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = McpServer.sync(mcpStreamableServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool, + (exchange, request) -> { + int count = (Integer) request.getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java index cc33e7b94..a3bdf10b0 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseMcpAsyncServerTests.java @@ -29,8 +29,7 @@ class WebFluxSseMcpAsyncServerTests extends AbstractMcpAsyncServerTests { private DisposableServer httpServer; - @Override - protected McpServerTransportProvider createMcpTransportProvider() { + private McpServerTransportProvider createMcpTransportProvider() { var transportProvider = new WebFluxSseServerTransportProvider.Builder().objectMapper(new ObjectMapper()) .messageEndpoint(MESSAGE_ENDPOINT) .build(); @@ -41,6 +40,11 @@ protected McpServerTransportProvider createMcpTransportProvider() { return transportProvider; } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + @Override protected void onStart() { } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java new file mode 100644 index 000000000..0d7429a0c --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxStreamableMcpAsyncServerTests.java @@ -0,0 +1,58 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import org.junit.jupiter.api.Timeout; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.server.RouterFunctions; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebFluxStreamableMcpAsyncServerTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private DisposableServer httpServer; + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + var transportProvider = new WebFluxStreamableServerTransportProvider(new ObjectMapper(), MESSAGE_ENDPOINT); + + HttpHandler httpHandler = RouterFunctions.toHttpHandler(transportProvider.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java index 6a6ad17e9..bb4c2bf37 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseAsyncServerTransportTests.java @@ -49,8 +49,7 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro private AnnotationConfigWebApplicationContext appContext; - @Override - protected McpServerTransportProvider createMcpTransportProvider() { + private McpServerTransportProvider createMcpTransportProvider() { // Set up Tomcat first tomcat = new Tomcat(); tomcat.setPort(PORT); @@ -90,6 +89,11 @@ protected McpServerTransportProvider createMcpTransportProvider() { return transportProvider; } + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + @Override protected void onStart() { } diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index eb08bdcde..68a60a17c 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -18,9 +18,12 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; @@ -42,7 +45,7 @@ public abstract class AbstractMcpAsyncServerTests { private static final String TEST_PROMPT_NAME = "test-prompt"; - abstract protected McpServerTransportProvider createMcpTransportProvider(); + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); protected void onStart() { } @@ -63,28 +66,29 @@ void tearDown() { // Server Lifecycle Tests // --------------------------------------- - @Test - void testConstructorWithInvalidArguments() { + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "sse", "streamable" }) + void testConstructorWithInvalidArguments(String serverType) { assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null"); - assertThatThrownBy( - () -> McpServer.async(createMcpTransportProvider()).serverInfo((McpSchema.Implementation) null)) + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo((McpSchema.Implementation) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Server info must not be null"); } @Test void testGracefulShutdown() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + McpServer.AsyncSpecification builder = prepareAsyncServerBuilder(); + var mcpAsyncServer = builder.serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.closeGracefully()).verifyComplete(); } @Test void testImmediateClose() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThatCode(() -> mcpAsyncServer.close()).doesNotThrowAnyException(); } @@ -104,8 +108,7 @@ void testImmediateClose() { @Deprecated void testAddTool() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -119,8 +122,7 @@ void testAddTool() { @Test void testAddToolCall() { Tool newTool = new McpSchema.Tool("new-tool", "New test tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -137,8 +139,7 @@ void testAddToolCall() { void testAddDuplicateTool() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tool(duplicateTool, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -158,8 +159,7 @@ void testAddDuplicateTool() { void testAddDuplicateToolCall() { Tool duplicateTool = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -180,8 +180,7 @@ void testDuplicateToolCallDuringBuilding() { Tool duplicateTool = new Tool("duplicate-build-toolcall", "Duplicate toolcall during building", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .toolCall(duplicateTool, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) // Duplicate! @@ -203,8 +202,7 @@ void testDuplicateToolsInBatchListRegistration() { .build() // Duplicate! ); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(specs) .build()).isInstanceOf(IllegalArgumentException.class) @@ -215,8 +213,7 @@ void testDuplicateToolsInBatchListRegistration() { void testDuplicateToolsInBatchVarargsRegistration() { Tool duplicateTool = new Tool("batch-varargs-tool", "Duplicate tool in batch varargs", emptyJsonSchema); - assertThatThrownBy(() -> McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + assertThatThrownBy(() -> prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .tools(McpServerFeatures.AsyncToolSpecification.builder() .tool(duplicateTool) @@ -235,8 +232,7 @@ void testDuplicateToolsInBatchVarargsRegistration() { void testRemoveTool() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, request) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -248,8 +244,7 @@ void testRemoveTool() { @Test void testRemoveNonexistentTool() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .build(); @@ -264,8 +259,7 @@ void testRemoveNonexistentTool() { void testNotifyToolsListChanged() { Tool too = new McpSchema.Tool(TEST_TOOL_NAME, "Duplicate tool", emptyJsonSchema); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().tools(true).build()) .toolCall(too, (exchange, args) -> Mono.just(new CallToolResult(List.of(), false))) .build(); @@ -281,7 +275,7 @@ void testNotifyToolsListChanged() { @Test void testNotifyResourcesListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyResourcesListChanged()).verifyComplete(); @@ -290,7 +284,7 @@ void testNotifyResourcesListChanged() { @Test void testNotifyResourcesUpdated() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier .create(mcpAsyncServer @@ -302,8 +296,7 @@ void testNotifyResourcesUpdated() { @Test void testAddResource() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -319,8 +312,7 @@ void testAddResource() { @Test void testAddResourceWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().resources(true, false).build()) .build(); @@ -335,9 +327,7 @@ void testAddResourceWithNullSpecification() { @Test void testAddResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Resource resource = new Resource(TEST_RESOURCE_URI, "Test Resource", "text/plain", "Test resource description", null); @@ -353,9 +343,7 @@ void testAddResourceWithoutCapability() { @Test void testRemoveResourceWithoutCapability() { // Create a server without resource capabilities - McpAsyncServer serverWithoutResources = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutResources = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutResources.removeResource(TEST_RESOURCE_URI)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -369,7 +357,7 @@ void testRemoveResourceWithoutCapability() { @Test void testNotifyPromptsListChanged() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()).serverInfo("test-server", "1.0.0").build(); + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(mcpAsyncServer.notifyPromptsListChanged()).verifyComplete(); @@ -378,8 +366,7 @@ void testNotifyPromptsListChanged() { @Test void testAddPromptWithNullSpecification() { - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(false).build()) .build(); @@ -392,9 +379,7 @@ void testAddPromptWithNullSpecification() { @Test void testAddPromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); Prompt prompt = new Prompt(TEST_PROMPT_NAME, "Test Prompt", "Test Prompt", List.of()); McpServerFeatures.AsyncPromptSpecification specification = new McpServerFeatures.AsyncPromptSpecification( @@ -410,9 +395,7 @@ void testAddPromptWithoutCapability() { @Test void testRemovePromptWithoutCapability() { // Create a server without prompt capabilities - McpAsyncServer serverWithoutPrompts = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + McpAsyncServer serverWithoutPrompts = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); StepVerifier.create(serverWithoutPrompts.removePrompt(TEST_PROMPT_NAME)).verifyErrorSatisfies(error -> { assertThat(error).isInstanceOf(McpError.class) @@ -429,8 +412,7 @@ void testRemovePrompt() { prompt, (exchange, req) -> Mono.just(new GetPromptResult("Test prompt description", List .of(new PromptMessage(McpSchema.Role.ASSISTANT, new McpSchema.TextContent("Test content")))))); - var mcpAsyncServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .prompts(specification) .build(); @@ -442,8 +424,7 @@ void testRemovePrompt() { @Test void testRemoveNonexistentPrompt() { - var mcpAsyncServer2 = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var mcpAsyncServer2 = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .capabilities(ServerCapabilities.builder().prompts(true).build()) .build(); @@ -466,8 +447,7 @@ void testRootsChangeHandlers() { var rootsReceived = new McpSchema.Root[1]; var consumerCalled = new boolean[1]; - var singleConsumerServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var singleConsumerServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumerCalled[0] = true; if (!roots.isEmpty()) { @@ -486,8 +466,7 @@ void testRootsChangeHandlers() { var consumer2Called = new boolean[1]; var rootsContent = new List[1]; - var multipleConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var multipleConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> Mono.fromRunnable(() -> { consumer1Called[0] = true; rootsContent[0] = roots; @@ -500,8 +479,7 @@ void testRootsChangeHandlers() { onClose(); // Test error handling - var errorHandlingServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") + var errorHandlingServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") .rootsChangeHandlers(List.of((exchange, roots) -> { throw new RuntimeException("Test error"); })) @@ -513,9 +491,7 @@ void testRootsChangeHandlers() { onClose(); // Test without consumers - var noConsumersServer = McpServer.async(createMcpTransportProvider()) - .serverInfo("test-server", "1.0.0") - .build(); + var noConsumersServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").build(); assertThat(noConsumersServer).isNotNull(); assertThatCode(() -> noConsumersServer.closeGracefully().block(Duration.ofSeconds(10))) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index cbf13f73d..6ebbbe23e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -10,6 +10,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import io.modelcontextprotocol.spec.DefaultMcpTransportContext; import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpLoggableSession; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; @@ -30,7 +31,7 @@ public class McpAsyncServerExchange { private final String sessionId; - private final McpSession session; + private final McpLoggableSession session; private final McpSchema.ClientCapabilities clientCapabilities; @@ -38,8 +39,6 @@ public class McpAsyncServerExchange { private final McpTransportContext transportContext; - private volatile LoggingLevel minLoggingLevel = LoggingLevel.INFO; - private static final TypeReference CREATE_MESSAGE_RESULT_TYPE_REF = new TypeReference<>() { }; @@ -59,13 +58,16 @@ public class McpAsyncServerExchange { * features and functionality. * @param clientInfo The client implementation information. * @deprecated Use - * {@link #McpAsyncServerExchange(String, McpSession, McpSchema.ClientCapabilities, McpSchema.Implementation, McpTransportContext)} + * {@link #McpAsyncServerExchange(String, McpLoggableSession, McpSchema.ClientCapabilities, McpSchema.Implementation, McpTransportContext)} */ @Deprecated public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) { this.sessionId = null; - this.session = session; + if (!(session instanceof McpLoggableSession)) { + throw new IllegalArgumentException("Expecting session to be a McpLoggableSession instance"); + } + this.session = (McpLoggableSession) session; this.clientCapabilities = clientCapabilities; this.clientInfo = clientInfo; this.transportContext = McpTransportContext.EMPTY; @@ -80,8 +82,9 @@ public McpAsyncServerExchange(McpSession session, McpSchema.ClientCapabilities c * transport * @param clientInfo The client implementation information. */ - public McpAsyncServerExchange(String sessionId, McpSession session, McpSchema.ClientCapabilities clientCapabilities, - McpSchema.Implementation clientInfo, McpTransportContext transportContext) { + public McpAsyncServerExchange(String sessionId, McpLoggableSession session, + McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, + McpTransportContext transportContext) { this.sessionId = sessionId; this.session = session; this.clientCapabilities = clientCapabilities; @@ -208,7 +211,7 @@ public Mono loggingNotification(LoggingMessageNotification loggingMessageN } return Mono.defer(() -> { - if (this.isNotificationForLevelAllowed(loggingMessageNotification.level())) { + if (this.session.isNotificationForLevelAllowed(loggingMessageNotification.level())) { return this.session.sendNotification(McpSchema.METHOD_NOTIFICATION_MESSAGE, loggingMessageNotification); } return Mono.empty(); @@ -243,11 +246,7 @@ public Mono ping() { */ void setMinLoggingLevel(LoggingLevel minLoggingLevel) { Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); - this.minLoggingLevel = minLoggingLevel; - } - - private boolean isNotificationForLevelAllowed(LoggingLevel loggingLevel) { - return loggingLevel.level() >= this.minLoggingLevel.level(); + this.session.setMinLoggingLevel(minLoggingLevel); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d5ee64758..1892eb243 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -156,7 +156,7 @@ static SingleSessionSyncSpecification sync(McpServerTransportProvider transportP * @param transportProvider The transport layer implementation for MCP communication. * @return A new instance of {@link AsyncSpecification} for configuring the server. */ - static SingleSessionAsyncSpecification async(McpServerTransportProvider transportProvider) { + static AsyncSpecification async(McpServerTransportProvider transportProvider) { return new SingleSessionAsyncSpecification(transportProvider); } @@ -180,7 +180,7 @@ static StreamableSyncSpecification sync(McpStreamableServerTransportProvider tra * @param transportProvider The transport layer implementation for MCP communication. * @return A new instance of {@link AsyncSpecification} for configuring the server. */ - static StreamableServerAsyncSpecification async(McpStreamableServerTransportProvider transportProvider) { + static AsyncSpecification async(McpStreamableServerTransportProvider transportProvider) { return new StreamableServerAsyncSpecification(transportProvider); } @@ -206,6 +206,7 @@ private SingleSessionAsyncSpecification(McpServerTransportProvider transportProv * @return A new instance of {@link McpAsyncServer} configured with this builder's * settings. */ + @Override public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, @@ -232,6 +233,7 @@ public StreamableServerAsyncSpecification(McpStreamableServerTransportProvider t * @return A new instance of {@link McpAsyncServer} configured with this builder's * settings. */ + @Override public McpAsyncServer build() { var features = new McpServerFeatures.Async(this.serverInfo, this.serverCapabilities, this.tools, this.resources, this.resourceTemplates, this.prompts, this.completions, this.rootsChangeHandlers, @@ -248,7 +250,7 @@ public McpAsyncServer build() { /** * Asynchronous server specification. */ - class AsyncSpecification> { + abstract class AsyncSpecification> { McpUriTemplateManagerFactory uriTemplateManagerFactory = new DeafaultMcpUriTemplateManagerFactory(); @@ -295,12 +297,14 @@ class AsyncSpecification> { final List, Mono>> rootsChangeHandlers = new ArrayList<>(); - Duration requestTimeout = Duration.ofSeconds(10); // Default timeout + Duration requestTimeout = Duration.ofHours(10); // Default timeout - @SuppressWarnings("unchecked") - S self() { - return (S) this; - } + public abstract McpAsyncServer build(); + + // @SuppressWarnings("unchecked") + // S self() { + // return (S) this; + // } /** * Sets the URI template manager factory to use for creating URI templates. This @@ -309,10 +313,10 @@ S self() { * @return This builder instance for method chaining * @throws IllegalArgumentException if uriTemplateManagerFactory is null */ - public S uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { + public AsyncSpecification uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManagerFactory) { Assert.notNull(uriTemplateManagerFactory, "URI template manager factory must not be null"); this.uriTemplateManagerFactory = uriTemplateManagerFactory; - return self(); + return this; } /** @@ -324,10 +328,10 @@ public S uriTemplateManagerFactory(McpUriTemplateManagerFactory uriTemplateManag * @return This builder instance for method chaining * @throws IllegalArgumentException if requestTimeout is null */ - public S requestTimeout(Duration requestTimeout) { + public AsyncSpecification requestTimeout(Duration requestTimeout) { Assert.notNull(requestTimeout, "Request timeout must not be null"); this.requestTimeout = requestTimeout; - return self(); + return this; } /** @@ -339,10 +343,10 @@ public S requestTimeout(Duration requestTimeout) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverInfo is null */ - public S serverInfo(McpSchema.Implementation serverInfo) { + public AsyncSpecification serverInfo(McpSchema.Implementation serverInfo) { Assert.notNull(serverInfo, "Server info must not be null"); this.serverInfo = serverInfo; - return self(); + return this; } /** @@ -355,11 +359,11 @@ public S serverInfo(McpSchema.Implementation serverInfo) { * @throws IllegalArgumentException if name or version is null or empty * @see #serverInfo(McpSchema.Implementation) */ - public S serverInfo(String name, String version) { + public AsyncSpecification serverInfo(String name, String version) { Assert.hasText(name, "Name must not be null or empty"); Assert.hasText(version, "Version must not be null or empty"); this.serverInfo = new McpSchema.Implementation(name, version); - return self(); + return this; } /** @@ -369,9 +373,9 @@ public S serverInfo(String name, String version) { * @param instructions The instructions text. Can be null or empty. * @return This builder instance for method chaining */ - public S instructions(String instructions) { + public AsyncSpecification instructions(String instructions) { this.instructions = instructions; - return self(); + return this; } /** @@ -388,10 +392,10 @@ public S instructions(String instructions) { * @return This builder instance for method chaining * @throws IllegalArgumentException if serverCapabilities is null */ - public S capabilities(McpSchema.ServerCapabilities serverCapabilities) { + public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabilities) { Assert.notNull(serverCapabilities, "Server capabilities must not be null"); this.serverCapabilities = serverCapabilities; - return self(); + return this; } /** @@ -419,7 +423,7 @@ public S capabilities(McpSchema.ServerCapabilities serverCapabilities) { * calls that require a request object. */ @Deprecated - public S tool(McpSchema.Tool tool, + public AsyncSpecification tool(McpSchema.Tool tool, BiFunction, Mono> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -427,7 +431,7 @@ public S tool(McpSchema.Tool tool, this.tools.add(new McpServerFeatures.AsyncToolSpecification(tool, handler)); - return self(); + return this; } /** @@ -443,7 +447,7 @@ public S tool(McpSchema.Tool tool, * @return This builder instance for method chaining * @throws IllegalArgumentException if tool or handler is null */ - public S toolCall(McpSchema.Tool tool, + public AsyncSpecification toolCall(McpSchema.Tool tool, BiFunction> callHandler) { Assert.notNull(tool, "Tool must not be null"); @@ -453,7 +457,7 @@ public S toolCall(McpSchema.Tool tool, this.tools .add(McpServerFeatures.AsyncToolSpecification.builder().tool(tool).callHandler(callHandler).build()); - return self(); + return this; } /** @@ -466,7 +470,7 @@ public S toolCall(McpSchema.Tool tool, * @throws IllegalArgumentException if toolSpecifications is null * @see #tools(McpServerFeatures.AsyncToolSpecification...) */ - public S tools(List toolSpecifications) { + public AsyncSpecification tools(List toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (var tool : toolSpecifications) { @@ -474,7 +478,7 @@ public S tools(List toolSpecifications this.tools.add(tool); } - return self(); + return this; } /** @@ -493,14 +497,14 @@ public S tools(List toolSpecifications * @return This builder instance for method chaining * @throws IllegalArgumentException if toolSpecifications is null */ - public S tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { + public AsyncSpecification tools(McpServerFeatures.AsyncToolSpecification... toolSpecifications) { Assert.notNull(toolSpecifications, "Tool handlers list must not be null"); for (McpServerFeatures.AsyncToolSpecification tool : toolSpecifications) { assertNoDuplicateTool(tool.tool().name()); this.tools.add(tool); } - return self(); + return this; } private void assertNoDuplicateTool(String toolName) { @@ -519,10 +523,11 @@ private void assertNoDuplicateTool(String toolName) { * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public S resources(Map resourceSpecifications) { + public AsyncSpecification resources( + Map resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers map must not be null"); this.resources.putAll(resourceSpecifications); - return self(); + return this; } /** @@ -534,12 +539,13 @@ public S resources(Map res * @throws IllegalArgumentException if resourceSpecifications is null * @see #resources(McpServerFeatures.AsyncResourceSpecification...) */ - public S resources(List resourceSpecifications) { + public AsyncSpecification resources( + List resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); } - return self(); + return this; } /** @@ -559,12 +565,12 @@ public S resources(List resourceSp * @return This builder instance for method chaining * @throws IllegalArgumentException if resourceSpecifications is null */ - public S resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { + public AsyncSpecification resources(McpServerFeatures.AsyncResourceSpecification... resourceSpecifications) { Assert.notNull(resourceSpecifications, "Resource handlers list must not be null"); for (McpServerFeatures.AsyncResourceSpecification resource : resourceSpecifications) { this.resources.put(resource.resource().uri(), resource); } - return self(); + return this; } /** @@ -584,10 +590,10 @@ public S resources(McpServerFeatures.AsyncResourceSpecification... resourceSpeci * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(ResourceTemplate...) */ - public S resourceTemplates(List resourceTemplates) { + public AsyncSpecification resourceTemplates(List resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); this.resourceTemplates.addAll(resourceTemplates); - return self(); + return this; } /** @@ -598,12 +604,12 @@ public S resourceTemplates(List resourceTemplates) { * @throws IllegalArgumentException if resourceTemplates is null. * @see #resourceTemplates(List) */ - public S resourceTemplates(ResourceTemplate... resourceTemplates) { + public AsyncSpecification resourceTemplates(ResourceTemplate... resourceTemplates) { Assert.notNull(resourceTemplates, "Resource templates must not be null"); for (ResourceTemplate resourceTemplate : resourceTemplates) { this.resourceTemplates.add(resourceTemplate); } - return self(); + return this; } /** @@ -623,10 +629,10 @@ public S resourceTemplates(ResourceTemplate... resourceTemplates) { * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public S prompts(Map prompts) { + public AsyncSpecification prompts(Map prompts) { Assert.notNull(prompts, "Prompts map must not be null"); this.prompts.putAll(prompts); - return self(); + return this; } /** @@ -637,12 +643,12 @@ public S prompts(Map prompts * @throws IllegalArgumentException if prompts is null * @see #prompts(McpServerFeatures.AsyncPromptSpecification...) */ - public S prompts(List prompts) { + public AsyncSpecification prompts(List prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } - return self(); + return this; } /** @@ -661,12 +667,12 @@ public S prompts(List prompts) { * @return This builder instance for method chaining * @throws IllegalArgumentException if prompts is null */ - public S prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { + public AsyncSpecification prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { Assert.notNull(prompts, "Prompts list must not be null"); for (McpServerFeatures.AsyncPromptSpecification prompt : prompts) { this.prompts.put(prompt.prompt().name(), prompt); } - return self(); + return this; } /** @@ -676,12 +682,12 @@ public S prompts(McpServerFeatures.AsyncPromptSpecification... prompts) { * @return This builder instance for method chaining * @throws IllegalArgumentException if completions is null */ - public S completions(List completions) { + public AsyncSpecification completions(List completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); } - return self(); + return this; } /** @@ -691,12 +697,12 @@ public S completions(List comple * @return This builder instance for method chaining * @throws IllegalArgumentException if completions is null */ - public S completions(McpServerFeatures.AsyncCompletionSpecification... completions) { + public AsyncSpecification completions(McpServerFeatures.AsyncCompletionSpecification... completions) { Assert.notNull(completions, "Completions list must not be null"); for (McpServerFeatures.AsyncCompletionSpecification completion : completions) { this.completions.put(completion.referenceKey(), completion); } - return self(); + return this; } /** @@ -709,10 +715,11 @@ public S completions(McpServerFeatures.AsyncCompletionSpecification... completio * @return This builder instance for method chaining * @throws IllegalArgumentException if consumer is null */ - public S rootsChangeHandler(BiFunction, Mono> handler) { + public AsyncSpecification rootsChangeHandler( + BiFunction, Mono> handler) { Assert.notNull(handler, "Consumer must not be null"); this.rootsChangeHandlers.add(handler); - return self(); + return this; } /** @@ -724,11 +731,11 @@ public S rootsChangeHandler(BiFunction rootsChangeHandlers( List, Mono>> handlers) { Assert.notNull(handlers, "Handlers list must not be null"); this.rootsChangeHandlers.addAll(handlers); - return self(); + return this; } /** @@ -740,7 +747,7 @@ public S rootsChangeHandlers( * @throws IllegalArgumentException if consumers is null * @see #rootsChangeHandlers(List) */ - public S rootsChangeHandlers( + public AsyncSpecification rootsChangeHandlers( @SuppressWarnings("unchecked") BiFunction, Mono>... handlers) { Assert.notNull(handlers, "Handlers list must not be null"); return this.rootsChangeHandlers(Arrays.asList(handlers)); @@ -752,10 +759,10 @@ public S rootsChangeHandlers( * @return This builder instance for method chaining. * @throws IllegalArgumentException if objectMapper is null */ - public S objectMapper(ObjectMapper objectMapper) { + public AsyncSpecification objectMapper(ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); this.objectMapper = objectMapper; - return self(); + return this; } /** @@ -766,10 +773,10 @@ public S objectMapper(ObjectMapper objectMapper) { * @return This builder instance for method chaining * @throws IllegalArgumentException if jsonSchemaValidator is null */ - public S jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { + public AsyncSpecification jsonSchemaValidator(JsonSchemaValidator jsonSchemaValidator) { Assert.notNull(jsonSchemaValidator, "JsonSchemaValidator must not be null"); this.jsonSchemaValidator = jsonSchemaValidator; - return self(); + return this; } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java new file mode 100644 index 000000000..ac6d01d91 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpLoggableSession.java @@ -0,0 +1,14 @@ +package io.modelcontextprotocol.spec; + +public interface McpLoggableSession extends McpSession { + + /** + * Set the minimum logging level for the client. Messages below this level will be + * filtered out. + * @param minLoggingLevel The minimum logging level + */ + void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel); + + boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index e9c23db6a..a3812dbc2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -44,7 +44,7 @@ public final class McpSchema { private McpSchema() { } - public static final String LATEST_PROTOCOL_VERSION = "2024-11-05"; + public static final String LATEST_PROTOCOL_VERSION = "2025-03-26"; public static final String JSONRPC_VERSION = "2.0"; diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 2dad7174e..8a569110b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -12,6 +12,7 @@ import io.modelcontextprotocol.server.McpInitRequestHandler; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -22,7 +23,7 @@ * Represents a Model Control Protocol (MCP) session on the server side. It manages * bidirectional JSON-RPC communication with the client. */ -public class McpServerSession implements McpSession { +public class McpServerSession implements McpLoggableSession { private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class); @@ -59,6 +60,8 @@ public class McpServerSession implements McpSession { private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + /** * Creates a new server session with the given parameters and the transport to use. * @param id session id @@ -112,6 +115,17 @@ private String generateRequestId() { return this.id + "-" + this.requestCounter.getAndIncrement(); } + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { String requestId = this.generateRequestId(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java index 42d170db5..7b29ca651 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSession.java @@ -5,6 +5,7 @@ package io.modelcontextprotocol.spec; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index 0532704b5..42313015a 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -4,6 +4,7 @@ import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -19,7 +20,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; -public class McpStreamableServerSession implements McpSession { +public class McpStreamableServerSession implements McpLoggableSession { private static final Logger logger = LoggerFactory.getLogger(McpStreamableServerSession.class); @@ -49,13 +50,19 @@ public class McpStreamableServerSession implements McpSession { private final AtomicInteger state = new AtomicInteger(STATE_UNINITIALIZED); - private final AtomicReference listeningStreamRef = new AtomicReference<>(); + private final AtomicReference listeningStreamRef; + + private final MissingMcpTransportSession missingMcpTransportSession; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; public McpStreamableServerSession(String id, McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo, Duration requestTimeout, InitNotificationHandler initNotificationHandler, Map> requestHandlers, Map notificationHandlers) { this.id = id; + this.missingMcpTransportSession = new MissingMcpTransportSession(id); + this.listeningStreamRef = new AtomicReference<>(this.missingMcpTransportSession); this.clientCapabilities.lazySet(clientCapabilities); this.clientInfo.lazySet(clientInfo); this.requestTimeout = requestTimeout; @@ -64,6 +71,17 @@ public McpStreamableServerSession(String id, McpSchema.ClientCapabilities client this.notificationHandlers = notificationHandlers; } + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + public String getId() { return this.id; } @@ -75,18 +93,16 @@ private String generateRequestId() { @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { return Mono.defer(() -> { - McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); - return listeningStream != null ? listeningStream.sendRequest(method, requestParams, typeRef) - : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendRequest(method, requestParams, typeRef); }); } @Override public Mono sendNotification(String method, Object params) { return Mono.defer(() -> { - McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); - return listeningStream != null ? listeningStream.sendNotification(method, params) - : Mono.error(new RuntimeException("Generic stream is unavailable for session " + this.id)); + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return listeningStream.sendNotification(method, params); }); } @@ -143,12 +159,9 @@ public Mono accept(McpSchema.JSONRPCNotification notification) { logger.error("No handler registered for notification method: {}", notification.method()); return Mono.empty(); } - McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); - return notificationHandler.handle( - new McpAsyncServerExchange(this.id, - listeningStream != null ? listeningStream : MissingMcpTransportSession.INSTANCE, - this.clientCapabilities.get(), this.clientInfo.get(), transportContext), - notification.params()); + McpLoggableSession listeningStream = this.listeningStreamRef.get(); + return notificationHandler.handle(new McpAsyncServerExchange(this.id, listeningStream, + this.clientCapabilities.get(), this.clientInfo.get(), transportContext), notification.params()); }); } @@ -160,6 +173,7 @@ public Mono accept(McpSchema.JSONRPCResponse response) { return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO // JSONize } + // TODO: encapsulate this inside the stream itself var sink = stream.pendingResponses.remove(response.id()); if (sink == null) { return Mono.error(new McpError("Unexpected response for unknown id " + response.id())); // TODO @@ -182,20 +196,15 @@ private MethodNotFoundError getMethodNotFoundError(String method) { @Override public Mono closeGracefully() { return Mono.defer(() -> { - McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); - return listeningStream != null ? listeningStream.closeGracefully() : Mono.empty(); // TODO: - // Also - // close - // all - // the - // open - // streams + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); + return listeningStream.closeGracefully(); + // TODO: Also close all the open streams }); } @Override public void close() { - McpStreamableServerSessionStream listeningStream = this.listeningStreamRef.get(); + McpLoggableSession listeningStream = this.listeningStreamRef.getAndSet(missingMcpTransportSession); if (listeningStream != null) { listeningStream.close(); } @@ -239,7 +248,7 @@ public record McpStreamableServerSessionInit(McpStreamableServerSession session, Mono initResult) { } - public final class McpStreamableServerSessionStream implements McpSession { + public final class McpStreamableServerSessionStream implements McpLoggableSession { private final ConcurrentHashMap> pendingResponses = new ConcurrentHashMap<>(); @@ -257,6 +266,17 @@ public McpStreamableServerSessionStream(McpStreamableServerTransport transport) this.uuidGenerator = () -> this.transportId + "_" + UUID.randomUUID(); } + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + McpStreamableServerSession.this.setMinLoggingLevel(minLoggingLevel); + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return McpStreamableServerSession.this.isNotificationForLevelAllowed(loggingLevel); + } + @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { String requestId = McpStreamableServerSession.this.generateRequestId(); @@ -304,7 +324,8 @@ public Mono closeGracefully() { this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); this.pendingResponses.clear(); // If this was the generic stream, reset it - McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, null); + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); return this.transport.closeGracefully(); }); @@ -315,7 +336,8 @@ public void close() { this.pendingResponses.values().forEach(s -> s.error(new RuntimeException("Stream closed"))); this.pendingResponses.clear(); // If this was the generic stream, reset it - McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, null); + McpStreamableServerSession.this.listeningStreamRef.compareAndExchange(this, + McpStreamableServerSession.this.missingMcpTransportSession); McpStreamableServerSession.this.requestIdToStream.values().removeIf(this::equals); this.transport.close(); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java index 79ca44d2c..f41c8768e 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/MissingMcpTransportSession.java @@ -1,20 +1,27 @@ package io.modelcontextprotocol.spec; import com.fasterxml.jackson.core.type.TypeReference; +import io.modelcontextprotocol.util.Assert; import reactor.core.publisher.Mono; -public class MissingMcpTransportSession implements McpSession { +public class MissingMcpTransportSession implements McpLoggableSession { - public static final MissingMcpTransportSession INSTANCE = new MissingMcpTransportSession(); + private final String sessionId; + + private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO; + + public MissingMcpTransportSession(String sessionId) { + this.sessionId = sessionId; + } @Override public Mono sendRequest(String method, Object requestParams, TypeReference typeRef) { - return Mono.error(new IllegalStateException("Stream unavailable")); + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); } @Override public Mono sendNotification(String method, Object params) { - return Mono.error(new IllegalStateException("Stream unavailable")); + return Mono.error(new IllegalStateException("Stream unavailable for session " + this.sessionId)); } @Override @@ -26,4 +33,15 @@ public Mono closeGracefully() { public void close() { } + @Override + public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) { + Assert.notNull(minLoggingLevel, "minLoggingLevel must not be null"); + this.minLoggingLevel = minLoggingLevel; + } + + @Override + public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) { + return loggingLevel.level() >= this.minLoggingLevel.level(); + } + }