Skip to content

Commit 8e615d0

Browse files
jgrandjarwinch
authored andcommitted
Re-factor DefaultClientCredentialsTokenResponseClient
Fixes gh-5735
1 parent 713e1e3 commit 8e615d0

File tree

6 files changed

+276
-231
lines changed

6 files changed

+276
-231
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/DefaultClientCredentialsTokenResponseClient.java

Lines changed: 37 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -15,41 +15,25 @@
1515
*/
1616
package org.springframework.security.oauth2.client.endpoint;
1717

18-
import org.springframework.core.ParameterizedTypeReference;
19-
import org.springframework.http.HttpHeaders;
20-
import org.springframework.http.HttpMethod;
21-
import org.springframework.http.MediaType;
18+
import org.springframework.core.convert.converter.Converter;
2219
import org.springframework.http.RequestEntity;
2320
import org.springframework.http.ResponseEntity;
24-
import org.springframework.http.client.ClientHttpResponse;
25-
import org.springframework.security.oauth2.client.registration.ClientRegistration;
21+
import org.springframework.http.converter.FormHttpMessageConverter;
22+
import org.springframework.http.converter.HttpMessageConverter;
23+
import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
2624
import org.springframework.security.oauth2.core.AuthorizationGrantType;
27-
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
28-
import org.springframework.security.oauth2.core.OAuth2AccessToken;
2925
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
3026
import org.springframework.security.oauth2.core.OAuth2Error;
31-
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
3227
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
33-
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
28+
import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
3429
import org.springframework.util.Assert;
3530
import org.springframework.util.CollectionUtils;
36-
import org.springframework.util.LinkedMultiValueMap;
37-
import org.springframework.util.MultiValueMap;
38-
import org.springframework.util.StringUtils;
3931
import org.springframework.web.client.ResponseErrorHandler;
32+
import org.springframework.web.client.RestClientException;
4033
import org.springframework.web.client.RestOperations;
4134
import org.springframework.web.client.RestTemplate;
42-
import org.springframework.web.util.UriComponentsBuilder;
4335

44-
import java.io.IOException;
45-
import java.net.URI;
4636
import java.util.Arrays;
47-
import java.util.Collections;
48-
import java.util.LinkedHashMap;
49-
import java.util.Map;
50-
import java.util.Set;
51-
import java.util.stream.Collectors;
52-
import java.util.stream.Stream;
5337

5438
/**
5539
* The default implementation of an {@link OAuth2AccessTokenResponseClient}
@@ -65,25 +49,18 @@
6549
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.2">Section 4.4.2 Access Token Request (Client Credentials Grant)</a>
6650
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.3">Section 4.4.3 Access Token Response (Client Credentials Grant)</a>
6751
*/
68-
public class DefaultClientCredentialsTokenResponseClient implements OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
69-
private static final String INVALID_TOKEN_REQUEST_ERROR_CODE = "invalid_token_request";
70-
52+
public final class DefaultClientCredentialsTokenResponseClient implements OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> {
7153
private static final String INVALID_TOKEN_RESPONSE_ERROR_CODE = "invalid_token_response";
7254

73-
private static final String[] TOKEN_RESPONSE_PARAMETER_NAMES = {
74-
OAuth2ParameterNames.ACCESS_TOKEN,
75-
OAuth2ParameterNames.TOKEN_TYPE,
76-
OAuth2ParameterNames.EXPIRES_IN,
77-
OAuth2ParameterNames.SCOPE,
78-
OAuth2ParameterNames.REFRESH_TOKEN
79-
};
55+
private Converter<OAuth2ClientCredentialsGrantRequest, RequestEntity<?>> requestEntityConverter =
56+
new OAuth2ClientCredentialsGrantRequestEntityConverter();
8057

8158
private RestOperations restOperations;
8259

8360
public DefaultClientCredentialsTokenResponseClient() {
84-
RestTemplate restTemplate = new RestTemplate();
85-
// Disable the ResponseErrorHandler as errors are handled directly within this class
86-
restTemplate.setErrorHandler(new NoOpResponseErrorHandler());
61+
RestTemplate restTemplate = new RestTemplate(Arrays.asList(
62+
new FormHttpMessageConverter(), new OAuth2AccessTokenResponseHttpMessageConverter()));
63+
restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
8764
this.restOperations = restTemplate;
8865
}
8966

@@ -93,50 +70,18 @@ public OAuth2AccessTokenResponse getTokenResponse(OAuth2ClientCredentialsGrantRe
9370

9471
Assert.notNull(clientCredentialsGrantRequest, "clientCredentialsGrantRequest cannot be null");
9572

96-
// Build request
97-
RequestEntity<MultiValueMap<String, String>> request = this.buildRequest(clientCredentialsGrantRequest);
73+
RequestEntity<?> request = this.requestEntityConverter.convert(clientCredentialsGrantRequest);
9874

99-
// Exchange
100-
ResponseEntity<Map<String, String>> response;
75+
ResponseEntity<OAuth2AccessTokenResponse> response;
10176
try {
102-
response = this.restOperations.exchange(
103-
request, new ParameterizedTypeReference<Map<String, String>>() {});
104-
} catch (Exception ex) {
105-
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_REQUEST_ERROR_CODE,
106-
"An error occurred while sending the Access Token Request: " + ex.getMessage(), null);
107-
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
108-
}
109-
110-
Map<String, String> responseParameters = response.getBody();
111-
112-
// Check for Error Response
113-
if (response.getStatusCodeValue() != 200) {
114-
OAuth2Error oauth2Error = this.parseErrorResponse(responseParameters);
115-
if (oauth2Error == null) {
116-
oauth2Error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR);
117-
}
118-
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
119-
}
120-
121-
// Success Response
122-
OAuth2AccessTokenResponse tokenResponse;
123-
try {
124-
tokenResponse = this.parseTokenResponse(responseParameters);
125-
} catch (Exception ex) {
77+
response = this.restOperations.exchange(request, OAuth2AccessTokenResponse.class);
78+
} catch (RestClientException ex) {
12679
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
127-
"An error occurred parsing the Access Token response (200 OK): " + ex.getMessage(), null);
80+
"An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response: " + ex.getMessage(), null);
12881
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
12982
}
13083

