Skip to content

Make OAuth's RestTemplate and WebClient customizable #8624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,24 +67,32 @@ static class OAuth2ClientWebMvcSecurityConfiguration implements WebMvcConfigurer
private ClientRegistrationRepository clientRegistrationRepository;
private OAuth2AuthorizedClientRepository authorizedClientRepository;
private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient;
private OAuth2AuthorizedClientProviderBuilder.RefreshTokenGrantBuilderCustomizer refreshTokenGrantBuilderCustomizer;
private OAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilderCustomizer clientCredentialsGrantBuilderCustomizer;
private OAuth2AuthorizedClientProviderBuilder.PasswordGrantBuilderCustomizer passwordGrantBuilderCustomizer;

@Override
public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
if (this.clientRegistrationRepository != null && this.authorizedClientRepository != null) {
OAuth2AuthorizedClientProviderBuilder authorizedClientProviderBuilder =
OAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.password();
.refreshToken(refreshTokenGrantBuilderCustomizer)
.password(passwordGrantBuilderCustomizer);
if (this.accessTokenResponseClient != null) {
authorizedClientProviderBuilder.clientCredentials(clientCredentialsGrantBuilderCustomizer);
authorizedClientProviderBuilder.clientCredentials(configurer ->
configurer.accessTokenResponseClient(this.accessTokenResponseClient));
configurer.accessTokenResponseClient(this.accessTokenResponseClient));
} else {
authorizedClientProviderBuilder.clientCredentials();
authorizedClientProviderBuilder.clientCredentials(clientCredentialsGrantBuilderCustomizer);
}
OAuth2AuthorizedClientProvider authorizedClientProvider = authorizedClientProviderBuilder.build();
DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository);
this.clientRegistrationRepository,
this.authorizedClientRepository,
this.refreshTokenGrantBuilderCustomizer,
this.clientCredentialsGrantBuilderCustomizer,
this.passwordGrantBuilderCustomizer);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
argumentResolvers.add(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager));
}
Expand All @@ -104,6 +112,30 @@ public void setAuthorizedClientRepository(List<OAuth2AuthorizedClientRepository>
}
}

@Autowired(required = false)
public void setRefreshTokenGrantBuilderCustomizer(
List<OAuth2AuthorizedClientProviderBuilder.RefreshTokenGrantBuilderCustomizer> beans) {
if (beans.size() == 1) {
this.refreshTokenGrantBuilderCustomizer = beans.get(0);
}
}

@Autowired(required = false)
public void setClientCredentialsGrantBuilderCustomizer(
List<OAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilderCustomizer> beans) {
if (beans.size() == 1) {
this.clientCredentialsGrantBuilderCustomizer = beans.get(0);
}
}

@Autowired(required = false)
public void setPasswordGrantBuilderCustomizer(
List<OAuth2AuthorizedClientProviderBuilder.PasswordGrantBuilderCustomizer> beans) {
if (beans.size() == 1) {
this.passwordGrantBuilderCustomizer = beans.get(0);
}
}

