|
29 | 29 | import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
30 | 30 | import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
31 | 31 | import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
| 32 | +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; |
| 33 | +import org.springframework.security.oauth2.core.OAuth2AuthorizationException; |
| 34 | +import org.springframework.security.oauth2.core.OAuth2Error; |
32 | 35 | import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
33 | 36 | import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
34 | 37 | import org.springframework.security.web.server.savedrequest.ServerRequestCache;
|
|
45 | 48 |
|
46 | 49 | import static org.assertj.core.api.Assertions.assertThat;
|
47 | 50 | import static org.assertj.core.api.Assertions.assertThatCode;
|
| 51 | +import static org.assertj.core.api.Assertions.assertThatThrownBy; |
48 | 52 | import static org.mockito.ArgumentMatchers.any;
|
49 | 53 | import static org.mockito.Mockito.mock;
|
50 | 54 | import static org.mockito.Mockito.times;
|
@@ -279,6 +283,65 @@ public void filterWhenAuthorizationSucceedsAndRequestCacheConfiguredThenRequestC
|
279 | 283 | assertThat(exchange.getResponse().getHeaders().getLocation().toString()).isEqualTo("/saved-request");
|
280 | 284 | }
|
281 | 285 |
|
| 286 | + // gh-8609 |
| 287 | + @Test |
| 288 | + public void filterWhenAuthenticationConverterThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { |
| 289 | + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); |
| 290 | + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.empty()); |
| 291 | + |
| 292 | + MockServerHttpRequest authorizationRequest = |
| 293 | + createAuthorizationRequest("/authorization/callback"); |
| 294 | + OAuth2AuthorizationRequest oauth2AuthorizationRequest = |
| 295 | + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); |
| 296 | + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) |
| 297 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 298 | + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) |
| 299 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 300 | + |
| 301 | + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); |
| 302 | + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); |
| 303 | + DefaultWebFilterChain chain = new DefaultWebFilterChain( |
| 304 | + e -> e.getResponse().setComplete(), Collections.emptyList()); |
| 305 | + |
| 306 | + assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) |
| 307 | + .isInstanceOf(OAuth2AuthenticationException.class) |
| 308 | + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) |
| 309 | + .extracting("errorCode") |
| 310 | + .isEqualTo("client_registration_not_found"); |
| 311 | + verifyNoInteractions(this.authenticationManager); |
| 312 | + } |
| 313 | + |
| 314 | + // gh-8609 |
| 315 | + @Test |
| 316 | + public void filterWhenAuthenticationManagerThrowsOAuth2AuthorizationExceptionThenMappedToOAuth2AuthenticationException() { |
| 317 | + ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); |
| 318 | + when(this.clientRegistrationRepository.findByRegistrationId(any())) |
| 319 | + .thenReturn(Mono.just(clientRegistration)); |
| 320 | + |
| 321 | + MockServerHttpRequest authorizationRequest = |
| 322 | + createAuthorizationRequest("/authorization/callback"); |
| 323 | + OAuth2AuthorizationRequest oauth2AuthorizationRequest = |
| 324 | + createOAuth2AuthorizationRequest(authorizationRequest, clientRegistration); |
| 325 | + when(this.authorizationRequestRepository.loadAuthorizationRequest(any())) |
| 326 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 327 | + when(this.authorizationRequestRepository.removeAuthorizationRequest(any())) |
| 328 | + .thenReturn(Mono.just(oauth2AuthorizationRequest)); |
| 329 | + |
| 330 | + when(this.authenticationManager.authenticate(any())) |
| 331 | + .thenReturn(Mono.error(new OAuth2AuthorizationException(new OAuth2Error("authorization_error")))); |
| 332 | + |
| 333 | + MockServerHttpRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); |
| 334 | + MockServerWebExchange exchange = MockServerWebExchange.from(authorizationResponse); |
| 335 | + DefaultWebFilterChain chain = new DefaultWebFilterChain( |
| 336 | + e -> e.getResponse().setComplete(), Collections.emptyList()); |
| 337 | + |
| 338 | + assertThatThrownBy(() -> this.filter.filter(exchange, chain).block()) |
| 339 | + .isInstanceOf(OAuth2AuthenticationException.class) |
| 340 | + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) |
| 341 | + .extracting("errorCode") |
| 342 | + .isEqualTo("authorization_error"); |
| 343 | + } |
| 344 | + |
282 | 345 | private static OAuth2AuthorizationRequest createOAuth2AuthorizationRequest(
|
283 | 346 | MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
|
284 | 347 | Map<String, Object> attributes = new HashMap<>();
|
|
0 commit comments