Skip to content

Commit 2d2ad75

Browse files
committed
fix: resolve absolute and relative message endpoint uri
fixes java mcp message endpoint error modelcontextprotocol#103
1 parent fab434c commit 2d2ad75

File tree

5 files changed

+135
-2
lines changed

5 files changed

+135
-2
lines changed

mcp/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,12 @@
126126
<version>${junit.version}</version>
127127
<scope>test</scope>
128128
</dependency>
129+
<dependency>
130+
<groupId>org.junit.jupiter</groupId>
131+
<artifactId>junit-jupiter-params</artifactId>
132+
<version>${junit.version}</version>
133+
<scope>test</scope>
134+
</dependency>
129135
<dependency>
130136
<groupId>org.mockito</groupId>
131137
<artifactId>mockito-core</artifactId>

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import io.modelcontextprotocol.spec.McpSchema;
2525
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
2626
import io.modelcontextprotocol.util.Assert;
27+
import io.modelcontextprotocol.util.Utils;
2728
import org.slf4j.Logger;
2829
import org.slf4j.LoggerFactory;
2930
import reactor.core.publisher.Mono;
@@ -340,7 +341,8 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
340341
CompletableFuture<Void> future = new CompletableFuture<>();
341342
connectionFuture.set(future);
342343

343-
sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() {
344+
String clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
345+
sseClient.subscribe(clientUri, new FlowSseClient.SseEventHandler() {
344346
@Override
345347
public void onEvent(SseEvent event) {
346348
if (isClosing) {

mcp/src/main/java/io/modelcontextprotocol/util/Utils.java

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
package io.modelcontextprotocol.util;
66

7+
import java.net.URI;
8+
import java.net.URISyntaxException;
79
import java.util.Collection;
810
import java.util.Map;
911

@@ -52,4 +54,68 @@ public static boolean isEmpty(@Nullable Map<?, ?> map) {
5254
return (map == null || map.isEmpty());
5355
}
5456

57+
/**
58+
* Resolves the given endpoint URL against the base URL.
59+
* <ul>
60+
* <li>If the endpoint URL is relative, it will be resolved against the base URL.</li>
61+
* <li>If the endpoint URL is absolute, it will be validated to ensure it matches the
62+
* base URL's scheme, authority, and path prefix.</li>
63+
* <li>If validation fails for an absolute URL, an {@link IllegalArgumentException} is
64+
* thrown.</li>
65+
* </ul>
66+
* @param baseUrl The base URL (must be absolute)
67+
* @param endpointUrl The endpoint URL (can be relative or absolute)
68+
* @return The resolved endpoint URL as a string
69+
* @throws IllegalArgumentException If the absolute endpoint URL does not match the
70+
* base URL or URI is malformed
71+
*/
72+
public static String resolveUri(String baseUrl, String endpointUrl) {
73+
try {
74+
URI baseUri = new URI(baseUrl);
75+
URI endpointUri = new URI(endpointUrl);
76+
if (!endpointUri.isAbsolute()) {
77+
URI resolvedUri = baseUri.resolve(endpointUri);
78+
return resolvedUri.toString();
79+
}
80+
else {
81+
if (isUnderBaseUri(baseUri, endpointUri)) {
82+
return endpointUri.toString();
83+
}
84+
else {
85+
throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL.");
86+
}
87+
}
88+
}
89+
catch (URISyntaxException e) {
90+
throw new IllegalArgumentException("Cannot resolve URI: " + e.getMessage(), e);
91+
}
92+
}
93+
94+
/**
95+
* Checks if the given absolute endpoint URI falls under the base URI. It validates
96+
* the scheme, authority (host and port), and ensures that the base path is a prefix
97+
* of the endpoint path.
98+
* @param baseUri The base URI
99+
* @param endpointUri The endpoint URI to check
100+
* @return true if endpointUri is within baseUri's hierarchy, false otherwise
101+
*/
102+
private static boolean isUnderBaseUri(URI baseUri, URI endpointUri) {
103+
if (!baseUri.getScheme().equals(endpointUri.getScheme())
104+
|| !baseUri.getAuthority().equals(endpointUri.getAuthority())) {
105+
return false;
106+
}
107+
108+
URI normalizedBase = baseUri.normalize();
109+
URI normalizedEndpoint = endpointUri.normalize();
110+
111+
String basePath = normalizedBase.getPath();
112+
String endpointPath = normalizedEndpoint.getPath();
113+
114+
if (basePath.endsWith("/")) {
115+
basePath = basePath.substring(0, basePath.length() - 1);
116+
}
117+
118+
return endpointPath.startsWith(basePath);
119+
}
120+
55121
}

mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import java.net.URI;
88
import java.net.http.HttpClient;
99
import java.net.http.HttpRequest;
10+
import java.net.http.HttpResponse;
1011
import java.time.Duration;
1112
import java.util.Map;
13+
import java.util.concurrent.CompletableFuture;
1214
import java.util.concurrent.atomic.AtomicBoolean;
1315
import java.util.concurrent.atomic.AtomicInteger;
1416
import java.util.concurrent.atomic.AtomicReference;
15-
import java.util.function.Consumer;
1617
import java.util.function.Function;
1718

1819
import io.modelcontextprotocol.spec.McpSchema;
@@ -21,6 +22,8 @@
2122
import org.junit.jupiter.api.BeforeEach;
2223
import org.junit.jupiter.api.Test;
2324
import org.junit.jupiter.api.Timeout;
25+
import org.mockito.ArgumentCaptor;
26+
import org.mockito.Mockito;
2427
import org.testcontainers.containers.GenericContainer;
2528
import org.testcontainers.containers.wait.strategy.Wait;
2629
import reactor.core.publisher.Mono;
@@ -31,6 +34,9 @@
3134

3235
import static org.assertj.core.api.Assertions.assertThat;
3336
import static org.assertj.core.api.Assertions.assertThatCode;
37+
import static org.mockito.ArgumentMatchers.any;
38+
import static org.mockito.Mockito.verify;
39+
import static org.mockito.Mockito.when;
3440

3541
import com.fasterxml.jackson.databind.ObjectMapper;
3642

@@ -364,4 +370,23 @@ void testChainedCustomizations() {
364370
customizedTransport.closeGracefully().block();
365371
}
366372

373+
@Test
374+
@SuppressWarnings("unchecked")
375+
void testResolvingClientEndpoint() {
376+
HttpClient httpClient = Mockito.mock(HttpClient.class);
377+
HttpResponse<Void> httpResponse = Mockito.mock(HttpResponse.class);
378+
CompletableFuture<HttpResponse<Void>> future = new CompletableFuture<>();
379+
future.complete(httpResponse);
380+
when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future);
381+
382+
HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(),
383+
"http://example.com", "http://example.com/sse", new ObjectMapper());
384+
385+
transport.connect(Function.identity());
386+
387+
ArgumentCaptor<HttpRequest> httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
388+
verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class));
389+
assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse"));
390+
}
391+
367392
}

mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
import java.util.List;
1111
import java.util.Map;
1212

13+
import static org.assertj.core.api.Assertions.assertThat;
14+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
1315
import static org.junit.jupiter.api.Assertions.assertFalse;
1416
import static org.junit.jupiter.api.Assertions.assertTrue;
17+
import org.junit.jupiter.params.ParameterizedTest;
18+
import org.junit.jupiter.params.provider.CsvSource;
1519

1620
class UtilsTests {
1721

@@ -37,4 +41,34 @@ void testMapIsEmpty() {
3741
assertFalse(Utils.isEmpty(Map.of("key", "value")));
3842
}
3943

44+
@ParameterizedTest
45+
@CsvSource({
46+
// relative endpoints
47+
"http://localhost:8080/root, /api/v1, http://localhost:8080/api/v1",
48+
"http://localhost:8080/root/, api, http://localhost:8080/root/api",
49+
"http://localhost:8080, /api, http://localhost:8080/api",
50+
// absolute endpoints matching base
51+
"http://localhost:8080/root, http://localhost:8080/root/api/v1, http://localhost:8080/root/api/v1",
52+
"http://localhost:8080/root, http://localhost:8080/root, http://localhost:8080/root" })
53+
void testValidUriResolution(String baseUrl, String endpoint, String expectedResult) {
54+
String result = Utils.resolveUri(baseUrl, endpoint);
55+
assertThat(result).isEqualTo(expectedResult);
56+
}
57+
58+
@ParameterizedTest
59+
@CsvSource({ "http://localhost:8080/root, http://localhost:8080/other/api",
60+
"http://localhost:8080/root, http://otherhost/api",
61+
"http://localhost:8080/root, http://localhost:9090/root/api" })
62+
void testAbsoluteUriNotMatchingBase(String baseUrl, String endpoint) {
63+
assertThatThrownBy(() -> Utils.resolveUri(baseUrl, endpoint)).isInstanceOf(IllegalArgumentException.class)
64+
.hasMessageContaining("does not match the base URL");
65+
}
66+
67+
@ParameterizedTest
68+
@CsvSource({ "http://localhost:8080/<>root", "http://localhost:8080/ root", "http://localhost:8080/root}" })
69+
void testInvalidUri(String baseUrl) {
70+
assertThatThrownBy(() -> Utils.resolveUri(baseUrl, "")).isInstanceOf(IllegalArgumentException.class)
71+
.hasMessageContaining("Cannot resolve URI");
72+
}
73+
4074
}

0 commit comments

Comments
 (0)