@Autowired
public void setAccessTokenResponseClient(
Optional<OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest>> accessTokenResponseClient) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RestTemplateFactory;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
Expand Down Expand Up @@ -156,6 +157,7 @@ public class AuthorizationCodeGrantConfigurer {
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
private OAuth2RestTemplateFactory restTemplateFactory;

private AuthorizationCodeGrantConfigurer() {
}
Expand Down Expand Up @@ -200,6 +202,12 @@ public AuthorizationCodeGrantConfigurer accessTokenResponseClient(
return this;
}

public AuthorizationCodeGrantConfigurer restTemplateFactory(OAuth2RestTemplateFactory restTemplateFactory) {
Assert.notNull(restTemplateFactory, "restTemplateFactory cannot be null");
this.restTemplateFactory = restTemplateFactory;
return this;
}

/**
* Returns the {@link OAuth2ClientConfigurer} for further configuration.
*
Expand All @@ -211,7 +219,7 @@ public OAuth2ClientConfigurer<B> and() {

private void init(B builder) {
OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider =
new OAuth2AuthorizationCodeAuthenticationProvider(getAccessTokenResponseClient());
new OAuth2AuthorizationCodeAuthenticationProvider(getOrCreateAccessTokenResponseClient());
builder.authenticationProvider(postProcess(authorizationCodeAuthenticationProvider));
}

Expand Down Expand Up @@ -264,11 +272,14 @@ private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B
return authorizationCodeGrantFilter;
}

private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> getAccessTokenResponseClient() {
if (this.accessTokenResponseClient != null) {
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> getOrCreateAccessTokenResponseClient() {
if (this.accessTokenResponseClient == null) {
return restTemplateFactory == null
? new DefaultAuthorizationCodeTokenResponseClient()
: new DefaultAuthorizationCodeTokenResponseClient(restTemplateFactory);
} else {
return this.accessTokenResponseClient;
}
return new DefaultAuthorizationCodeTokenResponseClient();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,19 @@
import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
import org.springframework.security.oauth2.client.endpoint.OAuth2RestTemplateFactory;
import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeAuthenticationProvider;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.userinfo.CustomUserTypesOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserServiceRestTemplateFactory;
import org.springframework.security.oauth2.client.userinfo.DelegatingOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserServiceRestTemplateFactory;
import org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
Expand Down Expand Up @@ -299,6 +302,7 @@ public OAuth2LoginConfigurer<B> tokenEndpoint(Customizer<TokenEndpointConfig> to
*/
public class TokenEndpointConfig {
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
private OAuth2RestTemplateFactory restTemplateFactory;

private TokenEndpointConfig() {
}
Expand All @@ -317,6 +321,12 @@ public TokenEndpointConfig accessTokenResponseClient(
return this;
}

public TokenEndpointConfig restTemplateFactory(OAuth2RestTemplateFactory restTemplateFactory) {
Assert.notNull(restTemplateFactory, "restTemplateFactory cannot be null");
this.restTemplateFactory = restTemplateFactory;
return this;
}

/**
* Returns the {@link OAuth2LoginConfigurer} for further configuration.
*
Expand All @@ -325,6 +335,16 @@ public TokenEndpointConfig accessTokenResponseClient(
public OAuth2LoginConfigurer<B> and() {
return OAuth2LoginConfigurer.this;
}

private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> getOrCreateAccessTokenResponseClient() {
if (accessTokenResponseClient == null) {
return restTemplateFactory == null
? new DefaultAuthorizationCodeTokenResponseClient()
: new DefaultAuthorizationCodeTokenResponseClient(restTemplateFactory);
} else {
return accessTokenResponseClient;
}
}
}

/**
Expand Down Expand Up @@ -407,8 +427,10 @@ public class UserInfoEndpointConfig {
private OAuth2UserService<OAuth2UserRequest, OAuth2User> userService;
private OAuth2UserService<OidcUserRequest, OidcUser> oidcUserService;
private Map<String, Class<? extends OAuth2User>> customUserTypes = new HashMap<>();
private OAuth2UserServiceRestTemplateFactory restTemplateFactory;

private UserInfoEndpointConfig() {
this.restTemplateFactory = DefaultOAuth2UserServiceRestTemplateFactory.DEFAULT;
}

/**
Expand Down Expand Up @@ -462,6 +484,12 @@ public UserInfoEndpointConfig userAuthoritiesMapper(GrantedAuthoritiesMapper use
return this;
}

public UserInfoEndpointConfig restTemplateFactory(OAuth2UserServiceRestTemplateFactory restTemplateFactory) {
Assert.notNull(restTemplateFactory, "restTemplateFactory cannot be null");
this.restTemplateFactory = restTemplateFactory;
return this;
}

/**
* Returns the {@link OAuth2LoginConfigurer} for further configuration.
*
Expand Down Expand Up @@ -501,10 +529,7 @@ public void init(B http) throws Exception {
}

OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient =
this.tokenEndpointConfig.accessTokenResponseClient;
if (accessTokenResponseClient == null) {
accessTokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient();
}
this.tokenEndpointConfig.getOrCreateAccessTokenResponseClient();

OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService = getOAuth2UserService();
OAuth2LoginAuthenticationProvider oauth2LoginAuthenticationProvider =
Expand Down Expand Up @@ -619,7 +644,7 @@ private OAuth2UserService<OidcUserRequest, OidcUser> getOidcUserService() {
ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OidcUserRequest.class, OidcUser.class);
OAuth2UserService<OidcUserRequest, OidcUser> bean = getBeanOrNull(type);
if (bean == null) {
return new OidcUserService();
return new OidcUserService(userInfoEndpointConfig.restTemplateFactory);
}

return bean;
Expand All @@ -634,11 +659,13 @@ private OAuth2UserService<OAuth2UserRequest, OAuth2User> getOAuth2UserService()
if (bean == null) {
if (!this.userInfoEndpointConfig.customUserTypes.isEmpty()) {
List<OAuth2UserService<OAuth2UserRequest, OAuth2User>> userServices = new ArrayList<>();
userServices.add(new CustomUserTypesOAuth2UserService(this.userInfoEndpointConfig.customUserTypes));
userServices.add(new DefaultOAuth2UserService());
userServices.add(new CustomUserTypesOAuth2UserService(
this.userInfoEndpointConfig.customUserTypes,
this.userInfoEndpointConfig.restTemplateFactory));
userServices.add(new DefaultOAuth2UserService(this.userInfoEndpointConfig.restTemplateFactory));
return new DelegatingOAuth2UserService<>(userServices);
} else {
return new DefaultOAuth2UserService();
return new DefaultOAuth2UserService(this.userInfoEndpointConfig.restTemplateFactory);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.springframework.security.oauth2.server.resource.authentication.OpaqueTokenAuthenticationProvider;
import org.springframework.security.oauth2.server.resource.introspection.NimbusOpaqueTokenIntrospector;
import org.springframework.security.oauth2.server.resource.introspection.OpaqueTokenIntrospector;
import org.springframework.security.oauth2.server.resource.introspection.OpaqueTokenIntrospectorRestTemplateFactory;
import org.springframework.security.oauth2.server.resource.web.BearerTokenAuthenticationEntryPoint;
import org.springframework.security.oauth2.server.resource.web.BearerTokenAuthenticationFilter;
import org.springframework.security.oauth2.server.resource.web.BearerTokenResolver;
Expand All @@ -49,6 +50,7 @@
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.client.RestOperations;

import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;

Expand Down Expand Up @@ -283,6 +285,8 @@ public class JwtConfigurer {

private Converter<Jwt, ? extends AbstractAuthenticationToken> jwtAuthenticationConverter;

private RestOperations restOperations;

JwtConfigurer(ApplicationContext context) {
this.context = context;
}
Expand All @@ -299,7 +303,15 @@ public JwtConfigurer decoder(JwtDecoder decoder) {
}

public JwtConfigurer jwkSetUri(String uri) {
this.decoder = withJwkSetUri(uri).build();
final NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder builder = withJwkSetUri(uri);
this.decoder = restOperations == null
? builder.build()
: builder.restOperations(restOperations).build();
return this;
}

public JwtConfigurer restOperations(RestOperations restOperations) {
this.restOperations = restOperations;
return this;
}

Expand Down Expand Up @@ -366,6 +378,7 @@ public class OpaqueTokenConfigurer {
private String clientId;
private String clientSecret;
private Supplier<OpaqueTokenIntrospector> introspector;
private OpaqueTokenIntrospectorRestTemplateFactory restTemplateFactory;

OpaqueTokenConfigurer(ApplicationContext context) {
this.context = context;
Expand All @@ -380,8 +393,7 @@ public OpaqueTokenConfigurer authenticationManager(AuthenticationManager authent
public OpaqueTokenConfigurer introspectionUri(String introspectionUri) {
Assert.notNull(introspectionUri, "introspectionUri cannot be null");
this.introspectionUri = introspectionUri;
this.introspector = () ->
new NimbusOpaqueTokenIntrospector(this.introspectionUri, this.clientId, this.clientSecret);
this.introspector = createIntrospectorSupplier();
return this;
}

Expand All @@ -390,11 +402,23 @@ public OpaqueTokenConfigurer introspectionClientCredentials(String clientId, Str
Assert.notNull(clientSecret, "clientSecret cannot be null");
this.clientId = clientId;
this.clientSecret = clientSecret;
this.introspector = () ->
new NimbusOpaqueTokenIntrospector(this.introspectionUri, this.clientId, this.clientSecret);
this.introspector = createIntrospectorSupplier();
return this;
}

public OpaqueTokenConfigurer restTemplateFactory(OpaqueTokenIntrospectorRestTemplateFactory restTemplateFactory) {
Assert.notNull(restTemplateFactory, "restTemplateFactory cannot be null");
this.restTemplateFactory = restTemplateFactory;
this.introspector = createIntrospectorSupplier();
return this;
}

private Supplier<OpaqueTokenIntrospector> createIntrospectorSupplier() {
return () -> restTemplateFactory == null
? new NimbusOpaqueTokenIntrospector(this.introspectionUri, this.clientId, this.clientSecret)
: new NimbusOpaqueTokenIntrospector(this.introspectionUri, this.clientId, this.clientSecret, restTemplateFactory);
}

public OpaqueTokenConfigurer introspector(OpaqueTokenIntrospector introspector) {
Assert.notNull(introspector, "introspector cannot be null");
this.introspector = () -> introspector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.util.ClassUtils;
import org.springframework.web.reactive.config.WebFluxConfigurer;
Expand Down Expand Up @@ -63,15 +63,19 @@ static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigur

private ReactiveOAuth2AuthorizedClientService authorizedClientService;

private ReactiveOAuth2AuthorizedClientProviderBuilder.RefreshTokenGrantBuilderCustomizer refreshTokenGrantBuilderCustomizer;
private ReactiveOAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilderCustomizer clientCredentialsGrantBuilderCustomizer;
private ReactiveOAuth2AuthorizedClientProviderBuilder.PasswordGrantBuilderCustomizer passwordGrantBuilderCustomizer;

@Override
public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) {
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider =
ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.password()
.refreshToken(refreshTokenGrantBuilderCustomizer)
.clientCredentials(clientCredentialsGrantBuilderCustomizer)
.password(passwordGrantBuilderCustomizer)
.build();
DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, getAuthorizedClientRepository());
Expand All @@ -98,6 +102,27 @@ public void setAuthorizedClientService(List<ReactiveOAuth2AuthorizedClientServic
}
}

@Autowired(required = false)
public void setRefreshTokenGrantBuilderCustomizer(List<ReactiveOAuth2AuthorizedClientProviderBuilder.RefreshTokenGrantBuilderCustomizer> beans) {
if (beans.size() == 1) {
this.refreshTokenGrantBuilderCustomizer = beans.get(0);
}
}

@Autowired(required = false)
public void setClientCredentialsGrantBuilderCustomizer(List<ReactiveOAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilderCustomizer> beans) {
if (beans.size() == 1) {
this.clientCredentialsGrantBuilderCustomizer = beans.get(0);
}
}

@Autowired(required = false)
public void setPasswordGrantBuilderCustomizer(List<ReactiveOAuth2AuthorizedClientProviderBuilder.PasswordGrantBuilderCustomizer> beans) {
if (beans.size() == 1) {
this.passwordGrantBuilderCustomizer = beans.get(0);
}
}

private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() {
if (this.authorizedClientRepository != null) {
return this.authorizedClientRepository;
Expand Down
Loading