Skip to content

Commit 28c0d14

Browse files
committed
Request Cache supports matchingRequestParameterName
1 parent 38cb6c3 commit 28c0d14

File tree

8 files changed

+294
-22
lines changed

8 files changed

+294
-22
lines changed

web/src/main/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixin.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.security.web.jackson2;
1818

1919
import com.fasterxml.jackson.annotation.JsonAutoDetect;
20+
import com.fasterxml.jackson.annotation.JsonInclude;
2021
import com.fasterxml.jackson.annotation.JsonTypeInfo;
2122
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
2223

@@ -43,4 +44,7 @@
4344
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
4445
abstract class DefaultSavedRequestMixin {
4546

47+
@JsonInclude(JsonInclude.Include.NON_NULL)
48+
String matchingRequestParameterName;
49+
4650
}

web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,15 @@ public class DefaultSavedRequest implements SavedRequest {
9898

9999
private final int serverPort;
100100

101-
@SuppressWarnings("unchecked")
101+
private final String matchingRequestParameterName;
102+
102103
public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver) {
104+
this(request, portResolver, null);
105+
}
106+
107+
@SuppressWarnings("unchecked")
108+
public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver,
109+
String matchingRequestParameterName) {
103110
Assert.notNull(request, "Request required");
104111
Assert.notNull(portResolver, "PortResolver required");
105112
// Cookies
@@ -132,6 +139,7 @@ public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver
132139
this.serverName = request.getServerName();
133140
this.contextPath = request.getContextPath();
134141
this.servletPath = request.getServletPath();
142+
this.matchingRequestParameterName = matchingRequestParameterName;
135143
}
136144

137145
/**
@@ -148,6 +156,7 @@ private DefaultSavedRequest(Builder builder) {
148156
this.serverName = builder.serverName;
149157
this.servletPath = builder.servletPath;
150158
this.serverPort = builder.serverPort;
159+
this.matchingRequestParameterName = builder.matchingRequestParameterName;
151160
}
152161

153162
/**
@@ -265,8 +274,9 @@ public List<Cookie> getCookies() {
265274
*/
266275
@Override
267276
public String getRedirectUrl() {
277+
String queryString = createQueryString(this.queryString, this.matchingRequestParameterName);
268278
return UrlUtils.buildFullRequestUrl(this.scheme, this.serverName, this.serverPort, this.requestURI,
269-
this.queryString);
279+
queryString);
270280
}
271281

272282
@Override
@@ -354,6 +364,19 @@ public String toString() {
354364
return "DefaultSavedRequest [" + getRedirectUrl() + "]";
355365
}
356366

