Skip to content

Commit f5ad4ba

Browse files
committed
ServletOAuth2AuthorizedClientExchangeFilterFunction support client_credentials
Fixes: gh-5639
1 parent 2d497c7 commit f5ad4ba

File tree

3 files changed

+139
-18
lines changed

3 files changed

+139
-18
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,15 @@
2626
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
2727
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
2828
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
29+
import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient;
30+
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
31+
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
2932
import org.springframework.security.oauth2.client.registration.ClientRegistration;
33+
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
3034
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
3135
import org.springframework.security.oauth2.core.AuthorizationGrantType;
3236
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
37+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
3338
import org.springframework.util.Assert;
3439
import org.springframework.web.context.request.RequestContextHolder;
3540
import org.springframework.web.context.request.ServletRequestAttributes;
@@ -107,16 +112,35 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
107112

108113
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
109114

115+
private ClientRegistrationRepository clientRegistrationRepository;
116+
110117
private OAuth2AuthorizedClientRepository authorizedClientRepository;
111118

119+
private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
120+
new DefaultClientCredentialsTokenResponseClient();
121+
112122
private boolean defaultOAuth2AuthorizedClient;
113123

114124
public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
115125

116-
public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientRepository authorizedClientRepository) {
126+
public ServletOAuth2AuthorizedClientExchangeFilterFunction(
127+
ClientRegistrationRepository clientRegistrationRepository,
128+
OAuth2AuthorizedClientRepository authorizedClientRepository) {
129+
this.clientRegistrationRepository = clientRegistrationRepository;
117130
this.authorizedClientRepository = authorizedClientRepository;
118131
}
119132

133+
/**
134+
* Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
135+
* client_credentials grant.
136+
* @param clientCredentialsTokenResponseClient the client to use
137+
*/
138+
public void setClientCredentialsTokenResponseClient(
139+
OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
140+
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
141+
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
142+
}
143+
120144
/**
121145
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
122146
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
@@ -277,18 +301,55 @@ private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> attrs) {
277301
clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
278302
}
279303
if (clientRegistrationId != null) {
280-
HttpServletRequest request = (HttpServletRequest) attrs.get(
281-
HTTP_SERVLET_REQUEST_ATTR_NAME);
304+
HttpServletRequest request = getRequest(attrs);
282305
OAuth2AuthorizedClient authorizedClient = this.authorizedClientRepository
283306
.loadAuthorizedClient(clientRegistrationId, authentication,
284307
request);
285308
if (authorizedClient == null) {
286-
throw new ClientAuthorizationRequiredException(clientRegistrationId);
309+
authorizedClient = getAuthorizedClient(clientRegistrationId, attrs);
287310
}
288311
oauth2AuthorizedClient(authorizedClient).accept(attrs);
289312
}
290313
}
291314

315+
private OAuth2AuthorizedClient getAuthorizedClient(String clientRegistrationId, Map<String, Object> attrs) {
316+
ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
317+
if (clientRegistration == null) {
318+
throw new IllegalArgumentException("Could not find ClientRegistration with id " + clientRegistrationId);
319+
}
320+
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
321+
return getAuthorizedClient(clientRegistration, attrs);
322+
}
323+
throw new ClientAuthorizationRequiredException(clientRegistrationId);
324+
}
325+
326+
327+
private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration,
328+
Map<String, Object> attrs) {
329+
330+
HttpServletRequest request = getRequest(attrs);
331+
HttpServletResponse response = getResponse(attrs);
332+
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest =
333+
new OAuth2ClientCredentialsGrantRequest(clientRegistration);
334+
OAuth2AccessTokenResponse tokenResponse =
335+
this.clientCredentialsTokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
336+
337+
Authentication principal = getAuthentication(attrs);
338+
339+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
340+
clientRegistration,
341+
(principal != null ? principal.getName() : "anonymousUser"),
342+
tokenResponse.getAccessToken());
343+
344+
this.authorizedClientRepository.saveAuthorizedClient(
345+
authorizedClient,
346+
principal,
347+
request,
348+
response);
349+
350+
return authorizedClient;
351+
}
352+
292353
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
293354
if (shouldRefresh(authorizedClient)) {
294355
return refreshAuthorizedClient(request, next, authorizedClient);

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/TestClientRegistrations.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,11 @@ public static ClientRegistration.Builder clientRegistration2() {
5454
.clientId("client-id-2")
5555
.clientSecret("client-secret");
5656
}
57+
58+
public static ClientRegistration.Builder clientCredentials() {
59+
return clientRegistration()
60+
.registrationId("client-credentials")
61+
.clientId("client-id")
62+
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS);
63+
}
5764
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,16 @@
4646
import org.springframework.security.core.context.SecurityContextHolder;
4747
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
4848
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
49+
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
50+
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
4951
import org.springframework.security.oauth2.client.registration.ClientRegistration;
52+
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
5053
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
5154
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
5255
import org.springframework.security.oauth2.core.OAuth2AccessToken;
5356
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
5457
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
58+
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
5559
import org.springframework.security.oauth2.core.user.OAuth2User;
5660
import org.springframework.web.context.request.RequestContextHolder;
5761
import org.springframework.web.context.request.ServletRequestAttributes;
@@ -89,6 +93,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
8993
@Mock
9094
private OAuth2AuthorizedClientRepository authorizedClientRepository;
9195
@Mock
96+
private ClientRegistrationRepository clientRegistrationRepository;
97+
@Mock
98+
private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient;
99+
@Mock
92100
private WebClient.RequestHeadersSpec<?> spec;
93101
@Captor
94102
private ArgumentCaptor<Consumer<Map<String, Object>>> attrs;
@@ -148,7 +156,8 @@ public void defaultRequestAuthenticationWhenSecurityContextEmptyThenAuthenticati
148156

149157
@Test
150158
public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet() {
151-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
159+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
160+
this.authorizedClientRepository);
152161
SecurityContextHolder.getContext().setAuthentication(this.authentication);
153162
Map<String, Object> attrs = getDefaultRequestAttributes();
154163
assertThat(getAuthentication(attrs)).isEqualTo(this.authentication);
@@ -157,7 +166,8 @@ public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationS
157166

158167
@Test
159168
public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride() {
160-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
169+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
170+
this.authorizedClientRepository);
161171
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
162172
"principalName", this.accessToken);
163173
oauth2AuthorizedClient(authorizedClient).accept(this.result);
@@ -168,15 +178,17 @@ public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAnd
168178

169179
@Test
170180
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
171-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
181+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
182+
this.authorizedClientRepository);
172183
Map<String, Object> attrs = getDefaultRequestAttributes();
173184
assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
174185
verifyZeroInteractions(this.authorizedClientRepository);
175186
}
176187

177188
@Test
178189
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull() {
179-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
190+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
191+
this.authorizedClientRepository);
180192
Map<String, Object> attrs = getDefaultRequestAttributes();
181193
assertThat(getOAuth2AuthorizedClient(attrs)).isNull();
182194
verifyZeroInteractions(this.authorizedClientRepository);
@@ -196,7 +208,8 @@ public void defaultRequestOAuth2AuthorizedClientWhenRepositoryNullThenOAuth2Auth
196208

197209
@Test
198210
public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
199-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
211+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
212+
this.authorizedClientRepository);
200213
this.function.setDefaultOAuth2AuthorizedClient(true);
201214
OAuth2User user = mock(OAuth2User.class);
202215
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
@@ -214,7 +227,8 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthentication
214227

215228
@Test
216229
public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient() {
217-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
230+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
231+
this.authorizedClientRepository);
218232
OAuth2User user = mock(OAuth2User.class);
219233
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
220234
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
@@ -227,7 +241,8 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticatio
227241

228242
@Test
229243
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit() {
230-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
244+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
245+
this.authorizedClientRepository);
231246
OAuth2User user = mock(OAuth2User.class);
232247
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
233248
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken(user, authorities, "id");
@@ -245,9 +260,8 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegis
245260

246261
@Test
247262
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient() {
248-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
249-
OAuth2User user = mock(OAuth2User.class);
250-
List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
263+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
264+
this.authorizedClientRepository);
251265
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
252266
"principalName", this.accessToken);
253267
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(authorizedClient);
@@ -259,6 +273,41 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientR
259273
verify(this.authorizedClientRepository).loadAuthorizedClient(eq("id"), any(), any());
260274
}
261275

276+
@Test
277+
public void defaultRequestWhenClientCredentialsThenAuthorizedClient() {
278+
this.registration = TestClientRegistrations.clientCredentials().build();
279+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
280+
this.authorizedClientRepository);
281+
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
282+
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.registration);
283+
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
284+
.accessTokenResponse().build();
285+
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(
286+
accessTokenResponse);
287+
288+
clientRegistrationId(this.registration.getRegistrationId()).accept(this.result);
289+
290+
Map<String, Object> attrs = getDefaultRequestAttributes();
291+
OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
292+
293+
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
294+
assertThat(authorizedClient.getClientRegistration()).isEqualTo(this.registration);
295+
assertThat(authorizedClient.getPrincipalName()).isEqualTo("anonymousUser");
296+
assertThat(authorizedClient.getRefreshToken()).isEqualTo(accessTokenResponse.getRefreshToken());
297+
}
298+
299+
@Test
300+
public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException() {
301+
this.registration = TestClientRegistrations.clientCredentials().build();
302+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
303+
this.authorizedClientRepository);
304+
305+
clientRegistrationId(this.registration.getRegistrationId()).accept(this.result);
306+
307+
assertThatCode(() -> getDefaultRequestAttributes())
308+
.isInstanceOf(IllegalArgumentException.class);
309+
}
310+
262311
private Map<String, Object> getDefaultRequestAttributes() {
263312
this.function.defaultRequest().accept(this.spec);
264313
verify(this.spec).attributes(this.attrs.capture());
@@ -322,7 +371,8 @@ public void filterWhenRefreshRequiredThenRefresh() {
322371
this.accessToken.getTokenValue(),
323372
issuedAt,
324373
accessTokenExpiresAt);
325-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
374+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
375+
this.authorizedClientRepository);
326376

327377
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
328378
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -368,7 +418,8 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
368418
this.accessToken.getTokenValue(),
369419
issuedAt,
370420
accessTokenExpiresAt);
371-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
421+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
422+
this.authorizedClientRepository);
372423

373424
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
374425
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -400,7 +451,8 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
400451

401452
@Test
402453
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
403-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
454+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
455+
this.authorizedClientRepository);
404456

405457
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
406458
"principalName", this.accessToken);
@@ -422,7 +474,8 @@ public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
422474

423475
@Test
424476
public void filterWhenNotExpiredThenShouldRefreshFalse() {
425-
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientRepository);
477+
this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository,
478+
this.authorizedClientRepository);
426479

427480
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
428481
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,

0 commit comments

Comments
 (0)