Skip to content

Commit ac78258

Browse files
committed
ServerOAuth2AuthorizedClientExchangeFilterFunction defaultOAuth2AuthorizedClient
Defaults to use the OAuth2AuthenticationToken to resolve the authorized client Issue: gh-4921
1 parent 158b8aa commit ac78258

File tree

2 files changed

+90
-6
lines changed

2 files changed

+90
-6
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.security.core.context.SecurityContext;
2828
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
2929
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
30+
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
3031
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
3132
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
3233
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
@@ -86,6 +87,8 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
8687

8788
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
8889

90+
private boolean defaultOAuth2AuthorizedClient;
91+
8992
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
9093
new WebClientReactiveClientCredentialsTokenResponseClient();
9194

@@ -174,6 +177,17 @@ public static Consumer<Map<String, Object>> clientRegistrationId(String clientRe
174177
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
175178
}
176179

180+
/**
181+
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
182+
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
183+
* resolved from the current Authentication.
184+
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
185+
* Default is false.
186+
*/
187+
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
188+
this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
189+
}
190+
177191
/**
178192
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
179193
* client_credentials grant.
@@ -216,14 +230,25 @@ private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(Client
216230
if (this.authorizedClientRepository == null) {
217231
return Mono.empty();
218232
}
219-
String clientRegistrationId = (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
220-
if (clientRegistrationId == null) {
221-
return Mono.empty();
222-
}
233+
223234
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
224235
return currentAuthentication()
225-
.flatMap(principal -> loadAuthorizedClient(clientRegistrationId, exchange, principal)
226-
);
236+
.flatMap(principal -> clientRegistrationId(request, principal)
237+
.flatMap(clientRegistrationId -> loadAuthorizedClient(clientRegistrationId, exchange, principal))
238+
);
239+
}
240+
241+
private Mono<String> clientRegistrationId(ClientRequest request, Authentication authentication) {
242+
return Mono.justOrEmpty(request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME))
243+
.cast(String.class)
244+
.switchIfEmpty(clientRegistrationId(authentication));
245+
}
246+
247+
private Mono<String> clientRegistrationId(Authentication authentication) {
248+
return Mono.justOrEmpty(authentication)
249+
.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
250+
.cast(OAuth2AuthenticationToken.class)
251+
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
227252
}
228253

229254
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,19 @@
3434
import org.springframework.http.server.reactive.ServerHttpRequest;
3535
import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
3636
import org.springframework.security.authentication.TestingAuthenticationToken;
37+
import org.springframework.security.core.authority.AuthorityUtils;
3738
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
3839
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
40+
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
3941
import org.springframework.security.oauth2.client.registration.ClientRegistration;
4042
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
4143
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
4244
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
4345
import org.springframework.security.oauth2.core.OAuth2AccessToken;
4446
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
4547
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
48+
import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
49+
import org.springframework.security.oauth2.core.user.OAuth2User;
4650
import org.springframework.web.reactive.function.BodyInserter;
4751
import org.springframework.web.reactive.function.client.ClientRequest;
4852
import reactor.core.publisher.Mono;
@@ -51,6 +55,7 @@
5155
import java.time.Duration;
5256
import java.time.Instant;
5357
import java.util.ArrayList;
58+
import java.util.Collections;
5459
import java.util.HashMap;
5560
import java.util.List;
5661
import java.util.Map;
@@ -60,6 +65,7 @@
6065
import static org.mockito.ArgumentMatchers.any;
6166
import static org.mockito.ArgumentMatchers.eq;
6267
import static org.mockito.Mockito.verify;
68+
import static org.mockito.Mockito.verifyZeroInteractions;
6369
import static org.mockito.Mockito.when;
6470
import static org.springframework.http.HttpMethod.GET;
6571
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
@@ -293,6 +299,59 @@ public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
293299
assertThat(getBody(request0)).isEmpty();
294300
}
295301

302+
@Test
303+
public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() {
304+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
305+
this.function.setDefaultOAuth2AuthorizedClient(true);
306+
307+
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
308+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
309+
"principalName", this.accessToken, refreshToken);
310+
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
311+
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
312+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
313+
.build();
314+
315+
OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections
316+
.singletonMap("user", "rob"), "user");
317+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id");
318+
this.function
319+
.filter(request, this.exchange)
320+
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
321+
.block();
322+
323+
List<ClientRequest> requests = this.exchange.getRequests();
324+
assertThat(requests).hasSize(1);
325+
326+
ClientRequest request0 = requests.get(0);
327+
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
328+
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
329+
assertThat(request0.method()).isEqualTo(HttpMethod.GET);
330+
assertThat(getBody(request0)).isEmpty();
331+
}
332+
333+
@Test
334+
public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
335+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
336+
337+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
338+
.build();
339+
340+
OAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), Collections
341+
.singletonMap("user", "rob"), "user");
342+
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), "client-id");
343+
344+
this.function
345+
.filter(request, this.exchange)
346+
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
347+
.block();
348+
349+
List<ClientRequest> requests = this.exchange.getRequests();
350+
assertThat(requests).hasSize(1);
351+
352+
verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
353+
}
354+
296355
private static String getBody(ClientRequest request) {
297356
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
298357
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));

0 commit comments

Comments
 (0)