diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java index ae302bbf3d84..1b0c3d4c8a38 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpRequest.java @@ -16,6 +16,7 @@ package org.springframework.http.client; +import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; @@ -23,17 +24,18 @@ import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; +import java.net.http.HttpTimeoutException; import java.nio.ByteBuffer; import java.time.Duration; import java.util.Collections; import java.util.Set; import java.util.TreeSet; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.lang.Nullable; @@ -92,28 +94,46 @@ public URI getURI() { @Override @SuppressWarnings("NullAway") protected ClientHttpResponse executeInternal(HttpHeaders headers, @Nullable Body body) throws IOException { + HttpRequest request = buildRequest(headers, body); + CompletableFuture> responsefuture = + this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()); try { - HttpRequest request = buildRequest(headers, body); - HttpResponse response; if (this.timeout != null) { - response = this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofInputStream()) - .get(this.timeout.toMillis(), TimeUnit.MILLISECONDS); - } - else { - response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofInputStream()); + CompletableFuture timeoutFuture = new CompletableFuture() + .completeOnTimeout(null, this.timeout.toMillis(), TimeUnit.MILLISECONDS); + timeoutFuture.thenRun(() -> { + if (!responsefuture.cancel(true) && !responsefuture.isCompletedExceptionally()) { + try { + responsefuture.resultNow().body().close(); + } catch (IOException ignored) {} + } + }); + var response = responsefuture.get(); + return new JdkClientHttpResponse(response.statusCode(), response.headers(), new FilterInputStream(response.body()) { + + @Override + public void close() throws IOException { + timeoutFuture.cancel(false); + super.close(); + } + }); + + } else { + var response = responsefuture.get(); + return new JdkClientHttpResponse(response.statusCode(), response.headers(), response.body()); } - return new JdkClientHttpResponse(response); - } - catch (UncheckedIOException ex) { - throw ex.getCause(); } catch (InterruptedException ex) { Thread.currentThread().interrupt(); + responsefuture.cancel(true); throw new IOException("Request was interrupted: " + ex.getMessage(), ex); } catch (ExecutionException ex) { Throwable cause = ex.getCause(); + if (cause instanceof CancellationException caEx) { + throw new HttpTimeoutException("Request timed out"); + } if (cause instanceof UncheckedIOException uioEx) { throw uioEx.getCause(); } @@ -127,17 +147,11 @@ else if (cause instanceof IOException ioEx) { throw new IOException(cause.getMessage(), cause); } } - catch (TimeoutException ex) { - throw new IOException("Request timed out: " + ex.getMessage(), ex); - } } private HttpRequest buildRequest(HttpHeaders headers, @Nullable Body body) { HttpRequest.Builder builder = HttpRequest.newBuilder().uri(this.uri); - if (this.timeout != null) { - builder.timeout(this.timeout); - } headers.forEach((headerName, headerValues) -> { if (!DISALLOWED_HEADERS.contains(headerName.toLowerCase())) { diff --git a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java index 17208e2e910b..a26b5531d9be 100644 --- a/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/client/JdkClientHttpResponse.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.io.InputStream; import java.net.http.HttpClient; -import java.net.http.HttpResponse; import java.util.List; import java.util.Locale; import java.util.Map; @@ -41,22 +40,21 @@ */ class JdkClientHttpResponse implements ClientHttpResponse { - private final HttpResponse response; + private final int statusCode; private final HttpHeaders headers; private final InputStream body; - public JdkClientHttpResponse(HttpResponse response) { - this.response = response; - this.headers = adaptHeaders(response); - InputStream inputStream = response.body(); - this.body = (inputStream != null ? inputStream : InputStream.nullInputStream()); + public JdkClientHttpResponse(int statusCode, java.net.http.HttpHeaders headers, InputStream body) { + this.statusCode = statusCode; + this.headers = adaptHeaders(headers); + this.body = body != null ? body : InputStream.nullInputStream(); } - private static HttpHeaders adaptHeaders(HttpResponse response) { - Map> rawHeaders = response.headers().map(); + private static HttpHeaders adaptHeaders(java.net.http.HttpHeaders headers) { + Map> rawHeaders = headers.map(); Map> map = new LinkedCaseInsensitiveMap<>(rawHeaders.size(), Locale.ENGLISH); MultiValueMap multiValueMap = CollectionUtils.toMultiValueMap(map); multiValueMap.putAll(rawHeaders); @@ -66,7 +64,7 @@ private static HttpHeaders adaptHeaders(HttpResponse response) { @Override public HttpStatusCode getStatusCode() { - return HttpStatusCode.valueOf(this.response.statusCode()); + return HttpStatusCode.valueOf(statusCode); } @Override