|
34 | 34 | import org.springframework.http.server.reactive.ServerHttpRequest;
|
35 | 35 | import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
|
36 | 36 | import org.springframework.security.authentication.TestingAuthenticationToken;
|
| 37 | +import org.springframework.security.core.authority.AuthorityUtils; |
37 | 38 | import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
38 | 39 | import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
| 40 | +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; |
39 | 41 | import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
40 | 42 | import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
41 | 43 | import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
42 | 44 | import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
43 | 45 | import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
44 | 46 | import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
45 | 47 | import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
| 48 | +import org.springframework.security.oauth2.core.user.DefaultOAuth2User; |
| 49 | +import org.springframework.security.oauth2.core.user.OAuth2User; |
46 | 50 | import org.springframework.web.reactive.function.BodyInserter;
|
47 | 51 | import org.springframework.web.reactive.function.client.ClientRequest;
|
48 | 52 | import reactor.core.publisher.Mono;
|
|
51 | 55 | import java.time.Duration;
|
52 | 56 | import java.time.Instant;
|
53 | 57 | import java.util.ArrayList;
|
| 58 | +import java.util.Collections; |
54 | 59 | import java.util.HashMap;
|
55 | 60 | import java.util.List;
|
56 | 61 | import java.util.Map;
|
|
60 | 65 | import static org.mockito.ArgumentMatchers.any;
|
61 | 66 | import static org.mockito.ArgumentMatchers.eq;
|
62 | 67 | import static org.mockito.Mockito.verify;
|
| 68 | +import static org.mockito.Mockito.verifyZeroInteractions; |
63 | 69 | import static org.mockito.Mockito.when;
|
64 | 70 | import static org.springframework.http.HttpMethod.GET;
|
65 | 71 | import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
|
@@ -293,6 +299,59 @@ public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
|
293 | 299 | assertThat(getBody(request0)).isEmpty();
|
294 | 300 | }
|
295 | 301 |
|
| 302 | + @Test |
| 303 | + public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() { |
| 304 | + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); |
| 305 | + this.function.setDefaultOAuth2AuthorizedClient(true); |
| 306 | + |
| 307 | + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); |
| 308 | + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, |
| 309 | + "principalName", this.accessToken, refreshToken); |
| 310 | + when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient)); |
| 311 | + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration)); |
| 312 | + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) |
| 313 | + .build(); |
| 314 | + |
| 315 | + OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections |
| 316 | + .singletonMap("user", "rob"), "user"); |
| 317 | + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id"); |
| 318 | + this.function |
| 319 | + .filter(request, this.exchange) |
| 320 | + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) |
| 321 | + .block(); |
| 322 | + |
| 323 | + List<ClientRequest> requests = this.exchange.getRequests(); |
| 324 | + assertThat(requests).hasSize(1); |
| 325 | + |
| 326 | + ClientRequest request0 = requests.get(0); |
| 327 | + assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0"); |
| 328 | + assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com"); |
| 329 | + assertThat(request0.method()).isEqualTo(HttpMethod.GET); |
| 330 | + assertThat(getBody(request0)).isEmpty(); |
| 331 | + } |
| 332 | + |
| 333 | + @Test |
| 334 | + public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() { |
| 335 | + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); |
| 336 | + |
| 337 | + ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) |
| 338 | + .build(); |
| 339 | + |
| 340 | + OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections |
| 341 | + .singletonMap("user", "rob"), "user"); |
| 342 | + OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id"); |
| 343 | + |
| 344 | + this.function |
| 345 | + .filter(request, this.exchange) |
| 346 | + .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) |
| 347 | + .block(); |
| 348 | + |
| 349 | + List<ClientRequest> requests = this.exchange.getRequests(); |
| 350 | + assertThat(requests).hasSize(1); |
| 351 | + |
| 352 | + verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository); |
| 353 | + } |
| 354 | + |
296 | 355 | private static String getBody(ClientRequest request) {
|
297 | 356 | final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
|
298 | 357 | messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
|
|
0 commit comments