From 2d2ad750f62d2481cb5867016a065a852df2d954 Mon Sep 17 00:00:00 2001 From: E550448 Date: Sun, 13 Apr 2025 12:54:48 +0200 Subject: [PATCH 1/2] fix: resolve absolute and relative message endpoint uri fixes java mcp message endpoint error #103 --- mcp/pom.xml | 6 ++ .../HttpClientSseClientTransport.java | 4 +- .../io/modelcontextprotocol/util/Utils.java | 66 +++++++++++++++++++ .../HttpClientSseClientTransportTests.java | 27 +++++++- .../modelcontextprotocol/util/UtilsTests.java | 34 ++++++++++ 5 files changed, 135 insertions(+), 2 deletions(-) diff --git a/mcp/pom.xml b/mcp/pom.xml index edb1c8f0..c55cd465 100644 --- a/mcp/pom.xml +++ b/mcp/pom.xml @@ -126,6 +126,12 @@ ${junit.version} test + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + org.mockito mockito-core diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 632d3844..8c77fc84 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -24,6 +24,7 @@ import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; @@ -340,7 +341,8 @@ public Mono connect(Function, Mono> h CompletableFuture future = new CompletableFuture<>(); connectionFuture.set(future); - sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() { + String clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + sseClient.subscribe(clientUri, new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { if (isClosing) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 0f799ca0..cac6b74c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,6 +4,8 @@ package io.modelcontextprotocol.util; +import java.net.URI; +import java.net.URISyntaxException; import java.util.Collection; import java.util.Map; @@ -52,4 +54,68 @@ public static boolean isEmpty(@Nullable Map map) { return (map == null || map.isEmpty()); } + /** + * Resolves the given endpoint URL against the base URL. + *
    + *
  • If the endpoint URL is relative, it will be resolved against the base URL.
  • + *
  • If the endpoint URL is absolute, it will be validated to ensure it matches the + * base URL's scheme, authority, and path prefix.
  • + *
  • If validation fails for an absolute URL, an {@link IllegalArgumentException} is + * thrown.
  • + *
