diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java new file mode 100644 index 00000000000..d7f6fd6edc4 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequest.java @@ -0,0 +1,194 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Represents a request the {@link OAuth2AuthorizedClientManager} uses to + * {@link OAuth2AuthorizedClientManager#authorize(OAuth2AuthorizeRequest) authorize} (or re-authorize) + * the {@link ClientRegistration client} identified by the provided {@link #getClientRegistrationId() clientRegistrationId}. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientManager + */ +public final class OAuth2AuthorizeRequest { + private String clientRegistrationId; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + private Map attributes; + + private OAuth2AuthorizeRequest() { + } + + /** + * Returns the identifier for the {@link ClientRegistration client registration}. + * + * @return the identifier for the client registration + */ + public String getClientRegistrationId() { + return this.clientRegistrationId; + } + + /** + * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if it was not provided. + * + * @return the {@link OAuth2AuthorizedClient} or {@code null} if it was not provided + */ + @Nullable + public OAuth2AuthorizedClient getAuthorizedClient() { + return this.authorizedClient; + } + + /** + * Returns the {@code Principal} (to be) associated to the authorized client. + * + * @return the {@code Principal} (to be) associated to the authorized client + */ + public Authentication getPrincipal() { + return this.principal; + } + + /** + * Returns the attributes associated to the request. + * + * @return a {@code Map} of the attributes associated to the request + */ + public Map getAttributes() { + return this.attributes; + } + + /** + * Returns the value of an attribute associated to the request or {@code null} if not available. + * + * @param name the name of the attribute + * @param the type of the attribute + * @return the value of the attribute associated to the request + */ + @Nullable + @SuppressWarnings("unchecked") + public T getAttribute(String name) { + return (T) this.getAttributes().get(name); + } + + /** + * Returns a new {@link Builder} initialized with the identifier for the {@link ClientRegistration client registration}. + * + * @param clientRegistrationId the identifier for the {@link ClientRegistration client registration} + * @return the {@link Builder} + */ + public static Builder withClientRegistrationId(String clientRegistrationId) { + return new Builder(clientRegistrationId); + } + + /** + * Returns a new {@link Builder} initialized with the {@link OAuth2AuthorizedClient authorized client}. + * + * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} + * @return the {@link Builder} + */ + public static Builder withAuthorizedClient(OAuth2AuthorizedClient authorizedClient) { + return new Builder(authorizedClient); + } + + /** + * A builder for {@link OAuth2AuthorizeRequest}. + */ + public static class Builder { + private String clientRegistrationId; + private OAuth2AuthorizedClient authorizedClient; + private Authentication principal; + private Map attributes; + + private Builder(String clientRegistrationId) { + Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); + this.clientRegistrationId = clientRegistrationId; + } + + private Builder(OAuth2AuthorizedClient authorizedClient) { + Assert.notNull(authorizedClient, "authorizedClient cannot be null"); + this.authorizedClient = authorizedClient; + } + + /** + * Sets the {@code Principal} (to be) associated to the authorized client. + * + * @param principal the {@code Principal} (to be) associated to the authorized client + * @return the {@link Builder} + */ + public Builder principal(Authentication principal) { + this.principal = principal; + return this; + } + + /** + * Sets the attributes associated to the request. + * + * @param attributes the attributes associated to the request + * @return the {@link Builder} + */ + public Builder attributes(Map attributes) { + this.attributes = attributes; + return this; + } + + /** + * Sets an attribute associated to the request. + * + * @param name the name of the attribute + * @param value the value of the attribute + * @return the {@link Builder} + */ + public Builder attribute(String name, Object value) { + if (this.attributes == null) { + this.attributes = new HashMap<>(); + } + this.attributes.put(name, value); + return this; + } + + /** + * Builds a new {@link OAuth2AuthorizeRequest}. + * + * @return a {@link OAuth2AuthorizeRequest} + */ + public OAuth2AuthorizeRequest build() { + Assert.notNull(this.principal, "principal cannot be null"); + OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest(); + if (this.authorizedClient != null) { + authorizeRequest.clientRegistrationId = this.authorizedClient.getClientRegistration().getRegistrationId(); + authorizeRequest.authorizedClient = this.authorizedClient; + } else { + authorizeRequest.clientRegistrationId = this.clientRegistrationId; + } + authorizeRequest.principal = this.principal; + authorizeRequest.attributes = Collections.unmodifiableMap( + CollectionUtils.isEmpty(this.attributes) ? + Collections.emptyMap() : new LinkedHashMap<>(this.attributes)); + return authorizeRequest; + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManager.java similarity index 89% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManager.java index af90c1600e0..5634fcc0866 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManager.java @@ -13,12 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.web; +package org.springframework.security.oauth2.client; import org.springframework.lang.Nullable; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; /** * Implementations of this interface are responsible for the overall management @@ -30,13 +29,14 @@ *
  • Authorizing (or re-authorizing) an OAuth 2.0 Client * by leveraging an {@link OAuth2AuthorizedClientProvider}(s).
  • *
  • Managing the persistence of an {@link OAuth2AuthorizedClient} between requests, - * typically using an {@link OAuth2AuthorizedClientRepository}.
  • + * typically using an {@link OAuth2AuthorizedClientService} OR {@link OAuth2AuthorizedClientRepository}. * * * @author Joe Grandja * @since 5.2 * @see OAuth2AuthorizedClient * @see OAuth2AuthorizedClientProvider + * @see OAuth2AuthorizedClientService * @see OAuth2AuthorizedClientRepository */ public interface OAuth2AuthorizedClientManager { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManagerService.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManagerService.java new file mode 100644 index 00000000000..180bf83ced8 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManagerService.java @@ -0,0 +1,142 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.springframework.lang.Nullable; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +/** + * An implementation of an {@link OAuth2AuthorizedClientManager} + * that is capable of operating outside of a {@code HttpServletRequest} context, + * e.g. in a scheduled/background thread and/or in the service-tier. + * + * @author Joe Grandja + * @since 5.2 + * @see OAuth2AuthorizedClientManager + * @see OAuth2AuthorizedClientProvider + * @see OAuth2AuthorizedClientService + */ +public final class OAuth2AuthorizedClientManagerService implements OAuth2AuthorizedClientManager { + private final ClientRegistrationRepository clientRegistrationRepository; + private final OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientProvider authorizedClientProvider = context -> null; + private Function> contextAttributesMapper = new DefaultContextAttributesMapper(); + + /** + * Constructs an {@code OAuth2AuthorizedClientManagerService} using the provided parameters. + * + * @param clientRegistrationRepository the repository of client registrations + * @param authorizedClientService the authorized client service + */ + public OAuth2AuthorizedClientManagerService(ClientRegistrationRepository clientRegistrationRepository, + OAuth2AuthorizedClientService authorizedClientService) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientService, "authorizedClientService cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientService = authorizedClientService; + } + + @Nullable + @Override + public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) { + Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); + + String clientRegistrationId = authorizeRequest.getClientRegistrationId(); + OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); + Authentication principal = authorizeRequest.getPrincipal(); + + OAuth2AuthorizationContext.Builder contextBuilder; + if (authorizedClient != null) { + contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); + } else { + ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId); + Assert.notNull(clientRegistration, "Could not find ClientRegistration with id '" + clientRegistrationId + "'"); + authorizedClient = this.authorizedClientService.loadAuthorizedClient(clientRegistrationId, principal.getName()); + if (authorizedClient != null) { + contextBuilder = OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient); + } else { + contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration); + } + } + OAuth2AuthorizationContext authorizationContext = contextBuilder + .principal(principal) + .attributes(this.contextAttributesMapper.apply(authorizeRequest)) + .build(); + + authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); + if (authorizedClient != null) { + this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal); + } else { + // In the case of re-authorization, the returned `authorizedClient` may be null if re-authorization is not supported. + // For these cases, return the provided `authorizationContext.authorizedClient`. + if (authorizationContext.getAuthorizedClient() != null) { + return authorizationContext.getAuthorizedClient(); + } + } + + return authorizedClient; + } + + /** + * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. + * + * @param authorizedClientProvider the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client + */ + public void setAuthorizedClientProvider(OAuth2AuthorizedClientProvider authorizedClientProvider) { + Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null"); + this.authorizedClientProvider = authorizedClientProvider; + } + + /** + * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes + * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. + * + * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes + * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} + */ + public void setContextAttributesMapper(Function> contextAttributesMapper) { + Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); + this.contextAttributesMapper = contextAttributesMapper; + } + + /** + * The default implementation of the {@link #setContextAttributesMapper(Function) contextAttributesMapper}. + */ + public static class DefaultContextAttributesMapper implements Function> { + + @Override + public Map apply(OAuth2AuthorizeRequest authorizeRequest) { + Map contextAttributes = Collections.emptyMap(); + String scope = authorizeRequest.getAttribute(OAuth2ParameterNames.SCOPE); + if (StringUtils.hasText(scope)) { + contextAttributes = new HashMap<>(); + contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, + StringUtils.delimitedListToStringArray(scope, " ")); + } + return contextAttributes; + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ServerOAuth2AuthorizedClientManager.java similarity index 82% rename from oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientManager.java rename to oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ServerOAuth2AuthorizedClientManager.java index dd24f9832ee..1542abbdf2d 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/ServerOAuth2AuthorizedClientManager.java @@ -13,11 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.web.server; +package org.springframework.security.oauth2.client; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import reactor.core.publisher.Mono; /** @@ -43,13 +42,13 @@ public interface ServerOAuth2AuthorizedClientManager { /** * Attempt to authorize or re-authorize (if required) the {@link ClientRegistration client} - * identified by the provided {@link ServerOAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}. + * identified by the provided {@link OAuth2AuthorizeRequest#getClientRegistrationId() clientRegistrationId}. * Implementations must return an empty {@code Mono} if authorization is not supported for the specified client, * e.g. the associated {@link ReactiveOAuth2AuthorizedClientProvider}(s) does not support * the {@link ClientRegistration#getAuthorizationGrantType() authorization grant} type configured for the client. * *

    - * In the case of re-authorization, implementations must return the provided {@link ServerOAuth2AuthorizeRequest#getAuthorizedClient() authorized client} + * In the case of re-authorization, implementations must return the provided {@link OAuth2AuthorizeRequest#getAuthorizedClient() authorized client} * if re-authorization is not supported for the client OR is not required, * e.g. a {@link OAuth2AuthorizedClient#getRefreshToken() refresh token} is not available OR * the {@link OAuth2AuthorizedClient#getAccessToken() access token} is not expired. @@ -57,6 +56,6 @@ public interface ServerOAuth2AuthorizedClientManager { * @param authorizeRequest the authorize request * @return the {@link OAuth2AuthorizedClient} or an empty {@code Mono} if authorization is not supported for the specified client */ - Mono authorize(ServerOAuth2AuthorizeRequest authorizeRequest); + Mono authorize(OAuth2AuthorizeRequest authorizeRequest); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java index d3644186136..688d6ffb634 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManager.java @@ -18,13 +18,17 @@ import org.springframework.lang.Nullable; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; import org.springframework.util.StringUtils; +import org.springframework.web.context.request.RequestContextHolder; +import org.springframework.web.context.request.ServletRequestAttributes; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -69,8 +73,11 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) String clientRegistrationId = authorizeRequest.getClientRegistrationId(); OAuth2AuthorizedClient authorizedClient = authorizeRequest.getAuthorizedClient(); Authentication principal = authorizeRequest.getPrincipal(); - HttpServletRequest servletRequest = authorizeRequest.getServletRequest(); - HttpServletResponse servletResponse = authorizeRequest.getServletResponse(); + + HttpServletRequest servletRequest = getHttpServletRequestOrDefault(authorizeRequest.getAttributes()); + Assert.notNull(servletRequest, "servletRequest cannot be null"); + HttpServletResponse servletResponse = getHttpServletResponseOrDefault(authorizeRequest.getAttributes()); + Assert.notNull(servletResponse, "servletResponse cannot be null"); OAuth2AuthorizationContext.Builder contextBuilder; if (authorizedClient != null) { @@ -105,6 +112,28 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizeRequest authorizeRequest) return authorizedClient; } + private static HttpServletRequest getHttpServletRequestOrDefault(Map attributes) { + HttpServletRequest servletRequest = (HttpServletRequest) attributes.get(HttpServletRequest.class.getName()); + if (servletRequest == null) { + ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + if (context != null) { + servletRequest = context.getRequest(); + } + } + return servletRequest; + } + + private static HttpServletResponse getHttpServletResponseOrDefault(Map attributes) { + HttpServletResponse servletResponse = (HttpServletResponse) attributes.get(HttpServletResponse.class.getName()); + if (servletResponse == null) { + ServletRequestAttributes context = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes(); + if (context != null) { + servletResponse = context.getResponse(); + } + } + return servletResponse; + } + /** * Sets the {@link OAuth2AuthorizedClientProvider} used for authorizing (or re-authorizing) an OAuth 2.0 Client. * @@ -135,7 +164,8 @@ public static class DefaultContextAttributesMapper implements Function apply(OAuth2AuthorizeRequest authorizeRequest) { Map contextAttributes = Collections.emptyMap(); - String scope = authorizeRequest.getServletRequest().getParameter(OAuth2ParameterNames.SCOPE); + HttpServletRequest servletRequest = getHttpServletRequestOrDefault(authorizeRequest.getAttributes()); + String scope = servletRequest.getParameter(OAuth2ParameterNames.SCOPE); if (StringUtils.hasText(scope)) { contextAttributes = new HashMap<>(); contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java deleted file mode 100644 index 7f221183855..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequest.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.client.web; - -import org.springframework.lang.Nullable; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.util.Assert; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -/** - * Represents a request the {@link OAuth2AuthorizedClientManager} uses to - * {@link OAuth2AuthorizedClientManager#authorize(OAuth2AuthorizeRequest) authorize} (or re-authorize) - * the {@link ClientRegistration client} identified by the provided {@link #getClientRegistrationId() clientRegistrationId}. - * - * @author Joe Grandja - * @since 5.2 - * @see OAuth2AuthorizedClientManager - */ -public class OAuth2AuthorizeRequest { - private final String clientRegistrationId; - private final OAuth2AuthorizedClient authorizedClient; - private final Authentication principal; - private final HttpServletRequest servletRequest; - private final HttpServletResponse servletResponse; - - /** - * Constructs an {@code OAuth2AuthorizeRequest} using the provided parameters. - * - * @param clientRegistrationId the identifier for the {@link ClientRegistration client registration} - * @param principal the {@code Principal} (to be) associated to the authorized client - * @param servletRequest the {@code HttpServletRequest} - * @param servletResponse the {@code HttpServletResponse} - */ - public OAuth2AuthorizeRequest(String clientRegistrationId, Authentication principal, - HttpServletRequest servletRequest, HttpServletResponse servletResponse) { - Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); - Assert.notNull(principal, "principal cannot be null"); - Assert.notNull(servletRequest, "servletRequest cannot be null"); - Assert.notNull(servletResponse, "servletResponse cannot be null"); - this.clientRegistrationId = clientRegistrationId; - this.authorizedClient = null; - this.principal = principal; - this.servletRequest = servletRequest; - this.servletResponse = servletResponse; - } - - /** - * Constructs an {@code OAuth2AuthorizeRequest} using the provided parameters. - * - * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} - * @param principal the {@code Principal} associated to the authorized client - * @param servletRequest the {@code HttpServletRequest} - * @param servletResponse the {@code HttpServletResponse} - */ - public OAuth2AuthorizeRequest(OAuth2AuthorizedClient authorizedClient, Authentication principal, - HttpServletRequest servletRequest, HttpServletResponse servletResponse) { - Assert.notNull(authorizedClient, "authorizedClient cannot be null"); - Assert.notNull(principal, "principal cannot be null"); - Assert.notNull(servletRequest, "servletRequest cannot be null"); - Assert.notNull(servletResponse, "servletResponse cannot be null"); - this.clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId(); - this.authorizedClient = authorizedClient; - this.principal = principal; - this.servletRequest = servletRequest; - this.servletResponse = servletResponse; - } - - /** - * Returns the identifier for the {@link ClientRegistration client registration}. - * - * @return the identifier for the client registration - */ - public String getClientRegistrationId() { - return this.clientRegistrationId; - } - - /** - * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if it was not provided. - * - * @return the {@link OAuth2AuthorizedClient} or {@code null} if it was not provided - */ - @Nullable - public OAuth2AuthorizedClient getAuthorizedClient() { - return this.authorizedClient; - } - - /** - * Returns the {@code Principal} (to be) associated to the authorized client. - * - * @return the {@code Principal} (to be) associated to the authorized client - */ - public Authentication getPrincipal() { - return this.principal; - } - - /** - * Returns the {@code HttpServletRequest}. - * - * @return the {@code HttpServletRequest} - */ - public HttpServletRequest getServletRequest() { - return this.servletRequest; - } - - /** - * Returns the {@code HttpServletResponse}. - * - * @return the {@code HttpServletResponse} - */ - public HttpServletResponse getServletResponse() { - return this.servletResponse; - } -} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index b88f2ecad6c..cf3e811cec2 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -33,8 +33,8 @@ import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizeRequest; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -142,8 +142,11 @@ public Object resolveArgument(MethodParameter parameter, HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); - OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( - clientRegistrationId, principal, servletRequest, servletResponse); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId) + .principal(principal) + .attribute(HttpServletRequest.class.getName(), servletRequest) + .attribute(HttpServletResponse.class.getName(), servletResponse) + .build(); return this.authorizedClientManager.authorize(authorizeRequest); } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index 9719fe98b5e..2e4c28043ff 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -22,18 +22,18 @@ import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizeRequest; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.web.reactive.function.client.ClientRequest; @@ -310,7 +310,7 @@ private Mono authorizedClient(ClientRequest request) { reauthorizeRequest(request, authorizedClient).flatMap(this.authorizedClientManager::authorize)); } - private Mono authorizeRequest(ClientRequest request) { + private Mono authorizeRequest(ClientRequest request) { Mono authentication = currentAuthentication(); Mono clientRegistrationId = Mono.justOrEmpty(clientRegistrationId(request)) @@ -323,10 +323,16 @@ private Mono authorizeRequest(ClientRequest reques .defaultIfEmpty(Optional.empty()); return Mono.zip(clientRegistrationId, authentication, serverWebExchange) - .map(t3 -> new ServerOAuth2AuthorizeRequest(t3.getT1(), t3.getT2(), t3.getT3().orElse(null))); + .map(t3 -> { + OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(t3.getT1()).principal(t3.getT2()); + if (t3.getT3().isPresent()) { + builder.attribute(ServerWebExchange.class.getName(), t3.getT3().get()); + } + return builder.build(); + }); } - private Mono reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { + private Mono reauthorizeRequest(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { Mono authentication = currentAuthentication(); Mono> serverWebExchange = Mono.justOrEmpty(serverWebExchange(request)) @@ -335,7 +341,13 @@ private Mono reauthorizeRequest(ClientRequest requ .defaultIfEmpty(Optional.empty()); return Mono.zip(authentication, serverWebExchange) - .map(t2 -> new ServerOAuth2AuthorizeRequest(authorizedClient, t2.getT1(), t2.getT2().orElse(null))); + .map(t2 -> { + OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient).principal(t2.getT1()); + if (t2.getT2().isPresent()) { + builder.attribute(ServerWebExchange.class.getName(), t2.getT2().get()); + } + return builder.build(); + }); } private Mono currentAuthentication() { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java index d32b560b311..a48edd1f6bc 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java @@ -26,7 +26,9 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider; @@ -36,8 +38,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizeRequest; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.web.context.request.RequestContextHolder; @@ -449,8 +449,15 @@ private Mono authorizeClient(String clientRegistrationId } HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( - clientRegistrationId, authentication, servletRequest, servletResponse); + + OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId).principal(authentication); + if (servletRequest != null) { + builder.attribute(HttpServletRequest.class.getName(), servletRequest); + } + if (servletResponse != null) { + builder.attribute(HttpServletResponse.class.getName(), servletResponse); + } + OAuth2AuthorizeRequest authorizeRequest = builder.build(); // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.elastic()) // since it performs a blocking I/O operation using RestTemplate internally @@ -468,8 +475,15 @@ private Mono authorizedClient(OAuth2AuthorizedClient aut } HttpServletRequest servletRequest = getRequest(attrs); HttpServletResponse servletResponse = getResponse(attrs); - OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( - authorizedClient, authentication, servletRequest, servletResponse); + + OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient).principal(authentication); + if (servletRequest != null) { + builder.attribute(HttpServletRequest.class.getName(), servletRequest); + } + if (servletResponse != null) { + builder.attribute(HttpServletResponse.class.getName(), servletResponse); + } + OAuth2AuthorizeRequest reauthorizeRequest = builder.build(); // NOTE: 'authorizedClientManager.authorize()' needs to be executed on a dedicated thread via subscribeOn(Schedulers.elastic()) // since it performs a blocking I/O operation using RestTemplate internally diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 0784ec40000..a835e48922b 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -23,15 +23,15 @@ import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder; +import org.springframework.security.oauth2.client.ServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizeRequest; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -125,7 +125,7 @@ public Mono resolveArgument(MethodParameter parameter, BindingContext bi }); } - private Mono authorizeRequest(String registrationId, ServerWebExchange exchange) { + private Mono authorizeRequest(String registrationId, ServerWebExchange exchange) { Mono defaultedAuthentication = currentAuthentication(); Mono defaultedRegistrationId = Mono.justOrEmpty(registrationId) @@ -136,7 +136,13 @@ private Mono authorizeRequest(String registrationI .switchIfEmpty(currentServerWebExchange()); return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange) - .map(t3 -> new ServerOAuth2AuthorizeRequest(t3.getT1(), t3.getT2(), t3.getT3())); + .map(t3 -> { + OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(t3.getT1()).principal(t3.getT2()); + if (t3.getT3() != null) { + builder.attribute(ServerWebExchange.class.getName(), t3.getT3()); + } + return builder.build(); + }); } private Mono currentAuthentication() { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManager.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManager.java index 54d61c7a11f..82acbdfd4b4 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManager.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManager.java @@ -17,8 +17,10 @@ import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; +import org.springframework.security.oauth2.client.ServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.util.Assert; @@ -43,7 +45,7 @@ public final class DefaultServerOAuth2AuthorizedClientManager implements ServerO private final ReactiveClientRegistrationRepository clientRegistrationRepository; private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = context -> Mono.empty(); - private Function> contextAttributesMapper = new DefaultContextAttributesMapper(); + private Function> contextAttributesMapper = new DefaultContextAttributesMapper(); /** * Constructs a {@code DefaultServerOAuth2AuthorizedClientManager} using the provided parameters. @@ -60,12 +62,14 @@ public DefaultServerOAuth2AuthorizedClientManager(ReactiveClientRegistrationRepo } @Override - public Mono authorize(ServerOAuth2AuthorizeRequest authorizeRequest) { + public Mono authorize(OAuth2AuthorizeRequest authorizeRequest) { Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); String clientRegistrationId = authorizeRequest.getClientRegistrationId(); Authentication principal = authorizeRequest.getPrincipal(); - ServerWebExchange serverWebExchange = authorizeRequest.getServerWebExchange(); + + ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); + Assert.notNull(serverWebExchange, "serverWebExchange cannot be null"); return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) .switchIfEmpty(Mono.defer(() -> @@ -112,13 +116,13 @@ public void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider a } /** - * Sets the {@code Function} used for mapping attribute(s) from the {@link ServerOAuth2AuthorizeRequest} to a {@code Map} of attributes + * Sets the {@code Function} used for mapping attribute(s) from the {@link OAuth2AuthorizeRequest} to a {@code Map} of attributes * to be associated to the {@link OAuth2AuthorizationContext#getAttributes() authorization context}. * * @param contextAttributesMapper the {@code Function} used for supplying the {@code Map} of attributes * to the {@link OAuth2AuthorizationContext#getAttributes() authorization context} */ - public void setContextAttributesMapper(Function> contextAttributesMapper) { + public void setContextAttributesMapper(Function> contextAttributesMapper) { Assert.notNull(contextAttributesMapper, "contextAttributesMapper cannot be null"); this.contextAttributesMapper = contextAttributesMapper; } @@ -126,12 +130,13 @@ public void setContextAttributesMapper(Function> { + public static class DefaultContextAttributesMapper implements Function> { @Override - public Map apply(ServerOAuth2AuthorizeRequest authorizeRequest) { + public Map apply(OAuth2AuthorizeRequest authorizeRequest) { Map contextAttributes = Collections.emptyMap(); - String scope = authorizeRequest.getServerWebExchange().getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); + ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); + String scope = serverWebExchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); if (StringUtils.hasText(scope)) { contextAttributes = new HashMap<>(); contextAttributes.put(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequest.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequest.java deleted file mode 100644 index 6aee1feb4f3..00000000000 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequest.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.client.web.server; - -import org.springframework.lang.Nullable; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.util.Assert; -import org.springframework.web.server.ServerWebExchange; - -/** - * Represents a request the {@link ServerOAuth2AuthorizedClientManager} uses to - * {@link ServerOAuth2AuthorizedClientManager#authorize(ServerOAuth2AuthorizeRequest) authorize} (or re-authorize) - * the {@link ClientRegistration client} identified by the provided {@link #getClientRegistrationId() clientRegistrationId}. - * - * @author Joe Grandja - * @since 5.2 - * @see ServerOAuth2AuthorizedClientManager - */ -public class ServerOAuth2AuthorizeRequest { - private final String clientRegistrationId; - private final OAuth2AuthorizedClient authorizedClient; - private final Authentication principal; - private final ServerWebExchange serverWebExchange; - - /** - * Constructs a {@code ServerOAuth2AuthorizeRequest} using the provided parameters. - * - * @param clientRegistrationId the identifier for the {@link ClientRegistration client registration} - * @param principal the {@code Principal} (to be) associated to the authorized client - * @param serverWebExchange the {@code ServerWebExchange} - */ - public ServerOAuth2AuthorizeRequest(String clientRegistrationId, Authentication principal, - ServerWebExchange serverWebExchange) { - Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); - Assert.notNull(principal, "principal cannot be null"); - Assert.notNull(serverWebExchange, "serverWebExchange cannot be null"); - this.clientRegistrationId = clientRegistrationId; - this.authorizedClient = null; - this.principal = principal; - this.serverWebExchange = serverWebExchange; - } - - /** - * Constructs a {@code ServerOAuth2AuthorizeRequest} using the provided parameters. - * - * @param authorizedClient the {@link OAuth2AuthorizedClient authorized client} - * @param principal the {@code Principal} (to be) associated to the authorized client - * @param serverWebExchange the {@code ServerWebExchange} - */ - public ServerOAuth2AuthorizeRequest(OAuth2AuthorizedClient authorizedClient, Authentication principal, - ServerWebExchange serverWebExchange) { - Assert.notNull(authorizedClient, "authorizedClient cannot be null"); - Assert.notNull(principal, "principal cannot be null"); - Assert.notNull(serverWebExchange, "serverWebExchange cannot be null"); - this.clientRegistrationId = authorizedClient.getClientRegistration().getRegistrationId(); - this.authorizedClient = authorizedClient; - this.principal = principal; - this.serverWebExchange = serverWebExchange; - } - - /** - * Returns the identifier for the {@link ClientRegistration client registration}. - * - * @return the identifier for the client registration - */ - public String getClientRegistrationId() { - return this.clientRegistrationId; - } - - /** - * Returns the {@link OAuth2AuthorizedClient authorized client} or {@code null} if it was not provided. - * - * @return the {@link OAuth2AuthorizedClient} or {@code null} if it was not provided - */ - @Nullable - public OAuth2AuthorizedClient getAuthorizedClient() { - return this.authorizedClient; - } - - /** - * Returns the {@code Principal} (to be) associated to the authorized client. - * - * @return the {@code Principal} (to be) associated to the authorized client - */ - public Authentication getPrincipal() { - return this.principal; - } - - /** - * Returns the {@link ServerWebExchange}. - * - * @return the {@link ServerWebExchange} - */ - public ServerWebExchange getServerWebExchange() { - return this.serverWebExchange; - } -} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java similarity index 51% rename from oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequestTests.java rename to oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java index d4ab401972c..c984ceefb5b 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/ServerOAuth2AuthorizeRequestTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizeRequestTests.java @@ -13,81 +13,76 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.springframework.security.oauth2.client.web.server; +package org.springframework.security.oauth2.client; import org.junit.Test; -import org.springframework.mock.http.server.reactive.MockServerHttpRequest; -import org.springframework.mock.web.server.MockServerWebExchange; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.*; /** - * Tests for {@link ServerOAuth2AuthorizeRequest}. + * Tests for {@link OAuth2AuthorizeRequest}. * * @author Joe Grandja */ -public class ServerOAuth2AuthorizeRequestTests { +public class OAuth2AuthorizeRequestTests { private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); private Authentication principal = new TestingAuthenticationToken("principal", "password"); private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( this.clientRegistration, this.principal.getName(), TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); - private MockServerWebExchange serverWebExchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/")).build(); @Test - public void constructorWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest((String) null, this.principal, this.serverWebExchange)) + public void withClientRegistrationIdWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizeRequest.withClientRegistrationId(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("clientRegistrationId cannot be empty"); } @Test - public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest((OAuth2AuthorizedClient) null, this.principal, this.serverWebExchange)) + public void withAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizeRequest.withAuthorizedClient(null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("authorizedClient cannot be null"); } @Test - public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), null, this.serverWebExchange)) + public void withClientRegistrationIdWhenPrincipalIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()).build()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("principal cannot be null"); } @Test - public void constructorWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new ServerOAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), this.principal, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("serverWebExchange cannot be null"); - } - - @Test - public void constructorClientRegistrationIdWhenAllValuesProvidedThenAllValuesAreSet() { - ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( - this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange); + public void withClientRegistrationIdWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute("name1", "value1") + .attribute("name2", "value2") + .build(); assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); + assertThat(authorizeRequest.getAuthorizedClient()).isNull(); assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizeRequest.getServerWebExchange()).isEqualTo(this.serverWebExchange); + assertThat(authorizeRequest.getAttributes()).contains(entry("name1", "value1"), entry("name2", "value2")); } @Test - public void constructorAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { - ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( - this.authorizedClient, this.principal, this.serverWebExchange); + public void withAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute("name1", "value1") + .attribute("name2", "value2") + .build(); assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.authorizedClient.getClientRegistration().getRegistrationId()); assertThat(authorizeRequest.getAuthorizedClient()).isEqualTo(this.authorizedClient); assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizeRequest.getServerWebExchange()).isEqualTo(this.serverWebExchange); + assertThat(authorizeRequest.getAttributes()).contains(entry("name1", "value1"), entry("name2", "value2")); } } diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManagerServiceTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManagerServiceTests.java new file mode 100644 index 00000000000..f59e16e120d --- /dev/null +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientManagerServiceTests.java @@ -0,0 +1,280 @@ +/* + * Copyright 2002-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.client; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; +import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; +import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; + +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link OAuth2AuthorizedClientManagerService}. + * + * @author Joe Grandja + */ +public class OAuth2AuthorizedClientManagerServiceTests { + private ClientRegistrationRepository clientRegistrationRepository; + private OAuth2AuthorizedClientService authorizedClientService; + private OAuth2AuthorizedClientProvider authorizedClientProvider; + private Function contextAttributesMapper; + private OAuth2AuthorizedClientManagerService authorizedClientManager; + private ClientRegistration clientRegistration; + private Authentication principal; + private OAuth2AuthorizedClient authorizedClient; + private ArgumentCaptor authorizationContextCaptor; + + @SuppressWarnings("unchecked") + @Before + public void setup() { + this.clientRegistrationRepository = mock(ClientRegistrationRepository.class); + this.authorizedClientService = mock(OAuth2AuthorizedClientService.class); + this.authorizedClientProvider = mock(OAuth2AuthorizedClientProvider.class); + this.contextAttributesMapper = mock(Function.class); + this.authorizedClientManager = new OAuth2AuthorizedClientManagerService( + this.clientRegistrationRepository, this.authorizedClientService); + this.authorizedClientManager.setAuthorizedClientProvider(this.authorizedClientProvider); + this.authorizedClientManager.setContextAttributesMapper(this.contextAttributesMapper); + this.clientRegistration = TestClientRegistrations.clientRegistration().build(); + this.principal = new TestingAuthenticationToken("principal", "password"); + this.authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); + this.authorizationContextCaptor = ArgumentCaptor.forClass(OAuth2AuthorizationContext.class); + } + + @Test + public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientManagerService(null, this.authorizedClientService)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("clientRegistrationRepository cannot be null"); + } + + @Test + public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> new OAuth2AuthorizedClientManagerService(this.clientRegistrationRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientService cannot be null"); + } + + @Test + public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizedClientProvider cannot be null"); + } + + @Test + public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("contextAttributesMapper cannot be null"); + } + + @Test + public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.authorizedClientManager.authorize(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizeRequest cannot be null"); + } + + @Test + public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id") + .principal(this.principal) + .build(); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isNull(); + verify(this.authorizedClientService, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isNull(); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(this.authorizedClient), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { + when(this.clientRegistrationRepository.findByRegistrationId( + eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); + when(this.authorizedClientService.loadAuthorizedClient( + eq(this.clientRegistration.getRegistrationId()), eq(this.principal.getName()))).thenReturn(this.authorizedClient); + + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); + + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(this.authorizedClient); + verify(this.authorizedClientService, never()).saveAuthorizedClient( + any(OAuth2AuthorizedClient.class), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenSupportedProviderThenReauthorized() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + } + + @SuppressWarnings("unchecked") + @Test + public void reauthorizeWhenRequestAttributeScopeThenMappedToContext() { + OAuth2AuthorizedClient reauthorizedClient = new OAuth2AuthorizedClient( + this.clientRegistration, this.principal.getName(), + TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); + + when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); + + // Override the mock with the default + this.authorizedClientManager.setContextAttributesMapper( + new OAuth2AuthorizedClientManagerService.DefaultContextAttributesMapper()); + + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(OAuth2ParameterNames.SCOPE, "read write") + .build(); + OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); + + verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); + + OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); + assertThat(authorizationContext.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); + assertThat(authorizationContext.getPrincipal()).isEqualTo(this.principal); + assertThat(authorizationContext.getAttributes()).containsKey(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + String[] requestScopeAttribute = authorizationContext.getAttribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME); + assertThat(requestScopeAttribute).contains("read", "write"); + + assertThat(authorizedClient).isSameAs(reauthorizedClient); + verify(this.authorizedClientService).saveAuthorizedClient( + eq(reauthorizedClient), eq(this.principal)); + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java index 3942df494c4..d214bddee77 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/DefaultOAuth2AuthorizedClientManagerTests.java @@ -23,6 +23,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -32,6 +33,8 @@ import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import java.util.function.Function; import static org.assertj.core.api.Assertions.assertThat; @@ -111,10 +114,34 @@ public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { .hasMessage("authorizeRequest cannot be null"); } + @Test + public void authorizeWhenHttpServletRequestIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("servletRequest cannot be null"); + } + + @Test + public void authorizeWhenHttpServletResponseIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), this.request) + .build(); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("servletResponse cannot be null"); + } + @Test public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( - "invalid-registration-id", this.principal, this.request, this.response); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id") + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); @@ -126,8 +153,11 @@ public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() when(this.clientRegistrationRepository.findByRegistrationId( eq(this.clientRegistration.getRegistrationId()))).thenReturn(this.clientRegistration); - OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( - this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -151,8 +181,11 @@ public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(this.authorizedClient); - OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( - this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -182,8 +215,11 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( - this.clientRegistration.getRegistrationId(), this.principal, this.request, this.response); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -202,8 +238,11 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { @SuppressWarnings("unchecked") @Test public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { - OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( - this.authorizedClient, this.principal, this.request, this.response); + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -228,8 +267,11 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(reauthorizedClient); - OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( - this.authorizedClient, this.principal, this.request, this.response); + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -260,8 +302,11 @@ public void reauthorizeWhenRequestScopeParameterThenMappedToContext() { this.request.addParameter(OAuth2ParameterNames.SCOPE, "read write"); - OAuth2AuthorizeRequest reauthorizeRequest = new OAuth2AuthorizeRequest( - this.authorizedClient, this.principal, this.request, this.response); + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(HttpServletRequest.class.getName(), this.request) + .attribute(HttpServletResponse.class.getName(), this.response) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java deleted file mode 100644 index 6d7e687fcdd..00000000000 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizeRequestTests.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright 2002-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.client.web; - -import org.junit.Test; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.authentication.TestingAuthenticationToken; -import org.springframework.security.core.Authentication; -import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.registration.ClientRegistration; -import org.springframework.security.oauth2.client.registration.TestClientRegistrations; -import org.springframework.security.oauth2.core.TestOAuth2AccessTokens; -import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * Tests for {@link OAuth2AuthorizeRequest}. - * - * @author Joe Grandja - */ -public class OAuth2AuthorizeRequestTests { - private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); - private Authentication principal = new TestingAuthenticationToken("principal", "password"); - private OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - this.clientRegistration, this.principal.getName(), - TestOAuth2AccessTokens.scopes("read", "write"), TestOAuth2RefreshTokens.refreshToken()); - private MockHttpServletRequest servletRequest = new MockHttpServletRequest(); - private MockHttpServletResponse servletResponse = new MockHttpServletResponse(); - - @Test - public void constructorWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizeRequest((String) null, this.principal, this.servletRequest, this.servletResponse)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("clientRegistrationId cannot be empty"); - } - - @Test - public void constructorWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizeRequest((OAuth2AuthorizedClient) null, this.principal, this.servletRequest, this.servletResponse)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizedClient cannot be null"); - } - - @Test - public void constructorWhenPrincipalIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), null, this.servletRequest, this.servletResponse)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("principal cannot be null"); - } - - @Test - public void constructorWhenServletRequestIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), this.principal, null, this.servletResponse)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("servletRequest cannot be null"); - } - - @Test - public void constructorWhenServletResponseIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizeRequest(this.clientRegistration.getRegistrationId(), this.principal, this.servletRequest, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("servletResponse cannot be null"); - } - - @Test - public void constructorClientRegistrationIdWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( - this.clientRegistration.getRegistrationId(), this.principal, this.servletRequest, this.servletResponse); - - assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.clientRegistration.getRegistrationId()); - assertThat(authorizeRequest.getAuthorizedClient()).isNull(); - assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizeRequest.getServletRequest()).isEqualTo(this.servletRequest); - assertThat(authorizeRequest.getServletResponse()).isEqualTo(this.servletResponse); - } - - @Test - public void constructorAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { - OAuth2AuthorizeRequest authorizeRequest = new OAuth2AuthorizeRequest( - this.authorizedClient, this.principal, this.servletRequest, this.servletResponse); - - assertThat(authorizeRequest.getClientRegistrationId()).isEqualTo(this.authorizedClient.getClientRegistration().getRegistrationId()); - assertThat(authorizeRequest.getAuthorizedClient()).isEqualTo(this.authorizedClient); - assertThat(authorizeRequest.getPrincipal()).isEqualTo(this.principal); - assertThat(authorizeRequest.getServletRequest()).isEqualTo(this.servletRequest); - assertThat(authorizeRequest.getServletResponse()).isEqualTo(this.servletResponse); - } -} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManagerTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManagerTests.java index 9729120cd58..a81cc3f74ea 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManagerTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/server/DefaultServerOAuth2AuthorizedClientManagerTests.java @@ -23,6 +23,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.client.OAuth2AuthorizationContext; +import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider; import org.springframework.security.oauth2.client.registration.ClientRegistration; @@ -113,6 +114,16 @@ public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException( .hasMessage("contextAttributesMapper cannot be null"); } + @Test + public void authorizeWhenServerWebExchangeIsNullThenThrowIllegalArgumentException() { + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .build(); + assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("serverWebExchange cannot be null"); + } + @Test public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.authorizedClientManager.authorize(null).block()) @@ -122,8 +133,10 @@ public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { @Test public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { - ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( - "invalid-registration-id", this.principal, this.serverWebExchange); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId("invalid-registration-id") + .principal(this.principal) + .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) + .build(); assertThatThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block()) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); @@ -135,8 +148,10 @@ public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() when(this.clientRegistrationRepository.findByRegistrationId( eq(this.clientRegistration.getRegistrationId()))).thenReturn(Mono.just(this.clientRegistration)); - ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( - this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -161,8 +176,10 @@ public void authorizeWhenNotAuthorizedAndSupportedProviderThenAuthorized() { when(this.authorizedClientProvider.authorize( any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(this.authorizedClient)); - ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( - this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -192,8 +209,10 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); - ServerOAuth2AuthorizeRequest authorizeRequest = new ServerOAuth2AuthorizeRequest( - this.clientRegistration.getRegistrationId(), this.principal, this.serverWebExchange); + OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(this.clientRegistration.getRegistrationId()) + .principal(this.principal) + .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -212,8 +231,10 @@ public void authorizeWhenAuthorizedAndSupportedProviderThenReauthorized() { @SuppressWarnings("unchecked") @Test public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { - ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest( - this.authorizedClient, this.principal, this.serverWebExchange); + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -238,8 +259,10 @@ public void reauthorizeWhenSupportedProviderThenReauthorized() { when(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).thenReturn(Mono.just(reauthorizedClient)); - ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest( - this.authorizedClient, this.principal, this.serverWebExchange); + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); @@ -274,8 +297,10 @@ public void reauthorizeWhenRequestScopeParameterThenMappedToContext() { .queryParam(OAuth2ParameterNames.SCOPE, "read write")) .build(); - ServerOAuth2AuthorizeRequest reauthorizeRequest = new ServerOAuth2AuthorizeRequest( - this.authorizedClient, this.principal, this.serverWebExchange); + OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) + .principal(this.principal) + .attribute(ServerWebExchange.class.getName(), this.serverWebExchange) + .build(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest).block(); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); diff --git a/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/WebClientConfig.java b/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/WebClientConfig.java index 7fdba365b7f..48fa389fb2f 100644 --- a/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/WebClientConfig.java +++ b/samples/boot/oauth2webclient-webflux/src/main/java/sample/config/WebClientConfig.java @@ -23,7 +23,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction; import org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.ServerOAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.web.reactive.function.client.WebClient; diff --git a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java index 636bc53fd6f..6d767c0c646 100644 --- a/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java +++ b/samples/boot/oauth2webclient/src/main/java/sample/config/WebClientConfig.java @@ -22,7 +22,7 @@ import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder; import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository; import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager; -import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientManager; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager; import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction; import org.springframework.web.reactive.function.client.WebClient;