Skip to content

Commit 4c902bb

Browse files
committed
OAuth2AuthorizationCodeGrantWebFilter should handle OAuth2AuthorizationException
Fixes gh-8609
1 parent bb0fac6 commit 4c902bb

File tree

6 files changed

+110
-59
lines changed

6 files changed

+110
-59
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
import org.springframework.security.core.AuthenticationException;
2121
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
2222
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
23+
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
24+
import org.springframework.security.oauth2.core.OAuth2Error;
2325
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
26+
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
27+
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
2428
import org.springframework.util.Assert;
2529

2630
/**
@@ -40,6 +44,7 @@
4044
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response</a>
4145
*/
4246
public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
47+
private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
4348
private final OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
4449

4550
/**
@@ -59,8 +64,18 @@ public Authentication authenticate(Authentication authentication) throws Authent
5964
OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
6065
(OAuth2AuthorizationCodeAuthenticationToken) authentication;
6166

62-
OAuth2AuthorizationExchangeValidator.validate(
63-
authorizationCodeAuthentication.getAuthorizationExchange());
67+
OAuth2AuthorizationResponse authorizationResponse = authorizationCodeAuthentication
68+
.getAuthorizationExchange().getAuthorizationResponse();
69+
if (authorizationResponse.statusError()) {
70+
throw new OAuth2AuthorizationException(authorizationResponse.getError());
71+
}
72+
73+
OAuth2AuthorizationRequest authorizationRequest = authorizationCodeAuthentication
74+
.getAuthorizationExchange().getAuthorizationRequest();
75+
if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
76+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
77+
throw new OAuth2AuthorizationException(oauth2Error);
78+
}
6479

6580
OAuth2AccessTokenResponse accessTokenResponse =
6681
this.accessTokenResponseClient.getTokenResponse(

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationCodeReactiveAuthenticationManager.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2018 the original author or authors.
2+
* Copyright 2002-2020 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -22,9 +22,13 @@
2222
import org.springframework.security.oauth2.client.registration.ClientRegistration;
2323
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
2424
import org.springframework.security.oauth2.core.OAuth2AccessToken;
25+
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
26+
import org.springframework.security.oauth2.core.OAuth2Error;
2527
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
2628
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
2729
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
30+
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
31+
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
2832
import org.springframework.security.oauth2.core.user.OAuth2User;
2933
import org.springframework.util.Assert;
3034
import reactor.core.publisher.Mono;
@@ -55,8 +59,8 @@
5559
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
5660
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.4">Section 4.1.4 Access Token Response</a>
5761
*/
58-
public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements
59-
ReactiveAuthenticationManager {
62+
public class OAuth2AuthorizationCodeReactiveAuthenticationManager implements ReactiveAuthenticationManager {
63+
private static final String INVALID_STATE_PARAMETER_ERROR_CODE = "invalid_state_parameter";
6064
private final ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
6165

6266
public OAuth2AuthorizationCodeReactiveAuthenticationManager(
@@ -70,7 +74,16 @@ public Mono<Authentication> authenticate(Authentication authentication) {
7074
return Mono.defer(() -> {
7175
OAuth2AuthorizationCodeAuthenticationToken token = (OAuth2AuthorizationCodeAuthenticationToken) authentication;
7276

73-
OAuth2AuthorizationExchangeValidator.validate(token.getAuthorizationExchange());
77+
OAuth2AuthorizationResponse authorizationResponse = token.getAuthorizationExchange().getAuthorizationResponse();
78+
if (authorizationResponse.statusError()) {
79+
return Mono.error(new OAuth2AuthorizationException(authorizationResponse.getError()));
80+
}
81+
82+
OAuth2AuthorizationRequest authorizationRequest = token.getAuthorizationExchange().getAuthorizationRequest();
83+
if (!authorizationResponse.getState().equals(authorizationRequest.getState())) {
84+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_STATE_PARAMETER_ERROR_CODE);
85+
return Mono.error(new OAuth2AuthorizationException(oauth2Error));
86+
}
7487

7588
OAuth2AuthorizationCodeGrantRequest authzRequest = new OAuth2AuthorizationCodeGrantRequest(
7689
token.getClientRegistration(),

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/authentication/OAuth2AuthorizationExchangeValidator.java

Lines changed: 0 additions & 47 deletions
This file was deleted.

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilter.java

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
2828
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
2929
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
30+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
31+
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
3032
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
3133
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
3234
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -201,15 +203,21 @@ private void updateDefaultAuthenticationSuccessHandler() {
201203
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
202204
return this.requiresAuthenticationMatcher.matches(exchange)
203205
.filter(ServerWebExchangeMatcher.MatchResult::isMatch)
204-
.flatMap(matchResult -> this.authenticationConverter.convert(exchange))
206+
.flatMap(matchResult ->
207+
this.authenticationConverter.convert(exchange)
208+
.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(
209+
e.getError(), e.getError().toString())))
205210
.switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
206-
.flatMap(token -> authenticate(exchange, chain, token));
211+
.flatMap(token -> authenticate(exchange, chain, token))
212+
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler
213+
.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e));
207214
}
208215

209-
private Mono<Void> authenticate(ServerWebExchange exchange,
210-
WebFilterChain chain, Authentication token) {
216+
private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
211217
WebFilterExchange webFilterExchange = new WebFilterExchange(exchange, chain);
212218
return this.authenticationManager.authenticate(token)
219+
.onErrorMap(OAuth2AuthorizationException.class, e -> new OAuth2AuthenticationException(
220+
e.getError(), e.getError().toString()))
213221
.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalStateException("No provider found for " + token.getClass()))))
214222
.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange))
215223
.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizationCodeAuthenticationTokenConverter.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import org.springframework.security.core.Authentication;
2020
import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
21-
import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
2221
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
2322
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
2423
import org.springframework.security.oauth2.core.OAuth2Error;
@@ -33,7 +32,7 @@
3332
import reactor.core.publisher.Mono;
3433