+ * @param baseUrl The base URL (must be absolute) + * @param endpointUrl The endpoint URL (can be relative or absolute) + * @return The resolved endpoint URL as a string + * @throws IllegalArgumentException If the absolute endpoint URL does not match the + * base URL or URI is malformed + */ + public static String resolveUri(String baseUrl, String endpointUrl) { + try { + URI baseUri = new URI(baseUrl); + URI endpointUri = new URI(endpointUrl); + if (!endpointUri.isAbsolute()) { + URI resolvedUri = baseUri.resolve(endpointUri); + return resolvedUri.toString(); + } + else { + if (isUnderBaseUri(baseUri, endpointUri)) { + return endpointUri.toString(); + } + else { + throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); + } + } + } + catch (URISyntaxException e) { + throw new IllegalArgumentException("Cannot resolve URI: " + e.getMessage(), e); + } + } + + /** + * Checks if the given absolute endpoint URI falls under the base URI. It validates + * the scheme, authority (host and port), and ensures that the base path is a prefix + * of the endpoint path. + * @param baseUri The base URI + * @param endpointUri The endpoint URI to check + * @return true if endpointUri is within baseUri's hierarchy, false otherwise + */ + private static boolean isUnderBaseUri(URI baseUri, URI endpointUri) { + if (!baseUri.getScheme().equals(endpointUri.getScheme()) + || !baseUri.getAuthority().equals(endpointUri.getAuthority())) { + return false; + } + + URI normalizedBase = baseUri.normalize(); + URI normalizedEndpoint = endpointUri.normalize(); + + String basePath = normalizedBase.getPath(); + String endpointPath = normalizedEndpoint.getPath(); + + if (basePath.endsWith("/")) { + basePath = basePath.substring(0, basePath.length() - 1); + } + + return endpointPath.startsWith(basePath); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java index e5178c0e..a75f7675 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java @@ -7,12 +7,13 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; +import java.net.http.HttpResponse; import java.time.Duration; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import java.util.function.Function; import io.modelcontextprotocol.spec.McpSchema; @@ -21,6 +22,8 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; import org.testcontainers.containers.GenericContainer; import org.testcontainers.containers.wait.strategy.Wait; import reactor.core.publisher.Mono; @@ -31,6 +34,9 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; @@ -364,4 +370,23 @@ void testChainedCustomizations() { customizedTransport.closeGracefully().block(); } + @Test + @SuppressWarnings("unchecked") + void testResolvingClientEndpoint() { + HttpClient httpClient = Mockito.mock(HttpClient.class); + HttpResponse httpResponse = Mockito.mock(HttpResponse.class); + CompletableFuture> future = new CompletableFuture<>(); + future.complete(httpResponse); + when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future); + + HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(), + "http://example.com", "http://example.com/sse", new ObjectMapper()); + + transport.connect(Function.identity()); + + ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class)); + assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse")); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java index aced20cb..0e8bf074 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java @@ -10,8 +10,12 @@ import java.util.List; import java.util.Map; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; class UtilsTests { @@ -37,4 +41,34 @@ void testMapIsEmpty() { assertFalse(Utils.isEmpty(Map.of("key", "value"))); } + @ParameterizedTest + @CsvSource({ + // relative endpoints + "http://localhost:8080/root, /api/v1, http://localhost:8080/api/v1", + "http://localhost:8080/root/, api, http://localhost:8080/root/api", + "http://localhost:8080, /api, http://localhost:8080/api", + // absolute endpoints matching base + "http://localhost:8080/root, http://localhost:8080/root/api/v1, http://localhost:8080/root/api/v1", + "http://localhost:8080/root, http://localhost:8080/root, http://localhost:8080/root" }) + void testValidUriResolution(String baseUrl, String endpoint, String expectedResult) { + String result = Utils.resolveUri(baseUrl, endpoint); + assertThat(result).isEqualTo(expectedResult); + } + + @ParameterizedTest + @CsvSource({ "http://localhost:8080/root, http://localhost:8080/other/api", + "http://localhost:8080/root, http://otherhost/api", + "http://localhost:8080/root, http://localhost:9090/root/api" }) + void testAbsoluteUriNotMatchingBase(String baseUrl, String endpoint) { + assertThatThrownBy(() -> Utils.resolveUri(baseUrl, endpoint)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("does not match the base URL"); + } + + @ParameterizedTest + @CsvSource({ "http://localhost:8080/<>root", "http://localhost:8080/ root", "http://localhost:8080/root}" }) + void testInvalidUri(String baseUrl) { + assertThatThrownBy(() -> Utils.resolveUri(baseUrl, "")).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot resolve URI"); + } + } \ No newline at end of file From 11fbb11d94e9831c4d3579ef8ea13f72da73eec3 Mon Sep 17 00:00:00 2001 From: ashakirin Date: Thu, 17 Apr 2025 14:53:27 +0200 Subject: [PATCH 2/2] fix: improvements for absolute and relative URI resolving: #103 --- .../HttpClientSseClientTransport.java | 11 ++++--- .../io/modelcontextprotocol/util/Utils.java | 32 ++++++------------- .../modelcontextprotocol/util/UtilsTests.java | 15 +++------ 3 files changed, 20 insertions(+), 38 deletions(-) diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 8c77fc84..99cf2a62 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -70,7 +70,7 @@ public class HttpClientSseClientTransport implements McpClientTransport { private static final String DEFAULT_SSE_ENDPOINT = "/sse"; /** Base URI for the MCP server */ - private final String baseUri; + private final URI baseUri; /** SSE endpoint path */ private final String sseEndpoint; @@ -179,7 +179,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); Assert.notNull(httpClient, "httpClient must not be null"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); - this.baseUri = baseUri; + this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; this.objectMapper = objectMapper; this.httpClient = httpClient; @@ -341,8 +341,8 @@ public Mono connect(Function, Mono> h CompletableFuture future = new CompletableFuture<>(); connectionFuture.set(future); - String clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); - sseClient.subscribe(clientUri, new FlowSseClient.SseEventHandler() { + URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { if (isClosing) { @@ -414,7 +414,8 @@ public Mono sendMessage(JSONRPCMessage message) { try { String jsonText = this.objectMapper.writeValueAsString(message); - HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint)) + URI requestUri = Utils.resolveUri(baseUri, endpoint); + HttpRequest request = this.requestBuilder.uri(requestUri) .POST(HttpRequest.BodyPublishers.ofString(jsonText)) .build(); diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index cac6b74c..8e654e59 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,13 +4,12 @@ package io.modelcontextprotocol.util; +import reactor.util.annotation.Nullable; + import java.net.URI; -import java.net.URISyntaxException; import java.util.Collection; import java.util.Map; -import reactor.util.annotation.Nullable; - /** * Miscellaneous utility methods. * @@ -65,29 +64,17 @@ public static boolean isEmpty(@Nullable Map map) { * * @param baseUrl The base URL (must be absolute) * @param endpointUrl The endpoint URL (can be relative or absolute) - * @return The resolved endpoint URL as a string + * @return The resolved endpoint URI * @throws IllegalArgumentException If the absolute endpoint URL does not match the * base URL or URI is malformed */ - public static String resolveUri(String baseUrl, String endpointUrl) { - try { - URI baseUri = new URI(baseUrl); - URI endpointUri = new URI(endpointUrl); - if (!endpointUri.isAbsolute()) { - URI resolvedUri = baseUri.resolve(endpointUri); - return resolvedUri.toString(); - } - else { - if (isUnderBaseUri(baseUri, endpointUri)) { - return endpointUri.toString(); - } - else { - throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); - } - } + public static URI resolveUri(URI baseUrl, String endpointUrl) { + URI endpointUri = URI.create(endpointUrl); + if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { + throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); } - catch (URISyntaxException e) { - throw new IllegalArgumentException("Cannot resolve URI: " + e.getMessage(), e); + else { + return baseUrl.resolve(endpointUri); } } @@ -114,7 +101,6 @@ private static boolean isUnderBaseUri(URI baseUri, URI endpointUri) { if (basePath.endsWith("/")) { basePath = basePath.substring(0, basePath.length() - 1); } - return endpointPath.startsWith(basePath); } diff --git a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java index 0e8bf074..0f2e689b 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java @@ -6,6 +6,7 @@ import org.junit.jupiter.api.Test; +import java.net.URI; import java.util.Collection; import java.util.List; import java.util.Map; @@ -51,8 +52,8 @@ void testMapIsEmpty() { "http://localhost:8080/root, http://localhost:8080/root/api/v1, http://localhost:8080/root/api/v1", "http://localhost:8080/root, http://localhost:8080/root, http://localhost:8080/root" }) void testValidUriResolution(String baseUrl, String endpoint, String expectedResult) { - String result = Utils.resolveUri(baseUrl, endpoint); - assertThat(result).isEqualTo(expectedResult); + URI result = Utils.resolveUri(URI.create(baseUrl), endpoint); + assertThat(result.toString()).isEqualTo(expectedResult); } @ParameterizedTest @@ -60,15 +61,9 @@ void testValidUriResolution(String baseUrl, String endpoint, String expectedResu "http://localhost:8080/root, http://otherhost/api", "http://localhost:8080/root, http://localhost:9090/root/api" }) void testAbsoluteUriNotMatchingBase(String baseUrl, String endpoint) { - assertThatThrownBy(() -> Utils.resolveUri(baseUrl, endpoint)).isInstanceOf(IllegalArgumentException.class) + assertThatThrownBy(() -> Utils.resolveUri(URI.create(baseUrl), endpoint)) + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("does not match the base URL"); } - @ParameterizedTest - @CsvSource({ "http://localhost:8080/<>root", "http://localhost:8080/ root", "http://localhost:8080/root}" }) - void testInvalidUri(String baseUrl) { - assertThatThrownBy(() -> Utils.resolveUri(baseUrl, "")).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Cannot resolve URI"); - } - } \ No newline at end of file