From 46169c4bf320b0eb3be5f9e4feee64b09e07ab13 Mon Sep 17 00:00:00 2001 From: Einar Pehrson Date: Thu, 29 Dec 2016 15:52:23 +0100 Subject: [PATCH] Allow request interceptor to add to headers set via entity Provide a fully mutable HttpHeaders to ClientHttpRequestInterceptors of a RestTemplate when headers are set using HttpEntity. This avoids UnsupportedOperationException if both HttpEntity and ClientHttpRequestInterceptor add values for the same HTTP header. Issue: SPR-15066 --- .../web/client/RestTemplate.java | 14 ++++++--- .../web/client/RestTemplateTests.java | 30 +++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 358ffc4011b9..89fa1135ca24 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -20,9 +20,11 @@ import java.lang.reflect.Type; import java.net.URI; import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpEntity; @@ -844,7 +846,7 @@ public void doWithRequest(ClientHttpRequest httpRequest) throws IOException { HttpHeaders httpHeaders = httpRequest.getHeaders(); HttpHeaders requestHeaders = this.requestEntity.getHeaders(); if (!requestHeaders.isEmpty()) { - httpHeaders.putAll(requestHeaders); + httpHeaders.putAll(readWriteHttpHeaderMap(requestHeaders)); } if (httpHeaders.getContentLength() < 0) { httpHeaders.setContentLength(0L); @@ -862,7 +864,7 @@ public void doWithRequest(ClientHttpRequest httpRequest) throws IOException { GenericHttpMessageConverter genericMessageConverter = (GenericHttpMessageConverter) messageConverter; if (genericMessageConverter.canWrite(requestBodyType, requestBodyClass, requestContentType)) { if (!requestHeaders.isEmpty()) { - httpRequest.getHeaders().putAll(requestHeaders); + httpRequest.getHeaders().putAll(readWriteHttpHeaderMap(requestHeaders)); } if (logger.isDebugEnabled()) { if (requestContentType != null) { @@ -881,7 +883,7 @@ public void doWithRequest(ClientHttpRequest httpRequest) throws IOException { } else if (messageConverter.canWrite(requestBodyClass, requestContentType)) { if (!requestHeaders.isEmpty()) { - httpRequest.getHeaders().putAll(requestHeaders); + httpRequest.getHeaders().putAll(readWriteHttpHeaderMap(requestHeaders)); } if (logger.isDebugEnabled()) { if (requestContentType != null) { @@ -906,8 +908,12 @@ else if (messageConverter.canWrite(requestBodyClass, requestContentType)) { throw new RestClientException(message); } } - } + private Map> readWriteHttpHeaderMap(HttpHeaders httpHeaders) { + return httpHeaders.entrySet().stream().collect( + Collectors.toMap(Map.Entry::getKey, (entry) -> new LinkedList<>(entry.getValue()))); + } + } /** * Response extractor for {@link HttpEntity}. diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index 1b3f4fe24c12..4f788c19eb52 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -39,11 +39,14 @@ import org.springframework.http.ResponseEntity; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.web.util.DefaultUriTemplateHandler; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; @@ -802,4 +805,31 @@ public void exchangeParameterizedType() throws Exception { verify(response).close(); } + + @Test // SPR-15066 + public void requestInterceptorCanAddHeaderValue() throws Exception { + ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { + request.getHeaders().add("MyHeader", "MyInterceptorValue"); + return execution.execute(request, body); + }; + template.setInterceptors(Collections.singletonList(interceptor)); + + given(requestFactory.createRequest(new URI("http://example.com"), HttpMethod.POST)) + .willReturn(request); + HttpHeaders requestHeaders = new HttpHeaders(); + given(request.getHeaders()).willReturn(requestHeaders); + given(request.execute()).willReturn(response); + given(errorHandler.hasError(response)).willReturn(false); + HttpStatus status = HttpStatus.OK; + given(response.getStatusCode()).willReturn(status); + given(response.getStatusText()).willReturn(status.getReasonPhrase()); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.add("MyHeader", "MyEntityValue"); + HttpEntity entity = new HttpEntity<>(null, entityHeaders); + template.exchange("http://example.com", HttpMethod.POST, entity, Void.class); + assertThat(requestHeaders.get("MyHeader"), contains("MyEntityValue", "MyInterceptorValue")); + + verify(response).close(); + } }