Skip to content

fix: resolve absolute and relative message endpoint uri #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mcp/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
Expand Down Expand Up @@ -69,7 +70,7 @@ public class HttpClientSseClientTransport implements McpClientTransport {
private static final String DEFAULT_SSE_ENDPOINT = "/sse";

/** Base URI for the MCP server */
private final String baseUri;
private final URI baseUri;

/** SSE endpoint path */
private final String sseEndpoint;
Expand Down Expand Up @@ -178,7 +179,7 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques
Assert.hasText(sseEndpoint, "sseEndpoint must not be empty");
Assert.notNull(httpClient, "httpClient must not be null");
Assert.notNull(requestBuilder, "requestBuilder must not be null");
this.baseUri = baseUri;
this.baseUri = URI.create(baseUri);
this.sseEndpoint = sseEndpoint;
this.objectMapper = objectMapper;
this.httpClient = httpClient;
Expand Down Expand Up @@ -340,7 +341,8 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
CompletableFuture<Void> future = new CompletableFuture<>();
connectionFuture.set(future);

sseClient.subscribe(this.baseUri + this.sseEndpoint, new FlowSseClient.SseEventHandler() {
URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint);
sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() {
@Override
public void onEvent(SseEvent event) {
if (isClosing) {
Expand Down Expand Up @@ -412,7 +414,8 @@ public Mono<Void> sendMessage(JSONRPCMessage message) {

try {
String jsonText = this.objectMapper.writeValueAsString(message);
HttpRequest request = this.requestBuilder.uri(URI.create(this.baseUri + endpoint))
URI requestUri = Utils.resolveUri(baseUri, endpoint);
HttpRequest request = this.requestBuilder.uri(requestUri)
.POST(HttpRequest.BodyPublishers.ofString(jsonText))
.build();

Expand Down
56 changes: 54 additions & 2 deletions mcp/src/main/java/io/modelcontextprotocol/util/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

package io.modelcontextprotocol.util;

import reactor.util.annotation.Nullable;

import java.net.URI;
import java.util.Collection;
import java.util.Map;

import reactor.util.annotation.Nullable;

/**
* Miscellaneous utility methods.
*
Expand Down Expand Up @@ -52,4 +53,55 @@ public static boolean isEmpty(@Nullable Map<?, ?> map) {
return (map == null || map.isEmpty());
}

/**
* Resolves the given endpoint URL against the base URL.
* <ul>
* <li>If the endpoint URL is relative, it will be resolved against the base URL.</li>
* <li>If the endpoint URL is absolute, it will be validated to ensure it matches the
* base URL's scheme, authority, and path prefix.</li>
* <li>If validation fails for an absolute URL, an {@link IllegalArgumentException} is
* thrown.</li>
* </ul>
* @param baseUrl The base URL (must be absolute)
* @param endpointUrl The endpoint URL (can be relative or absolute)
* @return The resolved endpoint URI
* @throws IllegalArgumentException If the absolute endpoint URL does not match the
* base URL or URI is malformed
*/
public static URI resolveUri(URI baseUrl, String endpointUrl) {
URI endpointUri = URI.create(endpointUrl);
if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) {
throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL.");
}
else {
return baseUrl.resolve(endpointUri);
}
}

/**
* Checks if the given absolute endpoint URI falls under the base URI. It validates
* the scheme, authority (host and port), and ensures that the base path is a prefix
* of the endpoint path.
* @param baseUri The base URI
* @param endpointUri The endpoint URI to check
* @return true if endpointUri is within baseUri's hierarchy, false otherwise
*/
private static boolean isUnderBaseUri(URI baseUri, URI endpointUri) {
if (!baseUri.getScheme().equals(endpointUri.getScheme())
|| !baseUri.getAuthority().equals(endpointUri.getAuthority())) {
return false;
}

URI normalizedBase = baseUri.normalize();
URI normalizedEndpoint = endpointUri.normalize();

String basePath = normalizedBase.getPath();
String endpointPath = normalizedEndpoint.getPath();

if (basePath.endsWith("/")) {
basePath = basePath.substring(0, basePath.length() - 1);
}
return endpointPath.startsWith(basePath);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;

import io.modelcontextprotocol.spec.McpSchema;
Expand All @@ -21,6 +22,8 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.containers.wait.strategy.Wait;
import reactor.core.publisher.Mono;
Expand All @@ -31,6 +34,9 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatCode;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.fasterxml.jackson.databind.ObjectMapper;

Expand Down Expand Up @@ -364,4 +370,23 @@ void testChainedCustomizations() {
customizedTransport.closeGracefully().block();
}

@Test
@SuppressWarnings("unchecked")
void testResolvingClientEndpoint() {
HttpClient httpClient = Mockito.mock(HttpClient.class);
HttpResponse<Void> httpResponse = Mockito.mock(HttpResponse.class);
CompletableFuture<HttpResponse<Void>> future = new CompletableFuture<>();
future.complete(httpResponse);
when(httpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))).thenReturn(future);

HttpClientSseClientTransport transport = new HttpClientSseClientTransport(httpClient, HttpRequest.newBuilder(),
"http://example.com", "http://example.com/sse", new ObjectMapper());

transport.connect(Function.identity());

ArgumentCaptor<HttpRequest> httpRequestCaptor = ArgumentCaptor.forClass(HttpRequest.class);
verify(httpClient).sendAsync(httpRequestCaptor.capture(), any(HttpResponse.BodyHandler.class));
assertThat(httpRequestCaptor.getValue().uri()).isEqualTo(URI.create("http://example.com/sse"));
}

}
29 changes: 29 additions & 0 deletions mcp/src/test/java/io/modelcontextprotocol/util/UtilsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

import org.junit.jupiter.api.Test;

import java.net.URI;
import java.util.Collection;
import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;

class UtilsTests {

Expand All @@ -37,4 +42,28 @@ void testMapIsEmpty() {
assertFalse(Utils.isEmpty(Map.of("key", "value")));
}

@ParameterizedTest
@CsvSource({
// relative endpoints
"http://localhost:8080/root, /api/v1, http://localhost:8080/api/v1",
"http://localhost:8080/root/, api, http://localhost:8080/root/api",
"http://localhost:8080, /api, http://localhost:8080/api",
// absolute endpoints matching base
"http://localhost:8080/root, http://localhost:8080/root/api/v1, http://localhost:8080/root/api/v1",
"http://localhost:8080/root, http://localhost:8080/root, http://localhost:8080/root" })
void testValidUriResolution(String baseUrl, String endpoint, String expectedResult) {
URI result = Utils.resolveUri(URI.create(baseUrl), endpoint);
assertThat(result.toString()).isEqualTo(expectedResult);
}

@ParameterizedTest
@CsvSource({ "http://localhost:8080/root, http://localhost:8080/other/api",
"http://localhost:8080/root, http://otherhost/api",
"http://localhost:8080/root, http://localhost:9090/root/api" })
void testAbsoluteUriNotMatchingBase(String baseUrl, String endpoint) {
assertThatThrownBy(() -> Utils.resolveUri(URI.create(baseUrl), endpoint))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("does not match the base URL");
}

}