46
46
import org .springframework .security .core .context .SecurityContextHolder ;
47
47
import org .springframework .security .oauth2 .client .OAuth2AuthorizedClient ;
48
48
import org .springframework .security .oauth2 .client .authentication .OAuth2AuthenticationToken ;
49
+ import org .springframework .security .oauth2 .client .endpoint .OAuth2AccessTokenResponseClient ;
50
+ import org .springframework .security .oauth2 .client .endpoint .OAuth2ClientCredentialsGrantRequest ;
49
51
import org .springframework .security .oauth2 .client .registration .ClientRegistration ;
52
+ import org .springframework .security .oauth2 .client .registration .ClientRegistrationRepository ;
50
53
import org .springframework .security .oauth2 .client .registration .TestClientRegistrations ;
51
54
import org .springframework .security .oauth2 .client .web .OAuth2AuthorizedClientRepository ;
52
55
import org .springframework .security .oauth2 .core .OAuth2AccessToken ;
53
56
import org .springframework .security .oauth2 .core .OAuth2RefreshToken ;
54
57
import org .springframework .security .oauth2 .core .endpoint .OAuth2AccessTokenResponse ;
58
+ import org .springframework .security .oauth2 .core .endpoint .TestOAuth2AccessTokenResponses ;
55
59
import org .springframework .security .oauth2 .core .user .OAuth2User ;
56
60
import org .springframework .web .context .request .RequestContextHolder ;
57
61
import org .springframework .web .context .request .ServletRequestAttributes ;
@@ -89,6 +93,10 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
89
93
@ Mock
90
94
private OAuth2AuthorizedClientRepository authorizedClientRepository ;
91
95
@ Mock
96
+ private ClientRegistrationRepository clientRegistrationRepository ;
97
+ @ Mock
98
+ private OAuth2AccessTokenResponseClient <OAuth2ClientCredentialsGrantRequest > clientCredentialsTokenResponseClient ;
99
+ @ Mock
92
100
private WebClient .RequestHeadersSpec <?> spec ;
93
101
@ Captor
94
102
private ArgumentCaptor <Consumer <Map <String , Object >>> attrs ;
@@ -148,7 +156,8 @@ public void defaultRequestAuthenticationWhenSecurityContextEmptyThenAuthenticati
148
156
149
157
@ Test
150
158
public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationSet () {
151
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
159
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
160
+ this .authorizedClientRepository );
152
161
SecurityContextHolder .getContext ().setAuthentication (this .authentication );
153
162
Map <String , Object > attrs = getDefaultRequestAttributes ();
154
163
assertThat (getAuthentication (attrs )).isEqualTo (this .authentication );
@@ -157,7 +166,8 @@ public void defaultRequestAuthenticationWhenAuthenticationSetThenAuthenticationS
157
166
158
167
@ Test
159
168
public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAndClientIdThenNotOverride () {
160
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
169
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
170
+ this .authorizedClientRepository );
161
171
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient (this .registration ,
162
172
"principalName" , this .accessToken );
163
173
oauth2AuthorizedClient (authorizedClient ).accept (this .result );
@@ -168,15 +178,17 @@ public void defaultRequestOAuth2AuthorizedClientWhenOAuth2AuthorizationClientAnd
168
178
169
179
@ Test
170
180
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull () {
171
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
181
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
182
+ this .authorizedClientRepository );
172
183
Map <String , Object > attrs = getDefaultRequestAttributes ();
173
184
assertThat (getOAuth2AuthorizedClient (attrs )).isNull ();
174
185
verifyZeroInteractions (this .authorizedClientRepository );
175
186
}
176
187
177
188
@ Test
178
189
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationWrongTypeAndClientRegistrationIdNullThenOAuth2AuthorizedClientNull () {
179
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
190
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
191
+ this .authorizedClientRepository );
180
192
Map <String , Object > attrs = getDefaultRequestAttributes ();
181
193
assertThat (getOAuth2AuthorizedClient (attrs )).isNull ();
182
194
verifyZeroInteractions (this .authorizedClientRepository );
@@ -196,7 +208,8 @@ public void defaultRequestOAuth2AuthorizedClientWhenRepositoryNullThenOAuth2Auth
196
208
197
209
@ Test
198
210
public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient () {
199
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
211
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
212
+ this .authorizedClientRepository );
200
213
this .function .setDefaultOAuth2AuthorizedClient (true );
201
214
OAuth2User user = mock (OAuth2User .class );
202
215
List <GrantedAuthority > authorities = AuthorityUtils .createAuthorityList ("ROLE_USER" );
@@ -214,7 +227,8 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultTrueAndAuthentication
214
227
215
228
@ Test
216
229
public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticationAndClientRegistrationIdNullThenOAuth2AuthorizedClient () {
217
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
230
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
231
+ this .authorizedClientRepository );
218
232
OAuth2User user = mock (OAuth2User .class );
219
233
List <GrantedAuthority > authorities = AuthorityUtils .createAuthorityList ("ROLE_USER" );
220
234
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken (user , authorities , "id" );
@@ -227,7 +241,8 @@ public void defaultRequestOAuth2AuthorizedClientWhenDefaultFalseAndAuthenticatio
227
241
228
242
@ Test
229
243
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegistrationIdThenIdIsExplicit () {
230
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
244
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
245
+ this .authorizedClientRepository );
231
246
OAuth2User user = mock (OAuth2User .class );
232
247
List <GrantedAuthority > authorities = AuthorityUtils .createAuthorityList ("ROLE_USER" );
233
248
OAuth2AuthenticationToken token = new OAuth2AuthenticationToken (user , authorities , "id" );
@@ -245,9 +260,8 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationAndClientRegis
245
260
246
261
@ Test
247
262
public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientRegistrationIdThenOAuth2AuthorizedClient () {
248
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
249
- OAuth2User user = mock (OAuth2User .class );
250
- List <GrantedAuthority > authorities = AuthorityUtils .createAuthorityList ("ROLE_USER" );
263
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
264
+ this .authorizedClientRepository );
251
265
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient (this .registration ,
252
266
"principalName" , this .accessToken );
253
267
when (this .authorizedClientRepository .loadAuthorizedClient (any (), any (), any ())).thenReturn (authorizedClient );
@@ -259,6 +273,41 @@ public void defaultRequestOAuth2AuthorizedClientWhenAuthenticationNullAndClientR
259
273
verify (this .authorizedClientRepository ).loadAuthorizedClient (eq ("id" ), any (), any ());
260
274
}
261
275
276
+ @ Test
277
+ public void defaultRequestWhenClientCredentialsThenAuthorizedClient () {
278
+ this .registration = TestClientRegistrations .clientCredentials ().build ();
279
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
280
+ this .authorizedClientRepository );
281
+ this .function .setClientCredentialsTokenResponseClient (this .clientCredentialsTokenResponseClient );
282
+ when (this .clientRegistrationRepository .findByRegistrationId (any ())).thenReturn (this .registration );
283
+ OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
284
+ .accessTokenResponse ().build ();
285
+ when (this .clientCredentialsTokenResponseClient .getTokenResponse (any ())).thenReturn (
286
+ accessTokenResponse );
287
+
288
+ clientRegistrationId (this .registration .getRegistrationId ()).accept (this .result );
289
+
290
+ Map <String , Object > attrs = getDefaultRequestAttributes ();
291
+ OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient (attrs );
292
+
293
+ assertThat (authorizedClient .getAccessToken ()).isEqualTo (accessTokenResponse .getAccessToken ());
294
+ assertThat (authorizedClient .getClientRegistration ()).isEqualTo (this .registration );
295
+ assertThat (authorizedClient .getPrincipalName ()).isEqualTo ("anonymousUser" );
296
+ assertThat (authorizedClient .getRefreshToken ()).isEqualTo (accessTokenResponse .getRefreshToken ());
297
+ }
298
+
299
+ @ Test
300
+ public void defaultRequestWhenClientIdNotFoundThenIllegalArgumentException () {
301
+ this .registration = TestClientRegistrations .clientCredentials ().build ();
302
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
303
+ this .authorizedClientRepository );
304
+
305
+ clientRegistrationId (this .registration .getRegistrationId ()).accept (this .result );
306
+
307
+ assertThatCode (() -> getDefaultRequestAttributes ())
308
+ .isInstanceOf (IllegalArgumentException .class );
309
+ }
310
+
262
311
private Map <String , Object > getDefaultRequestAttributes () {
263
312
this .function .defaultRequest ().accept (this .spec );
264
313
verify (this .spec ).attributes (this .attrs .capture ());
@@ -322,7 +371,8 @@ public void filterWhenRefreshRequiredThenRefresh() {
322
371
this .accessToken .getTokenValue (),
323
372
issuedAt ,
324
373
accessTokenExpiresAt );
325
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
374
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
375
+ this .authorizedClientRepository );
326
376
327
377
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ("refresh-token" , issuedAt , refreshTokenExpiresAt );
328
378
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient (this .registration ,
@@ -368,7 +418,8 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
368
418
this .accessToken .getTokenValue (),
369
419
issuedAt ,
370
420
accessTokenExpiresAt );
371
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
421
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
422
+ this .authorizedClientRepository );
372
423
373
424
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ("refresh-token" , issuedAt , refreshTokenExpiresAt );
374
425
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient (this .registration ,
@@ -400,7 +451,8 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
400
451
401
452
@ Test
402
453
public void filterWhenRefreshTokenNullThenShouldRefreshFalse () {
403
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
454
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
455
+ this .authorizedClientRepository );
404
456
405
457
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient (this .registration ,
406
458
"principalName" , this .accessToken );
@@ -422,7 +474,8 @@ public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
422
474
423
475
@ Test
424
476
public void filterWhenNotExpiredThenShouldRefreshFalse () {
425
- this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .authorizedClientRepository );
477
+ this .function = new ServletOAuth2AuthorizedClientExchangeFilterFunction (this .clientRegistrationRepository ,
478
+ this .authorizedClientRepository );
426
479
427
480
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken ("refresh-token" , this .accessToken .getIssuedAt (), this .accessToken .getExpiresAt ());
428
481
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient (this .registration ,
0 commit comments