Skip to content

Commit 158b8aa

Browse files
committed
ServerOAuth2AuthorizedClientExchangeFilterFunction clientRegistrationId
Issue: gh-4921
1 parent 28537fa commit 158b8aa

File tree

3 files changed

+79
-18
lines changed

3 files changed

+79
-18
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@
2727
import org.springframework.security.core.context.SecurityContext;
2828
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
2929
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
30-
import org.springframework.security.oauth2.client.OAuth2ClientException;
30+
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
31+
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
32+
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
3133
import org.springframework.security.oauth2.client.registration.ClientRegistration;
34+
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
3235
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
3336
import org.springframework.security.oauth2.core.AuthorizationGrantType;
34-
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3537
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
38+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
3639
import org.springframework.util.Assert;
3740
import org.springframework.web.reactive.function.BodyInserters;
3841
import org.springframework.web.reactive.function.client.ClientRequest;
@@ -75,18 +78,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
7578
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
7679
*/
7780
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
78-
public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
81+
82+
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
7983
AuthorityUtils.createAuthorityList("ROLE_USER"));
8084

8185
private Clock clock = Clock.systemUTC();
8286

8387
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
8488

89+
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
90+
new WebClientReactiveClientCredentialsTokenResponseClient();
91+
92+
private ReactiveClientRegistrationRepository clientRegistrationRepository;
93+
8594
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
8695

8796
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
8897

89-
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
98+
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
99+
this.clientRegistrationRepository = clientRegistrationRepository;
90100
this.authorizedClientRepository = authorizedClientRepository;
91101
}
92102

@@ -164,6 +174,17 @@ public static Consumer<Map<String, Object>> clientRegistrationId(String clientRe
164174
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
165175
}
166176

177+
/**
178+
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
179+
* client_credentials grant.
180+
* @param clientCredentialsTokenResponseClient the client to use
181+
*/
182+
public void setClientCredentialsTokenResponseClient(
183+
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
184+
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
185+
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
186+
}
187+
167188
/**
168189
* An access token will be considered expired by comparing its expiration to now +
169190
* this skewed Duration. The default is 1 minute.
@@ -208,7 +229,39 @@ private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(Client
208229
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
209230
ServerWebExchange exchange, Authentication principal) {
210231
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
211-
.switchIfEmpty(Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)));
232+
.switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange));
233+
}
234+
235+
private Mono<OAuth2AuthorizedClient> authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) {
236+
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
237+
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
238+
.flatMap(clientRegistration -> {
239+
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
240+
return clientCredentials(clientRegistration, exchange);
241+
}
242+
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
243+
});
244+
}
245+
246+
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
247+
ClientRegistration clientRegistration, ServerWebExchange exchange) {
248+
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
249+
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
250+
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange));
251+
}
252+
253+
private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) {
254+
return currentAuthentication()
255+
.flatMap(principal -> {
256+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
257+
clientRegistration, (principal != null ?
258+
principal.getName() :
259+
"anonymousUser"),
260+
tokenResponse.getAccessToken());
261+
262+
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null)
263+
.thenReturn(authorizedClient);
264+
});
212265
}
213266

214267
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
3838
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
3939
import org.springframework.security.oauth2.client.registration.ClientRegistration;
40+
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
4041
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
4142
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
4243
import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -71,7 +72,10 @@
7172
@RunWith(MockitoJUnitRunner.class)
7273
public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
7374
@Mock
74-
private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
75+
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
76+
77+
@Mock
78+
private ReactiveClientRegistrationRepository clientRegistrationRepository;
7579

7680
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
7781

@@ -125,7 +129,7 @@ public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
125129

126130
@Test
127131
public void filterWhenRefreshRequiredThenRefresh() {
128-
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
132+
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
129133
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
130134
.tokenType(OAuth2AccessToken.TokenType.BEARER)
131135
.expiresIn(3600)
@@ -140,7 +144,7 @@ public void filterWhenRefreshRequiredThenRefresh() {
140144
this.accessToken.getTokenValue(),
141145
issuedAt,
142146
accessTokenExpiresAt);
143-
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
147+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
144148

145149
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
146150
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -154,7 +158,7 @@ public void filterWhenRefreshRequiredThenRefresh() {
154158
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
155159
.block();
156160

157-
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
161+
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
158162

159163
List<ClientRequest> requests = this.exchange.getRequests();
160164
assertThat(requests).hasSize(2);
@@ -174,7 +178,7 @@ public void filterWhenRefreshRequiredThenRefresh() {
174178

175179
@Test
176180
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
177-
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
181+
when(this.authorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
178182
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
179183
.tokenType(OAuth2AccessToken.TokenType.BEARER)
180184
.expiresIn(3600)
@@ -189,7 +193,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
189193
this.accessToken.getTokenValue(),
190194
issuedAt,
191195
accessTokenExpiresAt);
192-
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
196+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
193197

194198
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
195199
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -201,7 +205,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
201205
this.function.filter(request, this.exchange)
202206
.block();
203207

204-
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
208+
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any());
205209

206210
List<ClientRequest> requests = this.exchange.getRequests();
207211
assertThat(requests).hasSize(2);
@@ -221,7 +225,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
221225

222226
@Test
223227
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
224-
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
228+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
225229

226230
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
227231
"principalName", this.accessToken);
@@ -243,7 +247,7 @@ public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
243247

244248
@Test
245249
public void filterWhenNotExpiredThenShouldRefreshFalse() {
246-
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
250+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
247251

248252
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
249253
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -266,12 +270,13 @@ public void filterWhenNotExpiredThenShouldRefreshFalse() {
266270

267271
@Test
268272
public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
269-
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
273+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
270274

271275
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
272276
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
273277
"principalName", this.accessToken, refreshToken);
274-
when(this.auth2AuthorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
278+
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
279+
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
275280
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
276281
.attributes(clientRegistrationId(this.registration.getRegistrationId()))
277282
.build();

samples/boot/authcodegrant-webflux/src/main/java/sample/config/WebClientConfig.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import org.springframework.context.annotation.Bean;
2020
import org.springframework.context.annotation.Configuration;
21+
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
2122
import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction;
23+
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
2224
import org.springframework.web.reactive.function.client.WebClient;
2325

2426
/**
@@ -29,9 +31,10 @@
2931
public class WebClientConfig {
3032

3133
@Bean
32-
WebClient webClient() {
34+
WebClient webClient(ReactiveClientRegistrationRepository clientRegistrationRepository,
35+
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
3336
return WebClient.builder()
34-
.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction())
37+
.filter(new ServerOAuth2AuthorizedClientExchangeFilterFunction(clientRegistrationRepository, authorizedClientRepository))
3538
.build();
3639
}
3740
}

0 commit comments

Comments
 (0)