131-
if (tokenResponse == null) {
132-
// This should never happen as long as the provider
133-
// implements a Successful Response as defined in Section 5.1
134-
// https://tools.ietf.org/html/rfc6749#section-5.1
135-
OAuth2Error oauth2Error = new OAuth2Error(INVALID_TOKEN_RESPONSE_ERROR_CODE,
136-
"An error occurred parsing the Access Token response (200 OK). " +
137-
"Missing required parameters: access_token and/or token_type", null);
138-
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
139-
}
84+
OAuth2AccessTokenResponse tokenResponse = response.getBody();
14085

14186
if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) {
14287
// As per spec, in Section 5.1 Successful Access Token Response
@@ -151,120 +96,31 @@ public OAuth2AccessTokenResponse getTokenResponse(OAuth2ClientCredentialsGrantRe
15196
return tokenResponse;
15297
}
15398

154-
private RequestEntity<MultiValueMap<String, String>> buildRequest(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) {
155-
HttpHeaders headers = this.buildHeaders(clientCredentialsGrantRequest);
156-
MultiValueMap<String, String> formParameters = this.buildFormParameters(clientCredentialsGrantRequest);
157-
URI uri = UriComponentsBuilder.fromUriString(clientCredentialsGrantRequest.getClientRegistration().getProviderDetails().getTokenUri())
158-
.build()
159-
.toUri();
160-
161-
return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri);
162-
}
163-
164-
private HttpHeaders buildHeaders(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) {
165-
ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration();
166-
167-
HttpHeaders headers = new HttpHeaders();
168-
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
169-
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
170-
if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
171-
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
172-
}
173-
174-
return headers;
175-
}
176-
177-
private MultiValueMap<String, String> buildFormParameters(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) {
178-
ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration();
179-
180-
MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
181-
formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue());
182-
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
183-
formParameters.add(OAuth2ParameterNames.SCOPE,
184-
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
185-
}
186-
if (ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
187-
formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
188-
formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
189-
}
190-
191-
return formParameters;
192-
}
193-
194-
private OAuth2Error parseErrorResponse(Map<String, String> responseParameters) {
195-
if (CollectionUtils.isEmpty(responseParameters) ||
196-
!responseParameters.containsKey(OAuth2ParameterNames.ERROR)) {
197-
return null;
198-
}
199-
200-
String errorCode = responseParameters.get(OAuth2ParameterNames.ERROR);
201-
String errorDescription = responseParameters.get(OAuth2ParameterNames.ERROR_DESCRIPTION);
202-
String errorUri = responseParameters.get(OAuth2ParameterNames.ERROR_URI);
203-
204-
return new OAuth2Error(errorCode, errorDescription, errorUri);
205-
}
206-
207-
private OAuth2AccessTokenResponse parseTokenResponse(Map<String, String> responseParameters) {
208-
if (CollectionUtils.isEmpty(responseParameters) ||
209-
!responseParameters.containsKey(OAuth2ParameterNames.ACCESS_TOKEN) ||
210-
!responseParameters.containsKey(OAuth2ParameterNames.TOKEN_TYPE)) {
211-
return null;
212-
}
213-
214-
String accessToken = responseParameters.get(OAuth2ParameterNames.ACCESS_TOKEN);
215-
216-
OAuth2AccessToken.TokenType accessTokenType = null;
217-
if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(
218-
responseParameters.get(OAuth2ParameterNames.TOKEN_TYPE))) {
219-
accessTokenType = OAuth2AccessToken.TokenType.BEARER;
220-
}
221-
222-
long expiresIn = 0;
223-
if (responseParameters.containsKey(OAuth2ParameterNames.EXPIRES_IN)) {
224-
try {
225-
expiresIn = Long.valueOf(responseParameters.get(OAuth2ParameterNames.EXPIRES_IN));
226-
} catch (NumberFormatException ex) { }
227-
}
228-
229-
Set<String> scopes = Collections.emptySet();
230-
if (responseParameters.containsKey(OAuth2ParameterNames.SCOPE)) {
231-
String scope = responseParameters.get(OAuth2ParameterNames.SCOPE);
232-
scopes = Arrays.stream(StringUtils.delimitedListToStringArray(scope, " ")).collect(Collectors.toSet());
233-
}
234-
235-
Map<String, Object> additionalParameters = new LinkedHashMap<>();
236-
Set<String> tokenResponseParameterNames = Stream.of(TOKEN_RESPONSE_PARAMETER_NAMES).collect(Collectors.toSet());
237-
responseParameters.entrySet().stream()
238-
.filter(e -> !tokenResponseParameterNames.contains(e.getKey()))
239-
.forEach(e -> additionalParameters.put(e.getKey(), e.getValue()));
240-
241-
return OAuth2AccessTokenResponse.withToken(accessToken)
242-
.tokenType(accessTokenType)
243-
.expiresIn(expiresIn)
244-
.scopes(scopes)
245-
.additionalParameters(additionalParameters)
246-
.build();
99+
/**
100+
* Sets the {@link Converter} used for converting the {@link OAuth2ClientCredentialsGrantRequest}
101+
* to a {@link RequestEntity} representation of the OAuth 2.0 Access Token Request.
102+
*
103+
* @param requestEntityConverter the {@link Converter} used for converting to a {@link RequestEntity} representation of the Access Token Request
104+
*/
105+
public void setRequestEntityConverter(Converter<OAuth2ClientCredentialsGrantRequest, RequestEntity<?>> requestEntityConverter) {
106+
Assert.notNull(requestEntityConverter, "requestEntityConverter cannot be null");
107+
this.requestEntityConverter = requestEntityConverter;
247108
}
248109