367+
private static String createQueryString(String queryString, String matchingRequestParameterName) {
368+
if (matchingRequestParameterName == null) {
369+
return queryString;
370+
}
371+
if (queryString == null || queryString.length() == 0) {
372+
return matchingRequestParameterName;
373+
}
374+
if (queryString.endsWith("&")) {
375+
return queryString + matchingRequestParameterName;
376+
}
377+
return queryString + "&" + matchingRequestParameterName;
378+
}
379+
357380
/**
358381
* @since 4.2
359382
*/
@@ -389,6 +412,8 @@ public static class Builder {
389412

390413
private int serverPort = 80;
391414

415+
private String matchingRequestParameterName;
416+
392417
public Builder setCookies(List<SavedCookie> cookies) {
393418
this.cookies = cookies;
394419
return this;
@@ -459,6 +484,11 @@ public Builder setServerPort(int serverPort) {
459484
return this;
460485
}
461486

487+
public Builder setMatchingRequestParameterName(String matchingRequestParameterName) {
488+
this.matchingRequestParameterName = matchingRequestParameterName;
489+
return this;
490+
}
491+
462492
public DefaultSavedRequest build() {
463493
DefaultSavedRequest savedRequest = new DefaultSavedRequest(this);
464494
if (!ObjectUtils.isEmpty(this.cookies)) {

web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ public class HttpSessionRequestCache implements RequestCache {
5353

5454
private String sessionAttrName = SAVED_REQUEST;
5555

56+
private String matchingRequestParameterName;
57+
5658
/**
5759
* Stores the current request, provided the configuration properties allow it.
5860
*/
@@ -65,7 +67,8 @@ public void saveRequest(HttpServletRequest request, HttpServletResponse response
6567
}
6668
return;
6769
}
68-
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver);
70+
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver,
71+
this.matchingRequestParameterName);
6972
if (this.createSessionAllowed || request.getSession(false) != null) {
7073
// Store the HTTP request itself. Used by
7174
// AbstractAuthenticationProcessingFilter
@@ -97,6 +100,12 @@ public void removeRequest(HttpServletRequest currentRequest, HttpServletResponse
97100

98101
@Override
99102
public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) {
103+
if (this.matchingRequestParameterName != null
104+
&& request.getParameter(this.matchingRequestParameterName) == null) {
105+
this.logger.trace(
106+
"matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided");
107+
return null;
108+
}
100109
SavedRequest saved = getRequest(request, response);
101110
if (saved == null) {
102111
this.logger.trace("No saved request");
@@ -162,4 +171,16 @@ public void setSessionAttrName(String sessionAttrName) {
162171
this.sessionAttrName = sessionAttrName;
163172
}
164173

174+
/**
175+
* Specify the name of a query parameter that is added to the URL that specifies the
176+
* request cache should be checked in
177+
* {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)}
178+
* @param matchingRequestParameterName the parameter name that must be in the request
179+
* for {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)} to check
180+
* the session.
181+
*/
182+
public void setMatchingRequestParameterName(String matchingRequestParameterName) {
183+
this.matchingRequestParameterName = matchingRequestParameterName;
184+
}
185+
165186
}

web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult;
3535
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
3636
import org.springframework.util.Assert;
37+
import org.springframework.util.MultiValueMap;
3738
import org.springframework.web.server.ServerWebExchange;
3839
import org.springframework.web.server.WebSession;
40+
import org.springframework.web.util.UriComponentsBuilder;
3941

4042
/**
4143
* An implementation of {@link ServerRequestCache} that saves the
@@ -57,6 +59,8 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
5759

5860
private ServerWebExchangeMatcher saveRequestMatcher = createDefaultRequestMacher();
5961

62+
private String matchingRequestParameterName;
63+
6064
/**
6165
* Sets the matcher to determine if the request should be saved. The default is to
6266
* match on any GET request.
@@ -81,19 +85,53 @@ public Mono<Void> saveRequest(ServerWebExchange exchange) {
8185
public Mono<URI> getRedirectUri(ServerWebExchange exchange) {
8286
return exchange.getSession()
8387
.flatMap((session) -> Mono.justOrEmpty(session.<String>getAttribute(this.sessionAttrName)))
84-
.map(URI::create);
88+
.map(this::createRedirectUri);
8589
}
8690

8791
@Override
8892
public Mono<ServerHttpRequest> removeMatchingRequest(ServerWebExchange exchange) {
93+
MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
94+
if (this.matchingRequestParameterName != null && !queryParams.containsKey(this.matchingRequestParameterName)) {
95+
this.logger.trace(
96+
"matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided");
97+
return Mono.empty();
98+
}
99+
ServerHttpRequest request = stripMatchingRequestParameterName(exchange.getRequest());
89100
return exchange.getSession().map(WebSession::getAttributes).filter((attributes) -> {
90-
String requestPath = pathInApplication(exchange.getRequest());
101+
String requestPath = pathInApplication(request);
91102
boolean removed = attributes.remove(this.sessionAttrName, requestPath);
92103
if (removed) {
93104
logger.debug(LogMessage.format("Request removed from WebSession: '%s'", requestPath));
94105
}
95106
return removed;
96-
}).map((attributes) -> exchange.getRequest());
107+
}).map((attributes) -> request);
108+
}
109+
110+
/**
111+
* Specify the name of a query parameter that is added to the URL in
112+
* {@link #getRedirectUri(ServerWebExchange)} and is required for
113+
* {@link #removeMatchingRequest(ServerWebExchange)} to look up the
114+
* {@link ServerHttpRequest}.
115+
* @param matchingRequestParameterName the parameter name that must be in the request
116+
* for {@link #removeMatchingRequest(ServerWebExchange)} to check the session.
117+
*/
118+
public void setMatchingRequestParameterName(String matchingRequestParameterName) {
119+
this.matchingRequestParameterName = matchingRequestParameterName;
120+
}
121+
122+
private ServerHttpRequest stripMatchingRequestParameterName(ServerHttpRequest request) {
123+
if (this.matchingRequestParameterName == null) {
124+
return request;
125+
}
126+
// @formatter:off
127+
URI uri = UriComponentsBuilder.fromUri(request.getURI())
128+
.replaceQueryParam(this.matchingRequestParameterName)
129+
.build()
130+
.toUri();
131+
return request.mutate()
132+
.uri(uri)
133+
.build();
134+
// @formatter:on
97135
}
98136

99137
private static String pathInApplication(ServerHttpRequest request) {
@@ -102,6 +140,18 @@ private static String pathInApplication(ServerHttpRequest request) {
102140
return path + ((query != null) ? "?" + query : "");
103141
}
104142

143+
private URI createRedirectUri(String uri) {
144+
if (this.matchingRequestParameterName == null) {
145+
return URI.create(uri);
146+
}
147+
// @formatter:off
148+
return UriComponentsBuilder.fromUriString(uri)
149+
.queryParam(this.matchingRequestParameterName)
150+
.build()
151+
.toUri();
152+
// @formatter:on
153+
}
154+
105155
private static ServerWebExchangeMatcher createDefaultRequestMacher() {
106156
ServerWebExchangeMatcher get = ServerWebExchangeMatchers.pathMatchers(HttpMethod.GET, "/**");
107157
ServerWebExchangeMatcher notFavicon = new NegatedServerWebExchangeMatcher(
@@ -111,4 +161,17 @@ private static ServerWebExchangeMatcher createDefaultRequestMacher() {
111161
return new AndServerWebExchangeMatcher(get, notFavicon, html);
112162
}
113163

164+
private static String createQueryString(String queryString, String matchingRequestParameterName) {
165+
if (matchingRequestParameterName == null) {
166+
return queryString;
167+
}
168+
if (queryString == null || queryString.length() == 0) {
169+
return matchingRequestParameterName;
170+
}
171+
if (queryString.endsWith("&")) {
172+
return queryString + matchingRequestParameterName;
173+
}
174+
return queryString + "&" + matchingRequestParameterName;
175+
}
176+
114177
}

web/src/test/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixinTests.java

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,42 @@ public class DefaultSavedRequestMixinTests extends AbstractMixinTests {
5656
// @formatter:on
5757
// @formatter:off
5858
private static final String REQUEST_JSON = "{" +
59-
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
60-
+ "\"cookies\": " + COOKIES_JSON + ","
61-
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
62-
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
63-
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
64-
+ "\"contextPath\": \"\", "
65-
+ "\"method\": \"\", "
66-
+ "\"pathInfo\": null, "
67-
+ "\"queryString\": null, "
68-
+ "\"requestURI\": \"\", "
69-
+ "\"requestURL\": \"http://localhost\", "
70-
+ "\"scheme\": \"http\", "
71-
+ "\"serverName\": \"localhost\", "
72-
+ "\"servletPath\": \"\", "
73-
+ "\"serverPort\": 80"
74-
+ "}";
59+
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
60+
+ "\"cookies\": " + COOKIES_JSON + ","
61+
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
62+
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
63+
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
64+
+ "\"contextPath\": \"\", "
65+
+ "\"method\": \"\", "
66+
+ "\"pathInfo\": null, "
67+
+ "\"queryString\": null, "
68+
+ "\"requestURI\": \"\", "
69+
+ "\"requestURL\": \"http://localhost\", "
70+
+ "\"scheme\": \"http\", "
71+
+ "\"serverName\": \"localhost\", "
72+
+ "\"servletPath\": \"\", "
73+
+ "\"serverPort\": 80"
74+
+ "}";
75+
// @formatter:on
76+
// @formatter:off
77+
private static final String REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON = "{" +
78+
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
79+
+ "\"cookies\": " + COOKIES_JSON + ","
80+
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
81+
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
82+
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
83+
+ "\"contextPath\": \"\", "
84+
+ "\"method\": \"\", "
85+
+ "\"pathInfo\": null, "
86+
+ "\"queryString\": null, "
87+
+ "\"requestURI\": \"\", "
88+
+ "\"requestURL\": \"http://localhost\", "
89+
+ "\"scheme\": \"http\", "
90+
+ "\"serverName\": \"localhost\", "
91+
+ "\"servletPath\": \"\", "
92+
+ "\"serverPort\": 80, "
93+
+ "\"matchingRequestParameterName\": \"success\""
94+
+ "}";
7595
// @formatter:on
7696
@Test
7797
public void matchRequestBuildWithConstructorAndBuilder() {
@@ -126,4 +146,17 @@ public void deserializeDefaultSavedRequest() throws IOException {
126146
assertThat(request.getHeaderValues("x-auth-token")).hasSize(1).contains("12");
127147
}
128148

149+
@Test
150+
public void deserializeWhenMatchingRequestParameterNameThenRedirectUrlContainsParam() throws IOException {
151+
DefaultSavedRequest request = (DefaultSavedRequest) this.mapper
152+
.readValue(REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON, Object.class);
153+
assertThat(request.getRedirectUrl()).isEqualTo("http://localhost?success");
154+
}
155+
156+
@Test
157+
public void deserializeWhenNullMatchingRequestParameterNameThenRedirectUrlDoesNotContainParam() throws IOException {
158+
DefaultSavedRequest request = (DefaultSavedRequest) this.mapper.readValue(REQUEST_JSON, Object.class);
159+
assertThat(request.getRedirectUrl()).isEqualTo("http://localhost");
160+
}
161+
129162
}

0 commit comments

Comments
 (0)