Skip to content

Emit WebClientResponseException for malformed HTTP response #27262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@
import org.springframework.http.client.reactive.ClientHttpResponse;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.lang.Nullable;
import org.springframework.util.MimeType;
import org.springframework.util.MultiValueMap;
import org.springframework.web.reactive.function.BodyExtractor;
import org.springframework.web.reactive.function.BodyExtractors;


/**
* Default implementation of {@link ClientResponse}.
*
Expand Down Expand Up @@ -130,13 +132,21 @@ public MultiValueMap<String, ResponseCookie> cookies() {
public <T> T body(BodyExtractor<T, ? super ClientHttpResponse> extractor) {
T result = extractor.extract(this.response, this.bodyExtractorContext);
String description = "Body from " + this.requestDescription + " [DefaultClientResponse]";

if (result instanceof Mono) {
return (T) ((Mono<?>) result).checkpoint(description);
return (T) ((Mono<?>) result)
.checkpoint(description)
.onErrorMap(WebClientUtils.WRAP_EXCEPTION_PREDICATE, t -> toException(t, null, null));
}

else if (result instanceof Flux) {
return (T) ((Flux<?>) result).checkpoint(description);
return (T) ((Flux<?>) result)
.checkpoint(description)
.onErrorMap(WebClientUtils.WRAP_EXCEPTION_PREDICATE, t -> toException(t, null, null));
}

else {
// XXX: is there a way to preemptively handle uncaught exceptions here?
return result;
}
}
Expand Down Expand Up @@ -205,26 +215,21 @@ public Mono<WebClientResponseException> createException() {
.defaultIfEmpty(EMPTY)
.onErrorReturn(IllegalStateException.class::isInstance, EMPTY)
.map(bodyBytes -> {
HttpRequest request = this.requestSupplier.get();
Charset charset = headers().contentType().map(MimeType::getCharset).orElse(null);
int statusCode = rawStatusCode();
HttpStatus httpStatus = HttpStatus.resolve(statusCode);
Charset charset = headers().contentType().map(MimeType::getCharset).orElse(null);

if (httpStatus != null) {
return WebClientResponseException.create(
statusCode,
httpStatus.getReasonPhrase(),
headers().asHttpHeaders(),
bodyBytes,
charset,
request);
return toException(null, bodyBytes, charset);
}
else {
return new UnknownHttpStatusCodeException(
statusCode,
headers().asHttpHeaders(),
bodyBytes,
charset,
request);
this.requestSupplier.get()
);
}
});
}
Expand All @@ -239,6 +244,23 @@ HttpRequest request() {
return this.requestSupplier.get();
}

private WebClientResponseException toException(
@Nullable Throwable cause, @Nullable byte[] bodyBytes, @Nullable Charset charset) {

WebClientResponseException ex = new WebClientResponseException(
this.response.getRawStatusCode(),
this.response.getStatusCode().getReasonPhrase(),
headers().asHttpHeaders(),
bodyBytes,
charset,
this.requestSupplier.get()
);

ex.initCause(cause);

return ex;
}

private class DefaultHeaders implements Headers {

private final HttpHeaders httpHeaders =
Expand Down Expand Up @@ -269,5 +291,4 @@ private OptionalLong toOptionalLong(long value) {
return (value != -1 ? OptionalLong.of(value) : OptionalLong.empty());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
Expand All @@ -31,15 +35,18 @@
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import okhttp3.mockwebserver.SocketPolicy;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -64,7 +71,9 @@
import org.springframework.http.client.reactive.HttpComponentsClientHttpConnector;
import org.springframework.http.client.reactive.JettyClientHttpConnector;
import org.springframework.http.client.reactive.ReactorClientHttpConnector;
import org.springframework.util.SocketUtils;
import org.springframework.web.reactive.function.BodyExtractors;
import org.springframework.web.reactive.function.client.WebClient.ResponseSpec;
import org.springframework.web.testfixture.xml.Pojo;

import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -83,7 +92,7 @@ class WebClientIntegrationTests {

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@ParameterizedTest(name = "[{index}] webClient [{0}]")
@ParameterizedTest(name = "[{index}] {displayName} [{0}]")
@MethodSource("arguments")
@interface ParameterizedWebClientTest {
}
Expand Down Expand Up @@ -113,7 +122,9 @@ private void startServer(ClientHttpConnector connector) {

@AfterEach
void shutdown() throws IOException {
this.server.shutdown();
if (server != null) {
this.server.shutdown();
}
}


Expand Down Expand Up @@ -1209,6 +1220,136 @@ void invalidDomain(ClientHttpConnector connector) {
.verify();
}

static Stream<Arguments> socketFaultArguments() {
Stream.Builder<Arguments> argumentsBuilder = Stream.builder();
arguments().forEach(arg -> {
argumentsBuilder.accept(Arguments.of(arg, SocketPolicy.DISCONNECT_AT_START));
argumentsBuilder.accept(Arguments.of(arg, SocketPolicy.DISCONNECT_DURING_REQUEST_BODY));
argumentsBuilder.accept(Arguments.of(arg, SocketPolicy.DISCONNECT_AFTER_REQUEST));
});
return argumentsBuilder.build();
}

@ParameterizedTest(name = "[{index}] {displayName} [{0}, {1}]")
@MethodSource("socketFaultArguments")
void prematureClosureFault(ClientHttpConnector connector, SocketPolicy socketPolicy) {
startServer(connector);

prepareResponse(response -> response
.setSocketPolicy(socketPolicy)
.setStatus("HTTP/1.1 200 OK")
.setHeader("Response-Header-1", "value 1")
.setHeader("Response-Header-2", "value 2")
.setBody("{\"message\": \"Hello, World!\"}"));

String uri = "/test";
Mono<String> result = this.webClient
.post()
.uri(uri)
// Random non-empty body to allow us to interrupt.
.bodyValue("{\"action\": \"Say hello!\"}")
.retrieve()
.bodyToMono(String.class);

StepVerifier.create(result)
.expectErrorSatisfies(throwable -> {
assertThat(throwable).isInstanceOf(WebClientRequestException.class);
WebClientRequestException ex = (WebClientRequestException) throwable;
// Varies between connector providers.
assertThat(ex.getCause()).isInstanceOf(IOException.class);
})
.verify();
}

static Stream<Arguments> malformedResponseChunkArguments() {
return Stream.of(
Arguments.of(new ReactorClientHttpConnector(), true),
Arguments.of(new JettyClientHttpConnector(), true),
// Apache injects the Transfer-Encoding header for us, and complains with an exception if we also
// add it. The other two connectors do not add the header at all. We need this header for the test
// case to work correctly.
Arguments.of(new HttpComponentsClientHttpConnector(), false)
);
}

@ParameterizedTest(name = "[{index}] {displayName} [{0}, {1}]")
@MethodSource("malformedResponseChunkArguments")
void malformedResponseChunksOnBodilessEntity(ClientHttpConnector connector, boolean addTransferEncodingHeader) {
Mono<?> result = doMalformedResponseChunks(connector, addTransferEncodingHeader, ResponseSpec::toBodilessEntity);

StepVerifier.create(result)
.expectErrorSatisfies(throwable -> {
assertThat(throwable).isInstanceOf(WebClientException.class);
WebClientException ex = (WebClientException) throwable;
assertThat(ex.getCause()).isInstanceOf(IOException.class);
})
.verify();
}

@ParameterizedTest(name = "[{index}] {displayName} [{0}, {1}]")
@MethodSource("malformedResponseChunkArguments")
void malformedResponseChunksOnEntityWithBody(ClientHttpConnector connector, boolean addTransferEncodingHeader) {
Mono<?> result = doMalformedResponseChunks(connector, addTransferEncodingHeader, spec -> spec.toEntity(String.class));

StepVerifier.create(result)
.expectErrorSatisfies(throwable -> {
assertThat(throwable).isInstanceOf(WebClientException.class);
WebClientException ex = (WebClientException) throwable;
assertThat(ex.getCause()).isInstanceOf(IOException.class);
})
.verify();
}

private <T> Mono<T> doMalformedResponseChunks(
ClientHttpConnector connector,
boolean addTransferEncodingHeader,
Function<ResponseSpec, Mono<T>> responseHandler
) {
int port = SocketUtils.findAvailableTcpPort();

Thread serverThread = new Thread(() -> {
// This exists separately to the main mock server, as I had a really hard time getting that to send the
// chunked responses correctly, flushing the socket each time. This was the only way I was able to replicate
// the issue of the client not handling malformed response chunks correctly.
try (ServerSocket serverSocket = new ServerSocket(port)) {
Socket socket = serverSocket.accept();
InputStream is = socket.getInputStream();

//noinspection ResultOfMethodCallIgnored
is.read(new byte[4096]);

OutputStream os = socket.getOutputStream();
os.write("HTTP/1.1 200 OK\r\n".getBytes(StandardCharsets.UTF_8));
os.write("Transfer-Encoding: chunked\r\n".getBytes(StandardCharsets.UTF_8));
os.write("\r\n".getBytes(StandardCharsets.UTF_8));
os.write("lskdu018973t09sylgasjkfg1][]'./.sdlv".getBytes(StandardCharsets.UTF_8));
socket.close();
}
catch (IOException ex) {
throw new RuntimeException(ex);
}
});

serverThread.setDaemon(true);
serverThread.start();

ResponseSpec spec = WebClient
.builder()
.clientConnector(connector)
.baseUrl("http://localhost:" + port)
.build()
.post()
.headers(headers -> {
if (addTransferEncodingHeader) {
headers.add(HttpHeaders.TRANSFER_ENCODING, "chunked");
}
})
.retrieve();

return responseHandler
.apply(spec)
.doFinally(signal -> serverThread.stop());
}

private void prepareResponse(Consumer<MockResponse> consumer) {
MockResponse response = new MockResponse();
Expand Down Expand Up @@ -1252,5 +1393,4 @@ public void setContainerValue(T containerValue) {
this.containerValue = containerValue;
}
}

}