Skip to content

Commit 04f9549

Browse files
committed
ServletServerHttpRequest.getURI() ignores malformed query string
The resolved URI instance is also being cached now. This should not make a difference in a real Servlet environment but does affect tests which assumed they could modify an HttpServletRequest path behind a pre-created ServletServerHttpRequest instance. Our WebSocket test base class has been revised accordingly, re-creating the ServletServerHttpRequest in such a case. Issue: SPR-16414 (cherry picked from commit 0e6f8df)
1 parent fe4472d commit 04f9549

File tree

9 files changed

+153
-101
lines changed

9 files changed

+153
-101
lines changed

spring-web/src/main/java/org/springframework/http/HttpRequest.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2015 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -35,7 +35,8 @@ public interface HttpRequest extends HttpMessage {
3535
HttpMethod getMethod();
3636

3737
/**
38-
* Return the URI of the request.
38+
* Return the URI of the request (including a query string if any,
39+
* but only if it is well-formed for a URI representation).
3940
* @return the URI of the request (never {@code null})
4041
*/
4142
URI getURI();

spring-web/src/main/java/org/springframework/http/server/ServletServerHttpRequest.java

+39-16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -59,6 +59,8 @@ public class ServletServerHttpRequest implements ServerHttpRequest {
5959

6060
private final HttpServletRequest servletRequest;
6161

62+
private URI uri;
63+
6264
private HttpHeaders headers;
6365

6466
private ServerHttpAsyncRequestControl asyncRequestControl;
@@ -88,34 +90,54 @@ public HttpMethod getMethod() {
8890

8991
@Override
9092
public URI getURI() {
91-
try {
92-
StringBuffer url = this.servletRequest.getRequestURL();
93-
String query = this.servletRequest.getQueryString();
94-
if (StringUtils.hasText(query)) {
95-
url.append('?').append(query);
93+
if (this.uri == null) {
94+
String urlString = null;
95+
boolean hasQuery = false;
96+
try {
97+
StringBuffer url = this.servletRequest.getRequestURL();
98+
String query = this.servletRequest.getQueryString();
99+
hasQuery = StringUtils.hasText(query);
100+
if (hasQuery) {
101+
url.append('?').append(query);
102+
}
103+
urlString = url.toString();
104+
this.uri = new URI(urlString);
105+
}
106+
catch (URISyntaxException ex) {
107+
if (!hasQuery) {
108+
throw new IllegalStateException(
109+
"Could not resolve HttpServletRequest as URI: " + urlString, ex);
110+
}
111+
// Maybe a malformed query string... try plain request URL
112+
try {
113+
urlString = this.servletRequest.getRequestURL().toString();
114+
this.uri = new URI(urlString);
115+
}
116+
catch (URISyntaxException ex2) {
117+
throw new IllegalStateException(
118+
"Could not resolve HttpServletRequest as URI: " + urlString, ex2);
119+
}
96120
}
97-
return new URI(url.toString());
98-
}
99-
catch (URISyntaxException ex) {
100-
throw new IllegalStateException("Could not get HttpServletRequest URI: " + ex.getMessage(), ex);
101121
}
122+
return this.uri;
102123
}
103124

104125
@Override
105126
public HttpHeaders getHeaders() {
106127
if (this.headers == null) {
107128
this.headers = new HttpHeaders();
108129

109-
for (Enumeration<?> headerNames = this.servletRequest.getHeaderNames(); headerNames.hasMoreElements();) {
110-
String headerName = (String) headerNames.nextElement();
130+
for (Enumeration<?> names = this.servletRequest.getHeaderNames(); names.hasMoreElements();) {
131+
String headerName = (String) names.nextElement();
111132
for (Enumeration<?> headerValues = this.servletRequest.getHeaders(headerName);
112133
headerValues.hasMoreElements();) {
113134
String headerValue = (String) headerValues.nextElement();
114135
this.headers.add(headerName, headerValue);
115136
}
116137
}
117138

118-
// HttpServletRequest exposes some headers as properties: we should include those if not already present
139+
// HttpServletRequest exposes some headers as properties:
140+
// we should include those if not already present
119141
try {
120142
MediaType contentType = this.headers.getContentType();
121143
if (contentType == null) {
@@ -132,8 +154,8 @@ public HttpHeaders getHeaders() {
132154
Map<String, String> params = new LinkedCaseInsensitiveMap<String>();
133155
params.putAll(contentType.getParameters());
134156
params.put("charset", charSet.toString());
135-
MediaType newContentType = new MediaType(contentType.getType(), contentType.getSubtype(), params);
136-
this.headers.setContentType(newContentType);
157+
MediaType mediaType = new MediaType(contentType.getType(), contentType.getSubtype(), params);
158+
this.headers.setContentType(mediaType);
137159
}
138160
}
139161
}
@@ -181,7 +203,8 @@ public InputStream getBody() throws IOException {
181203
public ServerHttpAsyncRequestControl getAsyncRequestControl(ServerHttpResponse response) {
182204
if (this.asyncRequestControl == null) {
183205
if (!ServletServerHttpResponse.class.isInstance(response)) {
184-
throw new IllegalArgumentException("Response must be a ServletServerHttpResponse: " + response.getClass());
206+
throw new IllegalArgumentException(
207+
"Response must be a ServletServerHttpResponse: " + response.getClass());
185208
}
186209
ServletServerHttpResponse servletServerResponse = (ServletServerHttpResponse) response;
187210
this.asyncRequestControl = new ServletServerHttpAsyncRequestControl(this, servletServerResponse);

spring-web/src/test/java/org/springframework/http/server/ServletServerHttpRequestTests.java

+42-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2016 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,8 +16,10 @@
1616

1717
package org.springframework.http.server;
1818

19+
import java.io.IOException;
1920
import java.net.URI;
20-
import java.nio.charset.Charset;
21+
import java.net.URISyntaxException;
22+
import java.nio.charset.StandardCharsets;
2123
import java.util.List;
2224

2325
import org.junit.Before;
@@ -33,6 +35,7 @@
3335

3436
/**
3537
* @author Arjen Poutsma
38+
* @author Juergen Hoeller
3639
*/
3740
public class ServletServerHttpRequestTests {
3841

@@ -42,42 +45,68 @@ public class ServletServerHttpRequestTests {
4245

4346

4447
@Before
45-
public void create() throws Exception {
48+
public void create() {
4649
mockRequest = new MockHttpServletRequest();
4750
request = new ServletServerHttpRequest(mockRequest);
4851
}
4952

5053

5154
@Test
52-
public void getMethod() throws Exception {
55+
public void getMethod() {
5356
mockRequest.setMethod("POST");
5457
assertEquals("Invalid method", HttpMethod.POST, request.getMethod());
5558
}
5659

5760
@Test
58-
public void getURI() throws Exception {
61+
public void getUriForSimplePath() throws URISyntaxException {
62+
URI uri = new URI("http://example.com/path");
63+
mockRequest.setServerName(uri.getHost());
64+
mockRequest.setServerPort(uri.getPort());
65+
mockRequest.setRequestURI(uri.getPath());
66+
mockRequest.setQueryString(uri.getQuery());
67+
assertEquals(uri, request.getURI());
68+
}
69+
70+
@Test
71+
public void getUriWithQueryString() throws URISyntaxException {
5972
URI uri = new URI("http://example.com/path?query");
6073
mockRequest.setServerName(uri.getHost());
6174
mockRequest.setServerPort(uri.getPort());
6275
mockRequest.setRequestURI(uri.getPath());
6376
mockRequest.setQueryString(uri.getQuery());
64-
assertEquals("Invalid uri", uri, request.getURI());
77+
assertEquals(uri, request.getURI());
78+
}
79+
80+
@Test // SPR-16414
81+
public void getUriWithQueryParam() throws URISyntaxException {
82+
mockRequest.setServerName("example.com");
83+
mockRequest.setRequestURI("/path");
84+
mockRequest.setQueryString("query=foo");
85+
assertEquals(new URI("http://example.com/path?query=foo"), request.getURI());
86+
}
87+
88+
@Test // SPR-16414
89+
public void getUriWithMalformedQueryParam() throws URISyntaxException {
90+
mockRequest.setServerName("example.com");
91+
mockRequest.setRequestURI("/path");
92+
mockRequest.setQueryString("query=foo%%x");
93+
assertEquals(new URI("http://example.com/path"), request.getURI());
6594
}
6695

6796
@Test // SPR-13876
68-
public void getUriWithEncoding() throws Exception {
97+
public void getUriWithEncoding() throws URISyntaxException {
6998
URI uri = new URI("https://example.com/%E4%B8%AD%E6%96%87" +
7099
"?redirect=https%3A%2F%2Fgithub.com%2Fspring-projects%2Fspring-framework");
71100
mockRequest.setScheme(uri.getScheme());
72101
mockRequest.setServerName(uri.getHost());
73102
mockRequest.setServerPort(uri.getPort());
74103
mockRequest.setRequestURI(uri.getRawPath());
75104
mockRequest.setQueryString(uri.getRawQuery());
76-
assertEquals("Invalid uri", uri, request.getURI());
105+
assertEquals(uri, request.getURI());
77106
}
78107

79108
@Test
80-
public void getHeaders() throws Exception {
109+
public void getHeaders() {
81110
String headerName = "MyHeader";
82111
String headerValue1 = "value1";
83112
String headerValue2 = "value2";
@@ -93,12 +122,12 @@ public void getHeaders() throws Exception {
93122
assertEquals("Invalid header values returned", 2, headerValues.size());
94123
assertTrue("Invalid header values returned", headerValues.contains(headerValue1));
95124
assertTrue("Invalid header values returned", headerValues.contains(headerValue2));
96-
assertEquals("Invalid Content-Type", new MediaType("text", "plain", Charset.forName("UTF-8")),
125+
assertEquals("Invalid Content-Type", new MediaType("text", "plain", StandardCharsets.UTF_8),
97126
headers.getContentType());
98127
}
99128

100129
@Test
101-
public void getHeadersWithEmptyContentTypeAndEncoding() throws Exception {
130+
public void getHeadersWithEmptyContentTypeAndEncoding() {
102131
String headerName = "MyHeader";
103132
String headerValue1 = "value1";
104133
String headerValue2 = "value2";
@@ -118,7 +147,7 @@ public void getHeadersWithEmptyContentTypeAndEncoding() throws Exception {
118147
}
119148

120149
@Test
121-
public void getBody() throws Exception {
150+
public void getBody() throws IOException {
122151
byte[] content = "Hello World".getBytes("UTF-8");
123152
mockRequest.setContent(content);
124153

@@ -127,7 +156,7 @@ public void getBody() throws Exception {
127156
}
128157

129158
@Test
130-
public void getFormBody() throws Exception {
159+
public void getFormBody() throws IOException {
131160
// Charset (SPR-8676)
132161
mockRequest.setContentType("application/x-www-form-urlencoded; charset=UTF-8");
133162
mockRequest.setMethod("POST");

spring-websocket/src/test/java/org/springframework/web/socket/AbstractHttpRequestTests.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2015 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -45,13 +45,14 @@ public abstract class AbstractHttpRequestTests {
4545

4646

4747
@Before
48-
public void setUp() {
48+
public void setup() {
4949
resetRequestAndResponse();
5050
}
5151

5252
protected void setRequest(String method, String requestUri) {
5353
this.servletRequest.setMethod(method);
5454
this.servletRequest.setRequestURI(requestUri);
55+
this.request = new ServletServerHttpRequest(this.servletRequest);
5556
}
5657

5758
protected void resetRequestAndResponse() {

spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java

+21-27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2014 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -50,19 +50,18 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
5050

5151

5252
@Before
53-
public void setup() throws Exception {
53+
public void setup() {
54+
super.setup();
55+
5456
MockitoAnnotations.initMocks(this);
5557
this.handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy);
5658
}
5759

5860

5961
@Test
60-
public void supportedSubProtocols() throws Exception {
61-
62+
public void supportedSubProtocols() {
6263
this.handshakeHandler.setSupportedProtocols("stomp", "mqtt");
63-
6464
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
65-
6665
this.servletRequest.setMethod("GET");
6766

6867
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
@@ -73,22 +72,20 @@ public void supportedSubProtocols() throws Exception {
7372
headers.setSecWebSocketProtocol("STOMP");
7473

7574
WebSocketHandler handler = new TextWebSocketHandler();
76-
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
75+
Map<String, Object> attributes = Collections.emptyMap();
7776
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
7877

79-
verify(this.upgradeStrategy).upgrade(this.request, this.response,
80-
"STOMP", Collections.<WebSocketExtension>emptyList(), null, handler, attributes);
78+
verify(this.upgradeStrategy).upgrade(this.request, this.response, "STOMP",
79+
Collections.emptyList(), null, handler, attributes);
8180
}
8281

83-
8482
@Test
85-
public void supportedExtensions() throws Exception {
86-
83+
public void supportedExtensions() {
8784
WebSocketExtension extension1 = new WebSocketExtension("ext1");
8885
WebSocketExtension extension2 = new WebSocketExtension("ext2");
8986

9087
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
91-
given(this.upgradeStrategy.getSupportedExtensions(this.request)).willReturn(Arrays.asList(extension1));
88+
given(this.upgradeStrategy.getSupportedExtensions(this.request)).willReturn(Collections.singletonList(extension1));
9289

9390
this.servletRequest.setMethod("GET");
9491

@@ -103,14 +100,13 @@ public void supportedExtensions() throws Exception {
103100
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
104101
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
105102

106-
verify(this.upgradeStrategy).upgrade(this.request, this.response, null, Arrays.asList(extension1),
107-
null, handler, attributes);
103+
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
104+
Collections.singletonList(extension1), null, handler, attributes);
108105
}
109106

110107
@Test
111-
public void subProtocolCapableHandler() throws Exception {
112-
113-
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[]{"13"});
108+
public void subProtocolCapableHandler() {
109+
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
114110

115111
this.servletRequest.setMethod("GET");
116112

@@ -125,14 +121,13 @@ public void subProtocolCapableHandler() throws Exception {
125121
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
126122
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
127123

128-
verify(this.upgradeStrategy).upgrade(this.request, this.response,
129-
"v11.stomp", Collections.<WebSocketExtension>emptyList(), null, handler, attributes);
124+
verify(this.upgradeStrategy).upgrade(this.request, this.response, "v11.stomp",
125+
Collections.emptyList(), null, handler, attributes);
130126
}
131127

132128
@Test
133-
public void subProtocolCapableHandlerNoMatch() throws Exception {
134-
135-
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[]{"13"});
129+
public void subProtocolCapableHandlerNoMatch() {
130+
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
136131

137132
this.servletRequest.setMethod("GET");
138133

@@ -147,17 +142,16 @@ public void subProtocolCapableHandlerNoMatch() throws Exception {
147142
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
148143
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
149144

150-
verify(this.upgradeStrategy).upgrade(this.request, this.response,
151-
null, Collections.<WebSocketExtension>emptyList(), null, handler, attributes);
145+
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
146+
Collections.emptyList(), null, handler, attributes);
152147
}
153148

154149

155150
private static class SubProtocolCapableHandler extends TextWebSocketHandler implements SubProtocolCapable {
156151

157152
private final List<String> subProtocols;
158153

159-
160-
private SubProtocolCapableHandler(String... subProtocols) {
154+
public SubProtocolCapableHandler(String... subProtocols) {
161155
this.subProtocols = Arrays.asList(subProtocols);
162156
}
163157

0 commit comments

Comments
 (0)