Skip to content

Commit aec9826

Browse files
committed
maxResponseBody client filter
Issue: SPR-16989
1 parent 5095ec4 commit aec9826

File tree

2 files changed

+77
-17
lines changed

2 files changed

+77
-17
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@
2525
import java.util.function.Function;
2626
import java.util.function.Predicate;
2727

28+
import reactor.core.publisher.Flux;
2829
import reactor.core.publisher.Mono;
2930

31+
import org.springframework.core.io.buffer.DataBuffer;
32+
import org.springframework.core.io.buffer.DataBufferUtils;
3033
import org.springframework.http.HttpHeaders;
3134
import org.springframework.http.HttpStatus;
3235
import org.springframework.util.Assert;
36+
import org.springframework.web.reactive.function.BodyExtractors;
3337

3438
/**
3539
* Static factory methods providing access to built-in implementations of
@@ -50,6 +54,21 @@ public abstract class ExchangeFilterFunctions {
5054
public static final String BASIC_AUTHENTICATION_CREDENTIALS_ATTRIBUTE =
5155
ExchangeFilterFunctions.class.getName() + ".basicAuthenticationCredentials";
5256

57+
/**
58+
* Consume up to the specified number of bytes from the response body and
59+
* cancel if any more data arrives. Internally delegates to
60+
* {@link DataBufferUtils#takeUntilByteCount}.
61+
* @return the filter to limit the response size with
62+
* @since 5.1
63+
*/
64+
public static ExchangeFilterFunction limitResponseSize(long maxByteCount) {
65+
return (request, next) ->
66+
next.exchange(request).map(response -> {
67+
Flux<DataBuffer> body = response.body(BodyExtractors.toDataBuffers());
68+
body = DataBufferUtils.takeUntilByteCount(body, maxByteCount);
69+
return ClientResponse.from(response).body(body).build();
70+
});
71+
}
5372

