|
29 | 29 | import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
30 | 30 | import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
31 | 31 | import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
| 32 | +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; |
| 33 | +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; |
| 34 | +import org.springframework.security.oauth2.core.OAuth2Error; |
32 | 35 | import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
33 | 36 | import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
34 | 37 | import org.springframework.util.CollectionUtils;
|
|
41 | 44 | import java.util.Map;
|
42 | 45 |
|
43 | 46 | import static org.assertj.core.api.Assertions.assertThatCode;
|
| 47 | +import static org.assertj.core.api.Assertions.assertThatThrownBy; |
44 | 48 | import static org.mockito.ArgumentMatchers.any;
|
45 | 49 | import static org.mockito.Mockito.times;
|
46 | 50 | import static org.mockito.Mockito.verify;
|
@@ -233,6 +237,65 @@ public void filterWhenAuthorizationRequestRedirectUriParametersNotMatchThenNotPr
|
233 | 237 | verifyNoInteractions(this.authenticationManager);
|
234 | 238 | }
|
235 | 239 |
|
| 240 | + // gh-8609 |
| 241 | + @Test |
| 242 | + public void filterWhenAuthenticationConverterThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { |
| 243 | + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); |
| 244 | + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.empty()); |
| 245 | + |
| 246 | + MockServerHttpRequest authorizationRequest = |
| 247 | + createAuthorizationRequest("/authorization/callback"); |
| 248 | + OAuth2AuthorizationRequest oauth2AuthorizationRequest = |
| 249 | + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); |
| 250 | + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) |
| 251 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 252 | + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) |
| 253 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 254 | + |
| 255 | + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); |
| 256 | + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); |
| 257 | + DefaultWebFilterChain chain = new DefaultWebFilterChain( |
| 258 | + e -> e.getResponse().setComplete(), Collections.emptyList()); |
| 259 | + |
| 260 | + assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) |
| 261 | + .isInstanceOf(OAuth2AuthenticationException.class) |
| 262 | + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) |
| 263 | + .extracting("errorCode") |
| 264 | + .isEqualTo("client_registration_not_found"); |
| 265 | + verifyNoInteractions(this.authenticationManager); |
| 266 | + } |
| 267 | + |
| 268 | + // gh-8609 |
| 269 | + @Test |
| 270 | + public void filterWhenAuthenticationManagerThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { |
| 271 | + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); |
| 272 | + when(this.clientRegistrationRepository.findByRegistrationId(any())) |
| 273 | + .thenReturn(Mono.just(clientRegistration)); |
| 274 | + |
| 275 | + MockServerHttpRequest authorizationRequest = |
| 276 | + createAuthorizationRequest("/authorization/callback"); |
| 277 | + OAuth2AuthorizationRequest oauth2AuthorizationRequest = |
| 278 | + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); |
| 279 | + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) |
| 280 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 281 | + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) |
| 282 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 283 | + |
| 284 | + when(this.authenticationManager.authenticate(any())) |
| 285 | + .thenReturn(Mono.error(new OAuth2AuthorizationException(new OAuth2Error("authorization_error")))); |
| 286 | + |
| 287 | + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); |
| 288 | + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); |
| 289 | + DefaultWebFilterChain chain = new DefaultWebFilterChain( |
| 290 | + e -> e.getResponse().setComplete(), Collections.emptyList()); |
| 291 | + |
| 292 | + assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) |
| 293 | + .isInstanceOf(OAuth2AuthenticationException.class) |
| 294 | + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) |
| 295 | + .extracting("errorCode") |
| 296 | + .isEqualTo("authorization_error"); |
| 297 | + } |
| 298 | + |
236 | 299 | private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(
|
237 | 300 | MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
|
238 | 301 | Map<String, Object> attributes = new HashMap<>();
|
|
0 commit comments