Skip to content

Commit c2c054a

Browse files
committed
Add support for client_credentials grant
Fixes gh-4982
1 parent 1a65abd commit c2c054a

File tree

12 files changed

+1040
-24
lines changed

12 files changed

+1040
-24
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.springframework.context.annotation.Import;
2121
import org.springframework.context.annotation.ImportSelector;
2222
import org.springframework.core.type.AnnotationMetadata;
23+
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
2324
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
2425
import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
2526
import org.springframework.util.ClassUtils;
@@ -57,17 +58,26 @@ public String[] selectImports(AnnotationMetadata importingClassMetadata) {
5758

5859
@Configuration
5960
static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer {
61+
private ClientRegistrationRepository clientRegistrationRepository;
6062
private OAuth2AuthorizedClientRepository authorizedClientRepository;
6163

6264
@Override
6365
public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
64-
if (this.authorizedClientRepository != null) {
66+
if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) {
6567
OAuth2AuthorizedClientArgumentResolver authorizedClientArgumentResolver =
66-
new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
68+
new OAuth2AuthorizedClientArgumentResolver(
69+
this.clientRegistrationRepository, this.authorizedClientRepository);
6770
argumentResolvers.add(authorizedClientArgumentResolver);
6871
}
6972
}
7073

74+
@Autowired(required = false)
75+
public void setClientRegistrationRepository(List<ClientRegistrationRepository> clientRegistrationRepositories) {
76+
if (clientRegistrationRepositories.size() == 1) {
77+
this.clientRegistrationRepository = clientRegistrationRepositories.get(0);
78+
}
79+
}
80+
7181
@Autowired(required = false)
7282
public void setAuthorizedClientRepository(List<OAuth2AuthorizedClientRepository> authorizedClientRepositories) {
7383
if (authorizedClientRepositories.size() == 1) {

config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfigurationTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ public String authorizedClient(@RegisteredOAuth2AuthorizedClient("client1") OAut
9898
}
9999
}
100100

101+
@Bean
102+
public ClientRegistrationRepository clientRegistrationRepository() {
103+
return mock(ClientRegistrationRepository.class);
104+
}
105+
101106
@Bean
102107
public OAuth2AuthorizedClientRepository authorizedClientRepository() {
103108
return AUTHORIZED_CLIENT_REPOSITORY;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
/*
2+
* Copyright 2002-2018 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.client.endpoint;
17+
18+
import org.springframework.core.ParameterizedTypeReference;
19+
import org.springframework.http.HttpHeaders;
20+
import org.springframework.http.HttpMethod;
21+
import org.springframework.http.MediaType;
22+
import org.springframework.http.RequestEntity;
23+
import org.springframework.http.ResponseEntity;
24+
import org.springframework.http.client.ClientHttpResponse;
25+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
26+
import org.springframework.security.oauth2.core.AuthorizationGrantType;
27+
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
28+
import org.springframework.security.oauth2.core.OAuth2AccessToken;
29+
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
30+
import org.springframework.security.oauth2.core.OAuth2Error;
31+
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
32+
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
33+
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
34+
import org.springframework.util.Assert;
35+
import org.springframework.util.CollectionUtils;
36+
import org.springframework.util.LinkedMultiValueMap;
37+
import org.springframework.util.MultiValueMap;
38+
import org.springframework.util.StringUtils;
39+
import org.springframework.web.client.ResponseErrorHandler;
40+
import org.springframework.web.client.RestOperations;
41+
import org.springframework.web.client.RestTemplate;
42+
import org.springframework.web.util.UriComponentsBuilder;
43+
44+
import java.io.IOException;
45+
import java.net.URI;
46+
import java.util.Arrays;
47+
import java.util.Collections;
48+
import java.util.LinkedHashMap;
49+
import java.util.Map;
50+
import java.util.Set;
51+
import java.util.stream.Collectors;
52+
import java.util.stream.Stream;
53+
54+
/**
55+
* The default implementation of an {@link OAuth2AccessTokenResponseClient}
56+
* for the {@link AuthorizationGrantType#CLIENT_CREDENTIALS client_credentials} grant.
57+
* This implementation uses a {@link RestOperations} when requesting
58+
* an access token credential at the Authorization Server's Token Endpoint.
59+
*
60+
* @author Joe Grandja
61+
* @since 5.1
62+
* @see OAuth2AccessTokenResponseClient
63+
* @see OAuth2ClientCredentialsGrantRequest
64+
* @see OAuth2AccessTokenResponse
65+
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.2">Section 4.4.2 Access Token Request (Client Credentials Grant)</a>
66+
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.3">Section 4.4.3 Access Token Response (Client Credentials Grant)</a>
67+
*/
68+
public class DefaultClientCredentialsTokenResponseClient implements OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
69+
private static final String INVALID_TOKEN_REQUEST_ERROR_CODE = "invalid_token_request";
70+
71+
private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
72+
73+
private static final String[] TOKEN_RESPONSE_PARAMETER_NAMES = {
74+
OAuth2ParameterNames.ACCESS_TOKEN,
75+
OAuth2ParameterNames.TOKEN_TYPE,
76+
OAuth2ParameterNames.EXPIRES_IN,
77+
OAuth2ParameterNames.SCOPE,
78+
OAuth2ParameterNames.REFRESH_TOKEN
79+
};
80+
81+
private RestOperations restOperations;
82+
83+
public DefaultClientCredentialsTokenResponseClient() {
84+
RestTemplate restTemplate = new RestTemplate();
85+
// Disable the ResponseErrorHandler as errors are handled directly within this class
86+
restTemplate.setErrorHandler(new NoOpResponseErrorHandler());
87+
this.restOperations = restTemplate;
88+
}
89+
90+
@Override
91+
public OAuth2AccessTokenResponse getTokenResponse(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest)
92+
throws OAuth2AuthenticationException {
93+
94+
Assert.notNull(clientCredentialsGrantRequest, "clientCredentialsGrantRequest cannot be null");
95+
96+
ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration();
97+
98+
// Headers
99+
HttpHeaders headers = new HttpHeaders();
100+
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
101+
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
102+
if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
103+
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
104+
}
105+
106+
// Form parameters
107+
MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
108+
formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue());
109+
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
110+
formParameters.add(OAuth2ParameterNames.SCOPE,
111+
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
112+
}
113+
if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
114+
formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
115+
formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
116+
}
117+
118+
// Request
119+
URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri())
120+
.build()
121+
.toUri();
122+
RequestEntity<MultiValueMap<String, String>> request =
123+
new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri);
124+
125+
ParameterizedTypeReference<Map<String, String>> typeReference =
126+
new ParameterizedTypeReference<Map<String, String>>() {};
127+
128+
// Exchange
129+
ResponseEntity<Map<String, String>> response;
130+
try {
131+
response = this.restOperations.exchange(request, typeReference);
132+
} catch (Exception ex) {
133+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_REQUEST_ERROR_CODE,
134+
"An error occurred while sending the Access Token Request: " + ex.getMessage(), null);
135+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
136+
}
137+
138+
Map<String, String> responseParameters = response.getBody();
139+
140+
// Check for Error Response
141+
if (response.getStatusCodeValue() != 200) {
142+
OAuth2Error oauth2Error = this.parseErrorResponse(responseParameters);
143+
if (oauth2Error == null) {
144+
oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR);
145+
}
146+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
147+
}
148+
149+
// Success Response
150+
OAuth2AccessTokenResponse tokenResponse;
151+
try {
152+
tokenResponse = this.parseTokenResponse(responseParameters);
153+
} catch (Exception ex) {
154+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
155+
"An error occurred parsing the Access Token response (200 OK): " + ex.getMessage(), null);
156+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
157+
}
158+
159+
if (tokenResponse == null) {
160+
// This should never happen as long as the provider
161+
// implements a Successful Response as defined in Section 5.1
162+
// https://tools.ietf.org/html/rfc6749#section-5.1
163+
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
164+
"An error occurred parsing the Access Token response (200 OK). " +
165+
"Missing required parameters: access_token and/or token_type", null);
166+
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
167+
}
168+
169+
if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) {
170+
// As per spec, in Section 5.1 Successful Access Token Response
171+
// https://tools.ietf.org/html/rfc6749#section-5.1
172+
// If AccessTokenResponse.scope is empty, then default to the scope
173+
// originally requested by the client in the Token Request
174+
tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse)
175+
.scopes(clientRegistration.getScopes())
176+
.build();
177+
}
178+
179+
return tokenResponse;
180+
}
181+
182+
/**
183+
* Sets the {@link RestOperations} used when requesting the access token response.
184+
*
185+
* @param restOperations the {@link RestOperations} used when requesting the access token response
186+
*/
187+
public final void setRestOperations(RestOperations restOperations) {
188+
Assert.notNull(restOperations, "restOperations cannot be null");
189+
this.restOperations = restOperations;
190+
}
191+
192+
private OAuth2Error parseErrorResponse(Map<String, String> responseParameters) {
193+
if (CollectionUtils.isEmpty(responseParameters) ||
194+
!responseParameters.containsKey(OAuth2ParameterNames.ERROR)) {
195+
return null;
196+
}
197+
198+
String errorCode = responseParameters.get(OAuth2ParameterNames.ERROR);
199+
String errorDescription = responseParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION);
200+
String errorUri = responseParameters.get(OAuth2ParameterNames.ERROR_URI);
201+
202+
return new OAuth2Error(errorCode, errorDescription, errorUri);
203+
}
204+
205+
private OAuth2AccessTokenResponse parseTokenResponse(Map<String, String> responseParameters) {
206+
if (CollectionUtils.isEmpty(responseParameters) ||
207+
!responseParameters.containsKey(OAuth2ParameterNames.ACCESS_TOKEN) ||
208+
!responseParameters.containsKey(OAuth2ParameterNames.TOKEN_TYPE)) {
209+
return null;
210+
}
211+
212+
String accessToken = responseParameters.get(OAuth2ParameterNames.ACCESS_TOKEN);
213+
214+
OAuth2AccessToken.TokenType accessTokenType = null;
215+
if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(
216+
responseParameters.get(OAuth2ParameterNames.TOKEN_TYPE))) {
217+
accessTokenType = OAuth2AccessToken.TokenType.BEARER;
218+
}
219+
220+
long expiresIn = 0;
221+
if (responseParameters.containsKey(OAuth2ParameterNames.EXPIRES_IN)) {
222+
try {
223+
expiresIn = Long.valueOf(responseParameters.get(OAuth2ParameterNames.EXPIRES_IN));
224+
} catch (NumberFormatException ex) { }
225+
}
226+
227+
Set<String> scopes = Collections.emptySet();
228+
if (responseParameters.containsKey(OAuth2ParameterNames.SCOPE)) {
229+
String scope = responseParameters.get(OAuth2ParameterNames.SCOPE);
230+
scopes = Arrays.stream(StringUtils.delimitedListToStringArray(scope, " ")).collect(Collectors.toSet());
231+
}
232+
233+
Map<String, Object> additionalParameters = new LinkedHashMap<>();
234+
Set<String> tokenResponseParameterNames = Stream.of(TOKEN_RESPONSE_PARAMETER_NAMES).collect(Collectors.toSet());
235+
responseParameters.entrySet().stream()
236+
.filter(e -> !tokenResponseParameterNames.contains(e.getKey()))
237+
.forEach(e -> additionalParameters.put(e.getKey(), e.getValue()));
238+
239+
return OAuth2AccessTokenResponse.withToken(accessToken)
240+
.tokenType(accessTokenType)
241+
.expiresIn(expiresIn)
242+
.scopes(scopes)
243+
.additionalParameters(additionalParameters)
244+
.build();
245+
}
246+
247+
private static class NoOpResponseErrorHandler implements ResponseErrorHandler {
248+
249+
@Override
250+
public boolean hasError(ClientHttpResponse response) throws IOException {
251+
return false;
252+
}
253+
254+
@Override
255+
public void handleError(ClientHttpResponse response) throws IOException {
256+
}
257+
}
258+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2002-2018 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.client.endpoint;
17+
18+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
19+
import org.springframework.security.oauth2.core.AuthorizationGrantType;
20+
import org.springframework.util.Assert;
21+
22+
/**
23+
* An OAuth 2.0 Client Credentials Grant request that holds
24+
* the client's credentials in {@link #getClientRegistration()}.
25+
*
26+
* @author Joe Grandja
27+
* @since 5.1
28+
* @see AbstractOAuth2AuthorizationGrantRequest
29+
* @see ClientRegistration
30+
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.3.4">Section 1.3.4 Client Credentials Grant</a>
31+
*/
32+
public class OAuth2ClientCredentialsGrantRequest extends AbstractOAuth2AuthorizationGrantRequest {
33+
private final ClientRegistration clientRegistration;
34+
35+
/**
36+
* Constructs an {@code OAuth2ClientCredentialsGrantRequest} using the provided parameters.
37+
*
38+
* @param clientRegistration the client registration
39+
*/
40+
public OAuth2ClientCredentialsGrantRequest(ClientRegistration clientRegistration) {
41+
super(AuthorizationGrantType.CLIENT_CREDENTIALS);
42+
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
43+
Assert.isTrue(AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()),
44+
"clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS");
45+
this.clientRegistration = clientRegistration;
46+
}
47+
48+
/**
49+
* Returns the {@link ClientRegistration client registration}.
50+
*
51+
* @return the {@link ClientRegistration}
52+
*/
53+
public ClientRegistration getClientRegistration() {
54+
return this.clientRegistration;
55+
}
56+
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,9 @@ public Builder clientName(String clientName) {
448448
*/
449449
public ClientRegistration build() {
450450
Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null");
451-
if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) {
451+
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(this.authorizationGrantType)) {
452+
this.validateClientCredentialsGrantType();
453+
} else if (AuthorizationGrantType.IMPLICIT.equals(this.authorizationGrantType)) {
452454
this.validateImplicitGrantType();
453455
} else {
454456
this.validateAuthorizationCodeGrantType();
@@ -507,5 +509,15 @@ private void validateImplicitGrantType() {
507509
Assert.hasText(this.authorizationUri, "authorizationUri cannot be empty");
508510
Assert.hasText(this.clientName, "clientName cannot be empty");
509511
}
512+
513+
private void validateClientCredentialsGrantType() {
514+
Assert.isTrue(AuthorizationGrantType.CLIENT_CREDENTIALS.equals(this.authorizationGrantType),
515+
() -> "authorizationGrantType must be " + AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
516+
Assert.hasText(this.registrationId, "registrationId cannot be empty");
517+
Assert.hasText(this.clientId, "clientId cannot be empty");
518+
Assert.hasText(this.clientSecret, "clientSecret cannot be empty");
519+
Assert.notNull(this.clientAuthenticationMethod, "clientAuthenticationMethod cannot be null");
520+
Assert.hasText(this.tokenUri, "tokenUri cannot be empty");
521+
}
510522
}
511523
}

0 commit comments

Comments
 (0)