5473
/**
5574
* Return a filter for HTTP Basic Authentication that adds an authorization

spring-webflux/src/test/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctionsTests.java

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,36 @@
1717
package org.springframework.web.reactive.function.client;
1818

1919
import java.net.URI;
20+
import java.nio.charset.StandardCharsets;
2021

2122
import org.junit.Test;
23+
import reactor.core.publisher.Flux;
2224
import reactor.core.publisher.Mono;
2325
import reactor.test.StepVerifier;
2426

27+
import org.springframework.core.io.buffer.DataBuffer;
28+
import org.springframework.core.io.buffer.DataBufferUtils;
29+
import org.springframework.core.io.buffer.DefaultDataBufferFactory;
30+
import org.springframework.core.io.buffer.support.DataBufferTestUtils;
2531
import org.springframework.http.HttpHeaders;
32+
import org.springframework.http.HttpMethod;
2633
import org.springframework.http.HttpStatus;
34+
import org.springframework.web.reactive.function.BodyExtractors;
2735

2836
import static org.junit.Assert.*;
2937
import static org.mockito.Mockito.*;
30-
import static org.springframework.http.HttpMethod.GET;
31-
import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials;
3238

3339
/**
3440
* @author Arjen Poutsma
3541
*/
36-
@SuppressWarnings("deprecation")
3742
public class ExchangeFilterFunctionsTests {
3843

44+
private static final URI DEFAULT_URL = URI.create("http://example.com");
45+
46+
3947
@Test
4048
public void andThen() {
41-
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
49+
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
4250
ClientResponse response = mock(ClientResponse.class);
4351
ExchangeFunction exchange = r -> Mono.just(response);
4452

@@ -68,7 +76,7 @@ public void andThen() {
6876

6977
@Test
7078
public void apply() {
71-
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
79+
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
7280
ClientResponse response = mock(ClientResponse.class);
7381
ExchangeFunction exchange = r -> Mono.just(response);
7482

@@ -86,8 +94,9 @@ public void apply() {
8694
}
8795

8896
@Test
97+
@SuppressWarnings("deprecation")
8998
public void basicAuthenticationUsernamePassword() {
90-
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
99+
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
91100
ClientResponse response = mock(ClientResponse.class);
92101

93102
ExchangeFunction exchange = r -> {
@@ -109,9 +118,11 @@ public void basicAuthenticationInvalidCharacters() {
109118
}
110119

111120
@Test
121+
@SuppressWarnings("deprecation")
112122
public void basicAuthenticationAttributes() {
113-
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com"))
114-
.attributes(basicAuthenticationCredentials("foo", "bar"))
123+
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL)
124+
.attributes(org.springframework.web.reactive.function.client.ExchangeFilterFunctions
125+
.Credentials.basicAuthenticationCredentials("foo", "bar"))
115126
.build();
116127
ClientResponse response = mock(ClientResponse.class);
117128

@@ -128,8 +139,9 @@ public void basicAuthenticationAttributes() {
128139
}
129140

130141
@Test
142+
@SuppressWarnings("deprecation")
131143
public void basicAuthenticationAbsentAttributes() {
132-
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
144+
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
133145
ClientResponse response = mock(ClientResponse.class);
134146

135147
ExchangeFunction exchange = r -> {
@@ -145,7 +157,7 @@ public void basicAuthenticationAbsentAttributes() {
145157

146158
@Test
147159
public void statusHandlerMatch() {
148-
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
160+
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
149161
ClientResponse response = mock(ClientResponse.class);
150162
when(response.statusCode()).thenReturn(HttpStatus.NOT_FOUND);
151163

@@ -163,23 +175,52 @@ public void statusHandlerMatch() {
163175

164176
@Test
165177
public void statusHandlerNoMatch() {
166-
ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build();
178+
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
167179
ClientResponse response = mock(ClientResponse.class);
168180
when(response.statusCode()).thenReturn(HttpStatus.NOT_FOUND);
169181

170-
ExchangeFunction exchange = r -> Mono.just(response);
171-
172-
ExchangeFilterFunction errorHandler = ExchangeFilterFunctions.statusError(
173-
HttpStatus::is5xxServerError, r -> new MyException());
174-
175-
Mono<ClientResponse> result = errorHandler.filter(request, exchange);
182+
Mono<ClientResponse> result = ExchangeFilterFunctions
183+
.statusError(HttpStatus::is5xxServerError, req -> new MyException())
184+
.filter(request, req -> Mono.just(response));
176185

177186
StepVerifier.create(result)
178187
.expectNext(response)
179188
.expectComplete()
180189
.verify();
181190
}
182191

192+
@Test
193+
public void limitResponseSize() {
194+
DefaultDataBufferFactory bufferFactory = new DefaultDataBufferFactory();
195+
DataBuffer b1 = dataBuffer("foo", bufferFactory);
196+
DataBuffer b2 = dataBuffer("bar", bufferFactory);
197+
DataBuffer b3 = dataBuffer("baz", bufferFactory);
198+
199+
ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build();
200+
ClientResponse response = ClientResponse.create(HttpStatus.OK).body(Flux.just(b1, b2, b3)).build();
201+
202+
Mono<ClientResponse> result = ExchangeFilterFunctions.limitResponseSize(5)
203+
.filter(request, req -> Mono.just(response));
204+
205+
StepVerifier.create(result.flatMapMany(res -> res.body(BodyExtractors.toDataBuffers())))
206+
.consumeNextWith(buffer -> assertEquals("foo", string(buffer)))
207+
.consumeNextWith(buffer -> assertEquals("ba", string(buffer)))
208+
.expectComplete()
209+
.verify();
210+
211+
}
212+
213+
private String string(DataBuffer buffer) {
214+
String value = DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8);
215+
DataBufferUtils.release(buffer);
216+
return value;
217+
}
218+
219+
private DataBuffer dataBuffer(String foo, DefaultDataBufferFactory bufferFactory) {
220+
return bufferFactory.wrap(foo.getBytes(StandardCharsets.UTF_8));
221+
}
222+
223+
183224
@SuppressWarnings("serial")
184225
private static class MyException extends Exception {
185226

0 commit comments

Comments
 (0)