3534
/**
36-
* Converts from a {@link ServerWebExchange} to an {@link OAuth2LoginAuthenticationToken} that can be authenticated. The
35+
* Converts from a {@link ServerWebExchange} to an {@link OAuth2AuthorizationCodeAuthenticationToken} that can be authenticated. The
3736
* converter does not validate any errors it only performs a conversion.
3837
* @author Rob Winch
3938
* @since 5.1

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/OAuth2AuthorizationCodeGrantWebFilterTests.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
import org.springframework.security.oauth2.client.registration.ClientRegistration;
3030
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
3131
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
32+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
33+
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
34+
import org.springframework.security.oauth2.core.OAuth2Error;
3235
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
3336
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
3437
import org.springframework.security.web.server.savedrequest.ServerRequestCache;
@@ -45,6 +48,7 @@
4548

4649
import static org.assertj.core.api.Assertions.assertThat;
4750
import static org.assertj.core.api.Assertions.assertThatCode;
51+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4852
import static org.mockito.ArgumentMatchers.any;
4953
import static org.mockito.Mockito.mock;
5054
import static org.mockito.Mockito.times;
@@ -279,6 +283,65 @@ public void filterWhenAuthorizationSucceedsAndRequestCacheConfiguredThenRequestC
279283
assertThat(exchange.getResponse().getHeaders().getLocation().toString()).isEqualTo("/saved-request");
280284
}
281285

286+
// gh-8609
287+
@Test
288+
public void filterWhenAuthenticationConverterThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() {
289+
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
290+
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.empty());
291+
292+
MockServerHttpRequest authorizationRequest =
293+
createAuthorizationRequest("/authorization/callback");
294+
OAuth2AuthorizationRequest oauth2AuthorizationRequest =
295+
createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
296+
when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
297+
.thenReturn(Mono.just(oauth2AuthorizationRequest));
298+
when(this.authorizationRequestRepository.removeAuthorizationRequest(any()))
299+
.thenReturn(Mono.just(oauth2AuthorizationRequest));
300+
301+
MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
302+
MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
303+
DefaultWebFilterChain chain = new DefaultWebFilterChain(
304+
e -> e.getResponse().setComplete(), Collections.emptyList());
305+
306+
assertThatThrownBy(() -> this.filter.filter(exchange, chain).block())
307+
.isInstanceOf(OAuth2AuthenticationException.class)
308+
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
309+
.extracting("errorCode")
310+
.isEqualTo("client_registration_not_found");
311+
verifyNoInteractions(this.authenticationManager);
312+
}
313+
314+
// gh-8609
315+
@Test
316+
public void filterWhenAuthenticationManagerThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() {
317+
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
318+
when(this.clientRegistrationRepository.findByRegistrationId(any()))
319+
.thenReturn(Mono.just(clientRegistration));
320+
321+
MockServerHttpRequest authorizationRequest =
322+
createAuthorizationRequest("/authorization/callback");
323+
OAuth2AuthorizationRequest oauth2AuthorizationRequest =
324+
createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration);
325+
when(this.authorizationRequestRepository.loadAuthorizationRequest(any()))
326+
.thenReturn(Mono.just(oauth2AuthorizationRequest));
327+
when(this.authorizationRequestRepository.removeAuthorizationRequest(any()))
328+
.thenReturn(Mono.just(oauth2AuthorizationRequest));
329+
330+
when(this.authenticationManager.authenticate(any()))
331+
.thenReturn(Mono.error(new OAuth2AuthorizationException(new OAuth2Error("authorization_error"))));
332+
333+
MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest);
334+
MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse);
335+
DefaultWebFilterChain chain = new DefaultWebFilterChain(
336+
e -> e.getResponse().setComplete(), Collections.emptyList());
337+
338+
assertThatThrownBy(() -> this.filter.filter(exchange, chain).block())
339+
.isInstanceOf(OAuth2AuthenticationException.class)
340+
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
341+
.extracting("errorCode")
342+
.isEqualTo("authorization_error");
343+
}
344+
282345
private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(
283346
MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
284347
Map<String, Object> attributes = new HashMap<>();

0 commit comments

Comments
 (0)