Skip to content

Commit 2cd5482

Browse files
committed
Allow configuring a custom OAuth2AuthorizationRequestResolver
Fixes gh-5521
1 parent becff23 commit 2cd5482

File tree

4 files changed

+130
-15
lines changed

4 files changed

+130
-15
lines changed

config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
2929
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
3030
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
31+
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
3132
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
3233
import org.springframework.security.web.savedrequest.RequestCache;
3334
import org.springframework.util.Assert;
@@ -147,6 +148,7 @@ public AuthorizationEndpointConfig authorizationEndpoint() {
147148
*/
148149
public class AuthorizationEndpointConfig {
149150
private String authorizationRequestBaseUri;
151+
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
150152
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
151153

152154
private AuthorizationEndpointConfig() {
@@ -164,6 +166,18 @@ public AuthorizationEndpointConfig baseUri(String authorizationRequestBaseUri) {
164166
return this;
165167
}
166168

169+
/**
170+
* Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
171+
*
172+
* @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s
173+
* @return the {@link AuthorizationEndpointConfig} for further configuration
174+
*/
175+
public AuthorizationEndpointConfig authorizationRequestResolver(OAuth2AuthorizationRequestResolver authorizationRequestResolver) {
176+
Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null");
177+
this.authorizationRequestResolver = authorizationRequestResolver;
178+
return this;
179+
}
180+
167181
/**
168182
* Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s.
169183
*
@@ -267,14 +281,20 @@ private void init(B builder, AuthorizationCodeGrantConfigurer authorizationCodeG
267281
}
268282

269283
private void configure(B builder, AuthorizationCodeGrantConfigurer authorizationCodeGrantConfigurer) throws Exception {
270-
String authorizationRequestBaseUri = authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestBaseUri;
271-
if (authorizationRequestBaseUri == null) {
272-
authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
284+
OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter;
285+
286+
if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestResolver != null) {
287+
authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
288+
authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestResolver);
289+
} else {
290+
String authorizationRequestBaseUri = authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestBaseUri;
291+
if (authorizationRequestBaseUri == null) {
292+
authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
293+
}
294+
authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
295+
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), authorizationRequestBaseUri);
273296
}
274297

275-
OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
276-
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(builder), authorizationRequestBaseUri);
277-
278298
if (authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository != null) {
279299
authorizationRequestFilter.setAuthorizationRequestRepository(
280300
authorizationCodeGrantConfigurer.authorizationEndpointConfig.authorizationRequestRepository);

config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
4545
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
4646
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
47+
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
4748
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
4849
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
4950
import org.springframework.security.oauth2.core.OAuth2Error;
@@ -178,6 +179,7 @@ public AuthorizationEndpointConfig authorizationEndpoint() {
178179
*/
179180
public class AuthorizationEndpointConfig {
180181
private String authorizationRequestBaseUri;
182+
private OAuth2AuthorizationRequestResolver authorizationRequestResolver;
181183
private AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
182184

183185
private AuthorizationEndpointConfig() {
@@ -195,6 +197,19 @@ public AuthorizationEndpointConfig baseUri(String authorizationRequestBaseUri) {
195197
return this;
196198
}
197199

200+
/**
201+
* Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
202+
*
203+
* @since 5.1
204+
* @param authorizationRequestResolver the resolver used for resolving {@link OAuth2AuthorizationRequest}'s
205+
* @return the {@link AuthorizationEndpointConfig} for further configuration
206+
*/
207+
public AuthorizationEndpointConfig authorizationRequestResolver(OAuth2AuthorizationRequestResolver authorizationRequestResolver) {
208+
Assert.notNull(authorizationRequestResolver, "authorizationRequestResolver cannot be null");
209+
this.authorizationRequestResolver = authorizationRequestResolver;
210+
return this;
211+
}
212+
198213
/**
199214
* Sets the repository used for storing {@link OAuth2AuthorizationRequest}'s.
200215
*
@@ -444,13 +459,19 @@ public void init(B http) throws Exception {
444459

445460
@Override
446461
public void configure(B http) throws Exception {
447-
String authorizationRequestBaseUri = this.authorizationEndpointConfig.authorizationRequestBaseUri;
448-
if (authorizationRequestBaseUri == null) {
449-
authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
450-
}
462+
OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter;
451463

452-
OAuth2AuthorizationRequestRedirectFilter authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
453-
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), authorizationRequestBaseUri);
464+
if (this.authorizationEndpointConfig.authorizationRequestResolver != null) {
465+
authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
466+
this.authorizationEndpointConfig.authorizationRequestResolver);
467+
} else {
468+
String authorizationRequestBaseUri = this.authorizationEndpointConfig.authorizationRequestBaseUri;
469+
if (authorizationRequestBaseUri == null) {
470+
authorizationRequestBaseUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
471+
}
472+
authorizationRequestFilter = new OAuth2AuthorizationRequestRedirectFilter(
473+
OAuth2ClientConfigurerUtils.getClientRegistrationRepository(this.getBuilder()), authorizationRequestBaseUri);
474+
}
454475

455476
if (this.authorizationEndpointConfig.authorizationRequestRepository != null) {
456477
authorizationRequestFilter.setAuthorizationRequestRepository(

config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
3838
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
3939
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
40+
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
4041
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
42+
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
4143
import org.springframework.security.oauth2.core.AuthorizationGrantType;
4244
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
4345
import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -74,6 +76,8 @@ public class OAuth2ClientConfigurerTests {
7476

7577
private static OAuth2AuthorizedClientService authorizedClientService;
7678

79+
private static OAuth2AuthorizationRequestResolver authorizationRequestResolver;
80+
7781
private static OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient;
7882

7983
private static RequestCache requestCache;
@@ -103,6 +107,8 @@ public void setup() {
103107
.build();
104108
clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1);
105109
authorizedClientService = new InMemoryOAuth2AuthorizedClientService(clientRegistrationRepository);
110+
authorizationRequestResolver = new DefaultOAuth2AuthorizationRequestResolver(
111+
clientRegistrationRepository, "/oauth2/authorization");
106112

107113
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("access-token-1234")
108114
.tokenType(OAuth2AccessToken.TokenType.BEARER)
@@ -173,6 +179,28 @@ public void configureWhenRequestCacheProvidedAndClientAuthorizationRequiredExcep
173179
verify(requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class));
174180
}
175181

182+
// gh-5521
183+
@Test
184+
public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception {
185+
// Override default resolver
186+
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver = authorizationRequestResolver;
187+
authorizationRequestResolver = request -> {
188+
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
189+
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
190+
additionalParameters.put("param1", "value1");
191+
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
192+
.additionalParameters(additionalParameters)
193+
.build();
194+
};
195+
196+
this.spring.register(OAuth2ClientConfig.class).autowire();
197+
198+
MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorization/registration-1"))
199+
.andExpect(status().is3xxRedirection())
200+
.andReturn();
201+
assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://provider.com/oauth2/authorize\\?response_type=code&client_id=client-1&scope=user&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Fclient-1&param1=value1");
202+
}
203+
176204
@EnableWebSecurity
177205
@EnableWebMvc
178206
static class OAuth2ClientConfig extends WebSecurityConfigurerAdapter {
@@ -188,6 +216,9 @@ protected void configure(HttpSecurity http) throws Exception {
188216
.oauth2()
189217
.client()
190218
.authorizationCodeGrant()
219+
.authorizationEndpoint()
220+
.authorizationRequestResolver(authorizationRequestResolver)
221+
.and()
191222
.tokenEndpoint()
192223
.accessTokenResponseClient(accessTokenResponseClient);
193224
}

config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
4343
import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
4444
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
45+
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizationRequestResolver;
4546
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
47+
import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
4648
import org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter;
4749
import org.springframework.security.oauth2.core.OAuth2AccessToken;
4850
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -105,11 +107,9 @@ public class OAuth2LoginConfigurerTests {
105107
@Before
106108
public void setup() {
107109
this.request = new MockHttpServletRequest("GET", "");
110+
this.request.setServletPath("/login/oauth2/code/google");
108111
this.response = new MockHttpServletResponse();
109112
this.filterChain = new MockFilterChain();
110-
111-
this.request.setMethod("GET");
112-
this.request.setServletPath("/login/oauth2/code/google");
113113
}
114114

115115
@After
@@ -225,6 +225,20 @@ public void oauth2LoginConfigLoginProcessingUrl() throws Exception {
225225
.isInstanceOf(OAuth2UserAuthority.class).hasToString("ROLE_USER");
226226
}
227227

228+
// gh-5521
229+
@Test
230+
public void oauth2LoginWithCustomAuthorizationRequestParameters() throws Exception {
231+
loadConfig(OAuth2LoginConfigCustomAuthorizationRequestResolver.class);
232+
233+
String requestUri = "/oauth2/authorization/google";
234+
this.request = new MockHttpServletRequest("GET", requestUri);
235+
this.request.setServletPath(requestUri);
236+
237+
this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
238+
239+
assertThat(this.response.getRedirectedUrl()).matches("https://accounts.google.com/o/oauth2/v2/auth\\?response_type=code&client_id=clientId&scope=openid\\+profile\\+email&state=.{15,}&redirect_uri=http%3A%2F%2Flocalhost%2Flogin%2Foauth2%2Fcode%2Fgoogle&custom-param1=custom-value1");
240+
}
241+
228242
@Test
229243
public void oidcLogin() throws Exception {
230244
// setup application context
@@ -406,6 +420,35 @@ protected void configure(HttpSecurity http) throws Exception {
406420
}
407421
}
408422

423+
@EnableWebSecurity
424+
static class OAuth2LoginConfigCustomAuthorizationRequestResolver extends CommonWebSecurityConfigurerAdapter {
425+
private ClientRegistrationRepository clientRegistrationRepository =
426+
new InMemoryClientRegistrationRepository(CLIENT_REGISTRATION);
427+
428+
@Override
429+
protected void configure(HttpSecurity http) throws Exception {
430+
http
431+
.oauth2Login()
432+
.clientRegistrationRepository(this.clientRegistrationRepository)
433+
.authorizationEndpoint()
434+
.authorizationRequestResolver(this.getAuthorizationRequestResolver());
435+
super.configure(http);
436+
}
437+
438+
private OAuth2AuthorizationRequestResolver getAuthorizationRequestResolver() {
439+
OAuth2AuthorizationRequestResolver defaultAuthorizationRequestResolver =
440+
new DefaultOAuth2AuthorizationRequestResolver(this.clientRegistrationRepository, "/oauth2/authorization");
441+
return request -> {
442+
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
443+
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
444+
additionalParameters.put("custom-param1", "custom-value1");
445+
return OAuth2AuthorizationRequest.from(defaultAuthorizationRequest)
446+
.additionalParameters(additionalParameters)
447+
.build();
448+
};
449+
}
450+
}
451+
409452
private static abstract class CommonWebSecurityConfigurerAdapter extends WebSecurityConfigurerAdapter {
410453
@Override
411454
protected void configure(HttpSecurity http) throws Exception {

0 commit comments

Comments
 (0)