Skip to content

Commit f23612c

Browse files
committed
Add ResolvableType to HttpEntity for multipart Publishers
This commit adds a ResolvableType field to HttpEntity, in order to support Publishers as multipart data. Without the type, the MultipartHttpMessageWriter does not know which delegate writer to use to write the part. Issue: SPR-16307
1 parent 9d27e86 commit f23612c

File tree

5 files changed

+216
-20
lines changed

5 files changed

+216
-20
lines changed

spring-web/src/main/java/org/springframework/http/HttpEntity.java

+66
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616

1717
package org.springframework.http;
1818

19+
import org.reactivestreams.Publisher;
20+
21+
import org.springframework.core.ParameterizedTypeReference;
22+
import org.springframework.core.ResolvableType;
1923
import org.springframework.lang.Nullable;
24+
import org.springframework.util.Assert;
2025
import org.springframework.util.MultiValueMap;
2126
import org.springframework.util.ObjectUtils;
2227

@@ -67,6 +72,9 @@ public class HttpEntity<T> {
6772
@Nullable
6873
private final T body;
6974

75+
@Nullable
76+
private final ResolvableType bodyType;
77+
7078

7179
/**
7280
* Create a new, empty {@code HttpEntity}.
@@ -97,7 +105,18 @@ public HttpEntity(MultiValueMap<String, String> headers) {
97105
* @param headers the entity headers
98106
*/
99107
public HttpEntity(@Nullable T body, @Nullable MultiValueMap<String, String> headers) {
108+
this(body, null, headers);
109+
}
110+
111+
private HttpEntity(@Nullable T body, @Nullable ResolvableType bodyType,
112+
@Nullable MultiValueMap<String, String> headers) {
100113
this.body = body;
114+
115+
if (bodyType == null && body != null) {
116+
bodyType = ResolvableType.forClass(body.getClass());
117+
}
118+
this.bodyType = bodyType ;
119+
101120
HttpHeaders tempHeaders = new HttpHeaders();
102121
if (headers != null) {
103122
tempHeaders.putAll(headers);
@@ -128,6 +147,13 @@ public boolean hasBody() {
128147
return (this.body != null);
129148
}
130149

150+
/**
151+
* Returns the type of the body.
152+
*/
153+
@Nullable
154+
public ResolvableType getBodyType() {
155+
return this.bodyType;
156+
}
131157

132158
@Override
133159
public boolean equals(@Nullable Object other) {
@@ -159,4 +185,44 @@ public String toString() {
159185
return builder.toString();
160186
}
161187

188+
189+
// Static builder methods
190+
191+
/**
192+
* Create a new {@code HttpEntity} with the given {@link Publisher} as body, class contained in
193+
* {@code publisher}, and headers.
194+
* @param publisher the publisher to use as body
195+
* @param elementClass the class of elements contained in the publisher
196+
* @param headers the entity headers
197+
* @param <S> the type of the elements contained in the publisher
198+
* @param <P> the type of the {@code Publisher}
199+
* @return the created entity
200+
*/
201+
public static <S, P extends Publisher<S>> HttpEntity<P> fromPublisher(P publisher,
202+
Class<S> elementClass, @Nullable MultiValueMap<String, String> headers) {
203+
204+
Assert.notNull(publisher, "'publisher' must not be null");
205+
Assert.notNull(elementClass, "'elementClass' must not be null");
206+
return new HttpEntity<>(publisher, ResolvableType.forClass(elementClass), headers);
207+
}
208+
209+
/**
210+
* Create a new {@code HttpEntity} with the given {@link Publisher} as body, type contained in
211+
* {@code publisher}, and headers.
212+
* @param publisher the publisher to use as body
213+
* @param typeReference the type of elements contained in the publisher
214+
* @param headers the entity headers
215+
* @param <S> the type of the elements contained in the publisher
216+
* @param <P> the type of the {@code Publisher}
217+
* @return the created entity
218+
*/
219+
public static <S, P extends Publisher<S>> HttpEntity<P> fromPublisher(P publisher,
220+
ParameterizedTypeReference<S> typeReference,
221+
@Nullable MultiValueMap<String, String> headers) {
222+
223+
Assert.notNull(publisher, "'publisher' must not be null");
224+
Assert.notNull(typeReference, "'typeReference' must not be null");
225+
return new HttpEntity<>(publisher, ResolvableType.forType(typeReference), headers);
226+
}
227+
162228
}

spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java

+99-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
import java.util.List;
2121
import java.util.Map;
2222

23+
import org.reactivestreams.Publisher;
24+
25+
import org.springframework.core.ParameterizedTypeReference;
26+
import org.springframework.core.ResolvableType;
2327
import org.springframework.http.HttpEntity;
2428
import org.springframework.http.HttpHeaders;
2529
import org.springframework.http.MediaType;
@@ -96,6 +100,11 @@ public PartBuilder part(String name, Object part, @Nullable MediaType contentTyp
96100
Assert.hasLength(name, "'name' must not be empty");
97101
Assert.notNull(part, "'part' must not be null");
98102

103+
if (part instanceof Publisher) {
104+
throw new IllegalArgumentException("Use publisher(String, Publisher, Class) or " +
105+
"publisher(String, Publisher, ParameterizedTypeReference) for adding Publisher parts");
106+
}
107+
99108
Object partBody;
100109
HttpHeaders partHeaders = new HttpHeaders();
101110

@@ -116,6 +125,54 @@ public PartBuilder part(String name, Object part, @Nullable MediaType contentTyp
116125
return builder;
117126
}
118127

128+
/**
129+
* Adds a {@link Publisher} part to this builder, allowing for further header customization with
130+
* the returned {@link PartBuilder}.
131+
* @param name the name of the part to add (may not be empty)
132+
* @param publisher the contents of the part to add
133+
* @param elementClass the class of elements contained in the publisher
134+
* @return a builder that allows for further header customization
135+
*/
136+
public <T, P extends Publisher<T>> PartBuilder asyncPart(String name, P publisher,
137+
Class<T> elementClass) {
138+
139+
Assert.notNull(elementClass, "'elementClass' must not be null");
140+
ResolvableType elementType = ResolvableType.forClass(elementClass);
141+
Assert.hasLength(name, "'name' must not be empty");
142+
Assert.notNull(publisher, "'publisher' must not be null");
143+
Assert.notNull(elementType, "'elementType' must not be null");
144+
145+
HttpHeaders partHeaders = new HttpHeaders();
146+
PublisherClassPartBuilder<T, P> builder =
147+
new PublisherClassPartBuilder<>(publisher, elementClass, partHeaders);
148+
this.parts.add(name, builder);
149+
return builder;
150+
151+
}
152+
153+
/**
154+
* Adds a {@link Publisher} part to this builder, allowing for further header customization with
155+
* the returned {@link PartBuilder}.
156+
* @param name the name of the part to add (may not be empty)
157+
* @param publisher the contents of the part to add
158+
* @param elementType the type of elements contained in the publisher
159+
* @return a builder that allows for further header customization
160+
*/
161+
public <T, P extends Publisher<T>> PartBuilder asyncPart(String name, P publisher,
162+
ParameterizedTypeReference<T> elementType) {
163+
164+
Assert.notNull(elementType, "'elementType' must not be null");
165+
ResolvableType elementType1 = ResolvableType.forType(elementType);
166+
Assert.hasLength(name, "'name' must not be empty");
167+
Assert.notNull(publisher, "'publisher' must not be null");
168+
Assert.notNull(elementType1, "'elementType' must not be null");
169+
170+
HttpHeaders partHeaders = new HttpHeaders();
171+
PublisherTypReferencePartBuilder<T, P> builder =
172+
new PublisherTypReferencePartBuilder<>(publisher, elementType, partHeaders);
173+
this.parts.add(name, builder);
174+
return builder;
175+
}
119176

120177
/**
121178
* Builder interface that allows for customization of part headers.
@@ -136,10 +193,9 @@ public interface PartBuilder {
136193
private static class DefaultPartBuilder implements PartBuilder {
137194

138195
@Nullable
139-
private final Object body;
140-
141-
private final HttpHeaders headers;
196+
protected final Object body;
142197

198+
protected final HttpHeaders headers;
143199

144200
public DefaultPartBuilder(@Nullable Object body, HttpHeaders headers) {
145201
this.body = body;
@@ -157,4 +213,44 @@ public HttpEntity<?> build() {
157213
}
158214
}
159215

216+
private static class PublisherClassPartBuilder<S, P extends Publisher<S>>
217+
extends DefaultPartBuilder {
218+
219+
private final Class<S> bodyType;
220+
221+
public PublisherClassPartBuilder(P body, Class<S> bodyType, HttpHeaders headers) {
222+
super(body, headers);
223+
this.bodyType = bodyType;
224+
}
225+
226+
@Override
227+
@SuppressWarnings("unchecked")
228+
public HttpEntity<?> build() {
229+
P body = (P) this.body;
230+
Assert.state(body != null, "'body' must not be null");
231+
return HttpEntity.fromPublisher(body, this.bodyType, this.headers);
232+
}
233+
}
234+
235+
private static class PublisherTypReferencePartBuilder<S, P extends Publisher<S>>
236+
extends DefaultPartBuilder {
237+
238+
private final ParameterizedTypeReference<S> bodyType;
239+
240+
public PublisherTypReferencePartBuilder(P body, ParameterizedTypeReference<S> bodyType,
241+
HttpHeaders headers) {
242+
243+
super(body, headers);
244+
this.bodyType = bodyType;
245+
}
246+
247+
@Override
248+
@SuppressWarnings("unchecked")
249+
public HttpEntity<?> build() {
250+
P body = (P) this.body;
251+
Assert.state(body != null, "'body' must not be null");
252+
return HttpEntity.fromPublisher(body, this.bodyType, this.headers);
253+
}
254+
}
255+
160256
}

spring-web/src/main/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriter.java

+15-5
Original file line numberDiff line numberDiff line change
@@ -230,31 +230,41 @@ private <T> Flux<DataBuffer> encodePart(byte[] boundary, String name, T value) {
230230
MultipartHttpOutputMessage outputMessage = new MultipartHttpOutputMessage(this.bufferFactory, getCharset());
231231

232232
T body;
233+
ResolvableType bodyType = null;
233234
if (value instanceof HttpEntity) {
234-
outputMessage.getHeaders().putAll(((HttpEntity<T>) value).getHeaders());
235-
body = ((HttpEntity<T>) value).getBody();
235+
HttpEntity<T> httpEntity = (HttpEntity<T>) value;
236+
outputMessage.getHeaders().putAll(httpEntity.getHeaders());
237+
body = httpEntity.getBody();
236238
Assert.state(body != null, "MultipartHttpMessageWriter only supports HttpEntity with body");
239+
bodyType = httpEntity.getBodyType();
237240
}
238241
else {
239242
body = value;
240243
}
241244

245+
if (bodyType == null) {
246+
bodyType = ResolvableType.forClass(body.getClass());
247+
}
248+
242249
String filename = (body instanceof Resource ? ((Resource) body).getFilename() : null);
243250
outputMessage.getHeaders().setContentDispositionFormData(name, filename);
244251

245-
ResolvableType bodyType = ResolvableType.forClass(body.getClass());
246252
MediaType contentType = outputMessage.getHeaders().getContentType();
247253

254+
final ResolvableType finalBodyType = bodyType;
248255
Optional<HttpMessageWriter<?>> writer = this.partWriters.stream()
249-
.filter(partWriter -> partWriter.canWrite(bodyType, contentType))
256+
.filter(partWriter -> partWriter.canWrite(finalBodyType, contentType))
250257
.findFirst();
251258

252259
if (!writer.isPresent()) {
253260
return Flux.error(new CodecException("No suitable writer found for part: " + name));
254261
}
255262

263+
Publisher<T> bodyPublisher =
264+
body instanceof Publisher ? (Publisher<T>) body : Mono.just(body);
265+
256266
Mono<Void> partWritten = ((HttpMessageWriter<T>) writer.get())
257-
.write(Mono.just(body), bodyType, contentType, outputMessage, Collections.emptyMap());
267+
.write(bodyPublisher, bodyType, contentType, outputMessage, Collections.emptyMap());
258268

259269
// partWritten.subscribe() is required in order to make sure MultipartHttpOutputMessage#getBody()
260270
// returns a non-null value (occurs with ResourceHttpMessageWriter that invokes

spring-web/src/test/java/org/springframework/http/client/MultipartBodyBuilderTests.java

+16-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
package org.springframework.http.client;
1818

1919
import org.junit.Test;
20+
import org.reactivestreams.Publisher;
21+
import reactor.core.publisher.Flux;
2022

23+
import org.springframework.core.ResolvableType;
2124
import org.springframework.core.io.ClassPathResource;
2225
import org.springframework.core.io.Resource;
2326
import org.springframework.http.HttpEntity;
@@ -34,23 +37,25 @@ public class MultipartBodyBuilderTests {
3437

3538
@Test
3639
public void builder() throws Exception {
37-
MultiValueMap<String, String> form = new LinkedMultiValueMap<>();
38-
form.add("form field", "form value");
40+
MultiValueMap<String, String> multipartData = new LinkedMultiValueMap<>();
41+
multipartData.add("form field", "form value");
3942
Resource logo = new ClassPathResource("/org/springframework/http/converter/logo.jpg");
4043
HttpHeaders entityHeaders = new HttpHeaders();
4144
entityHeaders.add("foo", "bar");
4245
HttpEntity<String> entity = new HttpEntity<>("body", entityHeaders);
46+
Publisher<String> publisher = Flux.just("foo", "bar", "baz");
4347

4448
MultipartBodyBuilder builder = new MultipartBodyBuilder();
45-
builder.part("key", form).header("foo", "bar");
49+
builder.part("key", multipartData).header("foo", "bar");
4650
builder.part("logo", logo).header("baz", "qux");
4751
builder.part("entity", entity).header("baz", "qux");
52+
builder.asyncPart("publisher", publisher, String.class).header("baz", "qux");
4853

4954
MultiValueMap<String, HttpEntity<?>> result = builder.build();
5055

51-
assertEquals(3, result.size());
56+
assertEquals(4, result.size());
5257
assertNotNull(result.getFirst("key"));
53-
assertEquals(form, result.getFirst("key").getBody());
58+
assertEquals(multipartData, result.getFirst("key").getBody());
5459
assertEquals("bar", result.getFirst("key").getHeaders().getFirst("foo"));
5560

5661
assertNotNull(result.getFirst("logo"));
@@ -61,6 +66,12 @@ public void builder() throws Exception {
6166
assertEquals("body", result.getFirst("entity").getBody());
6267
assertEquals("bar", result.getFirst("entity").getHeaders().getFirst("foo"));
6368
assertEquals("qux", result.getFirst("entity").getHeaders().getFirst("baz"));
69+
70+
assertNotNull(result.getFirst("publisher"));
71+
assertEquals(publisher, result.getFirst("publisher").getBody());
72+
assertEquals(ResolvableType.forClass(String.class), result.getFirst("publisher").getBodyType());
73+
assertEquals("bar", result.getFirst("entity").getHeaders().getFirst("foo"));
74+
assertEquals("qux", result.getFirst("entity").getHeaders().getFirst("baz"));
6475
}
6576

6677

0 commit comments

Comments
 (0)