249110
/**
250-
* Sets the {@link RestOperations} used when requesting the access token response.
111+
* Sets the {@link RestOperations} used when requesting the OAuth 2.0 Access Token Response.
112+
*
113+
* <p>
114+
* <b>NOTE:</b> At a minimum, the supplied {@code restOperations} must be configured with the following:
115+
* <ol>
116+
* <li>{@link HttpMessageConverter}'s - {@link FormHttpMessageConverter} and {@link OAuth2AccessTokenResponseHttpMessageConverter}</li>
117+
* <li>{@link ResponseErrorHandler} - {@link OAuth2ErrorResponseErrorHandler}</li>
118+
* </ol>
251119
*
252-
* @param restOperations the {@link RestOperations} used when requesting the access token response
120+
* @param restOperations the {@link RestOperations} used when requesting the Access Token Response
253121
*/
254-
public final void setRestOperations(RestOperations restOperations) {
122+
public void setRestOperations(RestOperations restOperations) {
255123
Assert.notNull(restOperations, "restOperations cannot be null");
256124
this.restOperations = restOperations;
257125
}
258-
259-
private static class NoOpResponseErrorHandler implements ResponseErrorHandler {
260-
261-
@Override
262-
public boolean hasError(ClientHttpResponse response) throws IOException {
263-
return false;
264-
}
265-
266-
@Override
267-
public void handleError(ClientHttpResponse response) throws IOException {
268-
}
269-
}
270126
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/OAuth2AuthorizationCodeGrantRequestEntityConverter.java

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
import org.springframework.core.convert.converter.Converter;
1919
import org.springframework.http.HttpHeaders;
2020
import org.springframework.http.HttpMethod;
21-
import org.springframework.http.MediaType;
2221
import org.springframework.http.RequestEntity;
23-
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
2422
import org.springframework.security.oauth2.client.registration.ClientRegistration;
2523
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
2624
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
@@ -30,9 +28,6 @@
3028
import org.springframework.web.util.UriComponentsBuilder;
3129

3230
import java.net.URI;
33-
import java.util.Collections;
34-
35-
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
3631

3732
/**
3833
* A {@link Converter} that converts the provided {@link OAuth2AuthorizationCodeGrantRequest}
@@ -57,7 +52,7 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverter implements Conve
5752
public RequestEntity<?> convert(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) {
5853
ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration();
5954

60-
HttpHeaders headers = this.buildHeaders(authorizationCodeGrantRequest);
55+
HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration);
6156
MultiValueMap<String, String> formParameters = this.buildFormParameters(authorizationCodeGrantRequest);
6257
URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri())
6358
.build()
@@ -66,26 +61,6 @@ public RequestEntity<?> convert(OAuth2AuthorizationCodeGrantRequest authorizatio
6661
return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri);
6762
}
6863

69-
/**
70-
* Returns the {@link HttpHeaders} used for the Access Token Request.
71-
*
72-
* @param authorizationCodeGrantRequest the authorization code grant request
73-
* @return the {@link HttpHeaders} used for the Access Token Request
74-
*/
75-
private HttpHeaders buildHeaders(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) {
76-
ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration();
77-
78-
HttpHeaders headers = new HttpHeaders();
79-
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
80-
final MediaType contentType = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8");
81-
headers.setContentType(contentType);
82-
if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
83-
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
84-
}
85-
86-
return headers;
87-
}
88-
8964
/**
9065
* Returns a {@link MultiValueMap} of the form parameters used for the Access Token Request body.
9166
*
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2002-2018 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.security.oauth2.client.endpoint;
17+
18+
import org.springframework.core.convert.converter.Converter;
19+
import org.springframework.http.HttpHeaders;
20+
import org.springframework.http.MediaType;
21+
import org.springframework.http.RequestEntity;
22+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
23+
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
24+
25+
import java.util.Collections;
26+
27+
import static org.springframework.http.MediaType.APPLICATION_FORM_URLENCODED_VALUE;
28+
29+
/**
30+
* Utility methods used by the {@link Converter}'s that convert
31+
* from an implementation of an {@link AbstractOAuth2AuthorizationGrantRequest}
32+
* to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request
33+
* for the specific Authorization Grant.
34+
*
35+
* @author Joe Grandja
36+
* @since 5.1
37+
* @see OAuth2AuthorizationCodeGrantRequestEntityConverter
38+
* @see OAuth2ClientCredentialsGrantRequestEntityConverter
39+
*/
40+
final class OAuth2AuthorizationGrantRequestEntityUtils {
41+
private static HttpHeaders DEFAULT_TOKEN_REQUEST_HEADERS = getDefaultTokenRequestHeaders();
42+
43+
static HttpHeaders getTokenRequestHeaders(ClientRegistration clientRegistration) {
44+
HttpHeaders headers = new HttpHeaders();
45+
headers.addAll(DEFAULT_TOKEN_REQUEST_HEADERS);
46+
if (ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
47+
headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
48+
}
49+
return headers;
50+
}
51+
52+
private static HttpHeaders getDefaultTokenRequestHeaders() {
53+
HttpHeaders headers = new HttpHeaders();
54+
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON_UTF8));
55+
final MediaType contentType = MediaType.valueOf(APPLICATION_FORM_URLENCODED_VALUE + ";charset=UTF-8");
56+
headers.setContentType(contentType);
57+
return headers;
58+
}
59+
}

0 commit comments

Comments
 (0)