Skip to content

Commit f73a522

Browse files
committed
Ensure client response is drained with onStatus hook
Issue: SPR-17473
1 parent 8a2262e commit f73a522

File tree

6 files changed

+216
-25
lines changed

6 files changed

+216
-25
lines changed

spring-core/src/test/java/org/springframework/core/io/buffer/AbstractDataBufferAllocatingTestCase.java

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package org.springframework.core.io.buffer;
1818

1919
import java.nio.charset.StandardCharsets;
20+
import java.time.Duration;
21+
import java.time.Instant;
2022
import java.util.Arrays;
2123
import java.util.List;
2224
import java.util.function.Consumer;
@@ -36,7 +38,11 @@
3638
import static org.junit.Assert.*;
3739

3840
/**
41+
* Base class for tests that read or write data buffers with a rule to check
42+
* that allocated buffers have been released.
43+
*
3944
* @author Arjen Poutsma
45+
* @author Rossen Stoyanchev
4046
*/
4147
@RunWith(Parameterized.class)
4248
public abstract class AbstractDataBufferAllocatingTestCase {
@@ -61,6 +67,7 @@ public static Object[][] dataBufferFactories() {
6167
@Rule
6268
public final Verifier leakDetector = new LeakDetector();
6369

70+
6471
protected DataBuffer createDataBuffer(int capacity) {
6572
return this.bufferFactory.allocateBuffer(capacity);
6673
}
@@ -85,30 +92,45 @@ protected Consumer<DataBuffer> stringConsumer(String expected) {
8592
};
8693
}
8794

88-
89-
private class LeakDetector extends Verifier {
90-
91-
@Override
92-
protected void verify() throws Throwable {
93-
if (bufferFactory instanceof NettyDataBufferFactory) {
94-
ByteBufAllocator byteBufAllocator =
95-
((NettyDataBufferFactory) bufferFactory).getByteBufAllocator();
96-
if (byteBufAllocator instanceof PooledByteBufAllocator) {
97-
PooledByteBufAllocator pooledByteBufAllocator =
98-
(PooledByteBufAllocator) byteBufAllocator;
99-
PooledByteBufAllocatorMetric metric = pooledByteBufAllocator.metric();
100-
long allocations = calculateAllocations(metric.directArenas()) +
101-
calculateAllocations(metric.heapArenas());
102-
assertTrue("ByteBuf leak detected: " + allocations +
103-
" allocations were not released", allocations == 0);
104-
}
95+
/**
96+
* Wait until allocations are at 0, or the given duration elapses.
97+
*/
98+
protected void waitForDataBufferRelease(Duration duration) throws InterruptedException {
99+
Instant start = Instant.now();
100+
while (Instant.now().isBefore(start.plus(duration))) {
101+
try {
102+
verifyAllocations();
103+
break;
104+
}
105+
catch (AssertionError ex) {
106+
// ignore;
105107
}
108+
Thread.sleep(50);
106109
}
110+
}
107111

108-
private long calculateAllocations(List<PoolArenaMetric> metrics) {
109-
return metrics.stream().mapToLong(PoolArenaMetric::numActiveAllocations).sum();
112+
private void verifyAllocations() {
113+
if (this.bufferFactory instanceof NettyDataBufferFactory) {
114+
ByteBufAllocator allocator = ((NettyDataBufferFactory) this.bufferFactory).getByteBufAllocator();
115+
if (allocator instanceof PooledByteBufAllocator) {
116+
PooledByteBufAllocatorMetric metric = ((PooledByteBufAllocator) allocator).metric();
117+
long total = getAllocations(metric.directArenas()) + getAllocations(metric.heapArenas());
118+
assertEquals("ByteBuf Leak: " + total + " unreleased allocations", 0, total);
119+
}
110120
}
121+
}
122+
123+
private static long getAllocations(List<PoolArenaMetric> metrics) {
124+
return metrics.stream().mapToLong(PoolArenaMetric::numActiveAllocations).sum();
125+
}
126+
111127

128+
protected class LeakDetector extends Verifier {
129+
130+
@Override
131+
public void verify() {
132+
AbstractDataBufferAllocatingTestCase.this.verifyAllocations();
133+
}
112134
}
113135

114136
}

spring-web/src/main/java/org/springframework/http/client/reactive/ReactorClientHttpResponse.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.http.client.reactive;
1818

1919
import java.util.Collection;
20+
import java.util.concurrent.atomic.AtomicBoolean;
2021

2122
import reactor.core.publisher.Flux;
2223
import reactor.ipc.netty.http.client.HttpClientResponse;
@@ -26,6 +27,7 @@
2627
import org.springframework.http.HttpHeaders;
2728
import org.springframework.http.HttpStatus;
2829
import org.springframework.http.ResponseCookie;
30+
import org.springframework.util.Assert;
2931
import org.springframework.util.CollectionUtils;
3032
import org.springframework.util.LinkedMultiValueMap;
3133
import org.springframework.util.MultiValueMap;
@@ -43,6 +45,8 @@ class ReactorClientHttpResponse implements ClientHttpResponse {
4345

4446
private final HttpClientResponse response;
4547

48+
private final AtomicBoolean bodyConsumed = new AtomicBoolean();
49+
4650

4751
public ReactorClientHttpResponse(HttpClientResponse response) {
4852
this.response = response;
@@ -53,6 +57,13 @@ public ReactorClientHttpResponse(HttpClientResponse response) {
5357
@Override
5458
public Flux<DataBuffer> getBody() {
5559
return response.receive()
60+
.doOnSubscribe(s ->
61+
// WebClient's onStatus handling tries to drain the body, which may
62+
// have also been done by application code in the onStatus callback.
63+
// That relies on the 2nd subscriber being rejected but FluxReceive
64+
// isn't consistent in doing so and may hang without completion.
65+
Assert.state(this.bodyConsumed.compareAndSet(false, true),
66+
"The client response body can only be consumed once."))
5667
.map(buf -> {
5768
buf.retain();
5869
return dataBufferFactory.wrap(buf);

spring-webflux/src/main/java/org/springframework/web/reactive/function/client/DefaultWebClient.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,22 @@ private <T> Flux<T> monoThrowableToFlux(Mono<? extends Throwable> mono) {
433433
private <T extends Publisher<?>> T bodyToPublisher(ClientResponse response,
434434
T bodyPublisher, Function<Mono<? extends Throwable>, T> errorFunction) {
435435

436-
return this.statusHandlers.stream()
437-
.filter(statusHandler -> statusHandler.test(response.statusCode()))
438-
.findFirst()
439-
.map(statusHandler -> statusHandler.apply(response))
440-
.map(errorFunction::apply)
441-
.orElse(bodyPublisher);
436+
for (StatusHandler handler : this.statusHandlers) {
437+
if (handler.test(response.statusCode())) {
438+
Mono<? extends Throwable> exMono = handler.apply(response);
439+
exMono = exMono.flatMap(ex -> drainBody(response, ex));
440+
exMono = exMono.onErrorResume(ex -> drainBody(response, ex));
441+
return errorFunction.apply(exMono);
442+
}
443+
}
444+
return bodyPublisher;
445+
}
446+
447+
@SuppressWarnings("unchecked")
448+
private <T> Mono<T> drainBody(ClientResponse response, Throwable ex) {
449+
// Ensure the body is drained, even if the StatusHandler didn't consume it,
450+
// but ignore errors in case it did consume it.
451+
return (Mono<T>) response.bodyToMono(Void.class).onErrorMap(ex2 -> ex).thenReturn(ex);
442452
}
443453

444454
private static Mono<WebClientResponseException> createResponseException(ClientResponse response) {

spring-webflux/src/main/java/org/springframework/web/reactive/function/client/WebClient.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,9 @@ interface ResponseSpec {
596596
* {@link WebClientResponseException} when the response status code is 4xx or 5xx.
597597
* @param statusPredicate a predicate that indicates whether {@code exceptionFunction}
598598
* applies
599+
* <p><strong>NOTE:</strong> if the response is expected to have content,
600+
* the exceptionFunction should consume it. If not, the content will be
601+
* automatically drained to ensure resources are released.
599602
* @param exceptionFunction the function that returns the exception
600603
* @return this builder
601604
*/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright 2002-2018 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.web.reactive.function.client;
17+
18+
import java.time.Duration;
19+
import java.util.function.Function;
20+
21+
import io.netty.buffer.ByteBufAllocator;
22+
import io.netty.channel.ChannelOption;
23+
import okhttp3.mockwebserver.MockResponse;
24+
import okhttp3.mockwebserver.MockWebServer;
25+
import org.junit.After;
26+
import org.junit.Before;
27+
import org.junit.Test;
28+
import reactor.core.publisher.Mono;
29+
import reactor.test.StepVerifier;
30+
31+
import org.springframework.core.io.buffer.AbstractDataBufferAllocatingTestCase;
32+
import org.springframework.core.io.buffer.NettyDataBufferFactory;
33+
import org.springframework.http.HttpStatus;
34+
import org.springframework.http.MediaType;
35+
import org.springframework.http.client.reactive.ReactorClientHttpConnector;
36+
37+
import static org.junit.Assert.*;
38+
39+
/**
40+
* WebClient integration tests focusing on data buffer management.
41+
* @author Rossen Stoyanchev
42+
*/
43+
public class WebClientDataBufferAllocatingTests extends AbstractDataBufferAllocatingTestCase {
44+
45+
private static final Duration DELAY = Duration.ofSeconds(5);
46+
47+
48+
private MockWebServer server;
49+
50+
private WebClient webClient;
51+
52+
53+
@Before
54+
public void setUp() {
55+
this.server = new MockWebServer();
56+
this.webClient = WebClient
57+
.builder()
58+
.clientConnector(initConnector())
59+
.baseUrl(this.server.url("/").toString())
60+
.build();
61+
}
62+
63+
private ReactorClientHttpConnector initConnector() {
64+
if (bufferFactory instanceof NettyDataBufferFactory) {
65+
ByteBufAllocator allocator = ((NettyDataBufferFactory) bufferFactory).getByteBufAllocator();
66+
return new ReactorClientHttpConnector(builder -> builder.option(ChannelOption.ALLOCATOR, allocator));
67+
}
68+
else {
69+
return new ReactorClientHttpConnector();
70+
}
71+
}
72+
73+
@After
74+
public void shutDown() throws InterruptedException {
75+
waitForDataBufferRelease(Duration.ofSeconds(2));
76+
}
77+
78+
79+
@Test
80+
public void bodyToMonoVoid() {
81+
82+
this.server.enqueue(new MockResponse()
83+
.setResponseCode(201)
84+
.setHeader("Content-Type", "application/json")
85+
.setChunkedBody("{\"foo\" : {\"bar\" : \"123\", \"baz\" : \"456\"}}", 5));
86+
87+
Mono<Void> mono = this.webClient.get()
88+
.uri("/json").accept(MediaType.APPLICATION_JSON)
89+
.retrieve()
90+
.bodyToMono(Void.class);
91+
92+
StepVerifier.create(mono).expectComplete().verify(Duration.ofSeconds(3));
93+
assertEquals(1, this.server.getRequestCount());
94+
}
95+
96+
97+
@Test
98+
public void onStatusWithBodyNotConsumed() {
99+
RuntimeException ex = new RuntimeException("response error");
100+
testOnStatus(ex, response -> Mono.just(ex));
101+
}
102+
103+
@Test
104+
public void onStatusWithBodyConsumed() {
105+
RuntimeException ex = new RuntimeException("response error");
106+
testOnStatus(ex, response -> response.bodyToMono(Void.class).thenReturn(ex));
107+
}
108+
109+
@Test // SPR-17473
110+
public void onStatusWithMonoErrorAndBodyNotConsumed() {
111+
RuntimeException ex = new RuntimeException("response error");
112+
testOnStatus(ex, response -> Mono.error(ex));
113+
}
114+
115+
@Test
116+
public void onStatusWithMonoErrorAndBodyConsumed() {
117+
RuntimeException ex = new RuntimeException("response error");
118+
testOnStatus(ex, response -> response.bodyToMono(Void.class).then(Mono.error(ex)));
119+
}
120+
121+
private void testOnStatus(Throwable expected,
122+
Function<ClientResponse, Mono<? extends Throwable>> exceptionFunction) {
123+
124+
HttpStatus errorStatus = HttpStatus.BAD_GATEWAY;
125+
126+
this.server.enqueue(new MockResponse()
127+
.setResponseCode(errorStatus.value())
128+
.setHeader("Content-Type", "application/json")
129+
.setChunkedBody("{\"error\" : {\"status\" : 502, \"message\" : \"Bad gateway.\"}}", 5));
130+
131+
Mono<String> mono = this.webClient.get()
132+
.uri("/json").accept(MediaType.APPLICATION_JSON)
133+
.retrieve()
134+
.onStatus(status -> status.equals(errorStatus), exceptionFunction)
135+
.bodyToMono(String.class);
136+
137+
StepVerifier.create(mono).expectErrorSatisfies(actual -> assertSame(expected, actual)).verify(DELAY);
138+
assertEquals(1, this.server.getRequestCount());
139+
}
140+
141+
}

src/docs/asciidoc/web/webflux-webclient.adoc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ By default, responses with 4xx or 5xx status codes result in an error of type
7171
.bodyToMono(Person.class);
7272
----
7373

74+
When `onStatus` is used, if the response is expected to have content, then the `onStatus`
75+
callback should consume it. If not, the content will be automatically drained to ensure
76+
resources are released.
77+
7478

7579

7680

0 commit comments

Comments
 (0)