Skip to content

Commit 89f2874

Browse files
committed
ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId
You can now provide the clientRegistrationId and ServerOAuth2AuthorizedClientExchangeFilterFunction will look up the authorized client automatically. Issue: gh-4921
1 parent 5bcbb1c commit 89f2874

File tree

2 files changed

+87
-9
lines changed

2 files changed

+87
-9
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,19 @@
1919
import org.springframework.http.HttpHeaders;
2020
import org.springframework.http.HttpMethod;
2121
import org.springframework.http.MediaType;
22+
import org.springframework.security.authentication.AnonymousAuthenticationToken;
2223
import org.springframework.security.core.Authentication;
2324
import org.springframework.security.core.GrantedAuthority;
25+
import org.springframework.security.core.authority.AuthorityUtils;
2426
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
2527
import org.springframework.security.core.context.SecurityContext;
28+
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
2629
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
30+
import org.springframework.security.oauth2.client.OAuth2ClientException;
2731
import org.springframework.security.oauth2.client.registration.ClientRegistration;
2832
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
2933
import org.springframework.security.oauth2.core.AuthorizationGrantType;
34+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3035
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
3136
import org.springframework.util.Assert;
3237
import org.springframework.web.reactive.function.BodyInserters;
@@ -61,10 +66,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
6166
*/
6267
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
6368

69+
/**
70+
* The client request attribute name used to locate the {@link ClientRegistration#getRegistrationId()}
71+
*/
72+
private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
73+
6474
/**
6575
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
6676
*/
6777
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
78+
public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
79+
AuthorityUtils.createAuthorityList("ROLE_USER"));
6880

6981
private Clock clock = Clock.systemUTC();
7082

@@ -74,8 +86,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
7486

7587
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
7688

77-
public ServerOAuth2AuthorizedClientExchangeFilterFunction(
78-
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
89+
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
7990
this.authorizedClientRepository = authorizedClientRepository;
8091
}
8192

@@ -141,6 +152,18 @@ public static Consumer<Map<String, Object>> serverWebExchange(ServerWebExchange
141152
return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
142153
}
143154

155+
/**
156+
* Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
157+
* be used to look up the {@link OAuth2AuthorizedClient}.
158+
*
159+
* @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to
160+
* be used to look up the {@link OAuth2AuthorizedClient}.
161+
* @return the {@link Consumer} to populate the attributes
162+
*/
163+
public static Consumer<Map<String, Object>> clientRegistrationId(String clientRegistrationId) {
164+
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
165+
}
166+
144167
/**
145168
* An access token will be considered expired by comparing its expiration to now +
146169
* this skewed Duration. The default is 1 minute.
@@ -153,17 +176,42 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
153176

154177
@Override
155178
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
156-
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
157-
.map(OAuth2AuthorizedClient.class::cast);
158179
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
159-
return Mono.justOrEmpty(attribute)
160-
.flatMap(authorizedClient -> authorizedClient(next, authorizedClient, exchange))
180+
return authorizedClient(request)
181+
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, exchange))
161182
.map(authorizedClient -> bearer(request, authorizedClient))
162183
.flatMap(next::exchange)
163184
.switchIfEmpty(next.exchange(request));
164185
}
165186

166-
private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
187+
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
188+
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
189+
.map(OAuth2AuthorizedClient.class::cast);
190+
return Mono.justOrEmpty(attribute)
191+
.switchIfEmpty(findAuthorizedClientByRegistrationId(request));
192+
}
193+
194+
private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(ClientRequest request) {
195+
if (this.authorizedClientRepository == null) {
196+
return Mono.empty();
197+
}
198+
String clientRegistrationId = (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
199+
if (clientRegistrationId == null) {
200+
return Mono.empty();
201+
}
202+
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
203+
return currentAuthentication()
204+
.flatMap(principal -> loadAuthorizedClient(clientRegistrationId, exchange, principal)
205+
);
206+
}
207+
208+
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
209+
ServerWebExchange exchange, Authentication principal) {
210+
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
211+
.switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)));
212+
}
213+
214+
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
167215
if (shouldRefresh(authorizedClient)) {
168216
return refreshAuthorizedClient(next, authorizedClient, exchange);
169217
}
@@ -184,13 +232,18 @@ private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction ne
184232
return next.exchange(request)
185233
.flatMap(response -> response.body(oauth2AccessTokenResponse()))
186234
.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
187-
.flatMap(result -> ReactiveSecurityContextHolder.getContext()
188-
.map(SecurityContext::getAuthentication)
235+
.flatMap(result -> currentAuthentication()
189236
.defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
190237
.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
191238
.thenReturn(result));
192239
}
193240

241+
private Mono<Authentication> currentAuthentication() {
242+
return ReactiveSecurityContextHolder.getContext()
243+
.map(SecurityContext::getAuthentication)
244+
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
245+
}
246+
194247
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
195248
if (this.authorizedClientRepository == null) {
196249
return false;

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import static org.mockito.Mockito.verify;
6262
import static org.mockito.Mockito.when;
6363
import static org.springframework.http.HttpMethod.GET;
64+
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
6465
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
6566

6667
/**
@@ -263,6 +264,30 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() {
263264
assertThat(getBody(request0)).isEmpty();
264265
}
265266

267+
@Test
268+
public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
269+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
270+
271+
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
272+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
273+
"principalName", this.accessToken, refreshToken);
274+
when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
275+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
276+
.attributes(clientRegistrationId(this.registration.getRegistrationId()))
277+
.build();
278+
279+
this.function.filter(request, this.exchange).block();
280+
281+
List<ClientRequest> requests = this.exchange.getRequests();
282+
assertThat(requests).hasSize(1);
283+
284+
ClientRequest request0 = requests.get(0);
285+
assertThat(request0.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
286+
assertThat(request0.url().toASCIIString()).isEqualTo("https://example.com");
287+
assertThat(request0.method()).isEqualTo(HttpMethod.GET);
288+
assertThat(getBody(request0)).isEmpty();
289+
}
290+
266291
private static String getBody(ClientRequest request) {
267292
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
268293
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));

0 commit comments

Comments
 (0)