19
19
import org .springframework .http .HttpHeaders ;
20
20
import org .springframework .http .HttpMethod ;
21
21
import org .springframework .http .MediaType ;
22
+ import org .springframework .security .authentication .AnonymousAuthenticationToken ;
22
23
import org .springframework .security .core .Authentication ;
23
24
import org .springframework .security .core .GrantedAuthority ;
25
+ import org .springframework .security .core .authority .AuthorityUtils ;
24
26
import org .springframework .security .core .context .ReactiveSecurityContextHolder ;
25
27
import org .springframework .security .core .context .SecurityContext ;
28
+ import org .springframework .security .oauth2 .client .ClientAuthorizationRequiredException ;
26
29
import org .springframework .security .oauth2 .client .OAuth2AuthorizedClient ;
30
+ import org .springframework .security .oauth2 .client .OAuth2ClientException ;
27
31
import org .springframework .security .oauth2 .client .registration .ClientRegistration ;
28
32
import org .springframework .security .oauth2 .client .web .server .ServerOAuth2AuthorizedClientRepository ;
29
33
import org .springframework .security .oauth2 .core .AuthorizationGrantType ;
34
+ import org .springframework .security .oauth2 .core .OAuth2AuthenticationException ;
30
35
import org .springframework .security .oauth2 .core .OAuth2RefreshToken ;
31
36
import org .springframework .util .Assert ;
32
37
import org .springframework .web .reactive .function .BodyInserters ;
@@ -61,10 +66,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
61
66
*/
62
67
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient .class .getName ();
63
68
69
+ /**
70
+ * The client request attribute name used to locate the {@link ClientRegistration#getRegistrationId()}
71
+ */
72
+ private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient .class .getName ().concat (".CLIENT_REGISTRATION_ID" );
73
+
64
74
/**
65
75
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
66
76
*/
67
77
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange .class .getName ();
78
+ public static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken ("anonymous" , "anonymousUser" ,
79
+ AuthorityUtils .createAuthorityList ("ROLE_USER" ));
68
80
69
81
private Clock clock = Clock .systemUTC ();
70
82
@@ -74,8 +86,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
74
86
75
87
public ServerOAuth2AuthorizedClientExchangeFilterFunction () {}
76
88
77
- public ServerOAuth2AuthorizedClientExchangeFilterFunction (
78
- ServerOAuth2AuthorizedClientRepository authorizedClientRepository ) {
89
+ public ServerOAuth2AuthorizedClientExchangeFilterFunction (ServerOAuth2AuthorizedClientRepository authorizedClientRepository ) {
79
90
this .authorizedClientRepository = authorizedClientRepository ;
80
91
}
81
92
@@ -141,6 +152,18 @@ public static Consumer<Map<String, Object>> serverWebExchange(ServerWebExchange
141
152
return attributes -> attributes .put (SERVER_WEB_EXCHANGE_ATTR_NAME , serverWebExchange );
142
153
}
143
154
155
+ /**
156
+ * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
157
+ * be used to look up the {@link OAuth2AuthorizedClient}.
158
+ *
159
+ * @param clientRegistrationId the {@link ClientRegistration#getRegistrationId()} to
160
+ * be used to look up the {@link OAuth2AuthorizedClient}.
161
+ * @return the {@link Consumer} to populate the attributes
162
+ */
163
+ public static Consumer <Map <String , Object >> clientRegistrationId (String clientRegistrationId ) {
164
+ return attributes -> attributes .put (CLIENT_REGISTRATION_ID_ATTR_NAME , clientRegistrationId );
165
+ }
166
+
144
167
/**
145
168
* An access token will be considered expired by comparing its expiration to now +
146
169
* this skewed Duration. The default is 1 minute.
@@ -153,17 +176,42 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
153
176
154
177
@ Override
155
178
public Mono <ClientResponse > filter (ClientRequest request , ExchangeFunction next ) {
156
- Optional <OAuth2AuthorizedClient > attribute = request .attribute (OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME )
157
- .map (OAuth2AuthorizedClient .class ::cast );
158
179
ServerWebExchange exchange = (ServerWebExchange ) request .attributes ().get (SERVER_WEB_EXCHANGE_ATTR_NAME );
159
- return Mono . justOrEmpty ( attribute )
160
- .flatMap (authorizedClient -> authorizedClient (next , authorizedClient , exchange ))
180
+ return authorizedClient ( request )
181
+ .flatMap (authorizedClient -> refreshIfNecessary (next , authorizedClient , exchange ))
161
182
.map (authorizedClient -> bearer (request , authorizedClient ))
162
183
.flatMap (next ::exchange )
163
184
.switchIfEmpty (next .exchange (request ));
164
185
}
165
186
166
- private Mono <OAuth2AuthorizedClient > authorizedClient (ExchangeFunction next , OAuth2AuthorizedClient authorizedClient , ServerWebExchange exchange ) {
187
+ private Mono <OAuth2AuthorizedClient > authorizedClient (ClientRequest request ) {
188
+ Optional <OAuth2AuthorizedClient > attribute = request .attribute (OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME )
189
+ .map (OAuth2AuthorizedClient .class ::cast );
190
+ return Mono .justOrEmpty (attribute )
191
+ .switchIfEmpty (findAuthorizedClientByRegistrationId (request ));
192
+ }
193
+
194
+ private Mono <OAuth2AuthorizedClient > findAuthorizedClientByRegistrationId (ClientRequest request ) {
195
+ if (this .authorizedClientRepository == null ) {
196
+ return Mono .empty ();
197
+ }
198
+ String clientRegistrationId = (String ) request .attributes ().get (CLIENT_REGISTRATION_ID_ATTR_NAME );
199
+ if (clientRegistrationId == null ) {
200
+ return Mono .empty ();
201
+ }
202
+ ServerWebExchange exchange = (ServerWebExchange ) request .attributes ().get (SERVER_WEB_EXCHANGE_ATTR_NAME );
203
+ return currentAuthentication ()
204
+ .flatMap (principal -> loadAuthorizedClient (clientRegistrationId , exchange , principal )
205
+ );
206
+ }
207
+
208
+ private Mono <OAuth2AuthorizedClient > loadAuthorizedClient (String clientRegistrationId ,
209
+ ServerWebExchange exchange , Authentication principal ) {
210
+ return this .authorizedClientRepository .loadAuthorizedClient (clientRegistrationId , principal , exchange )
211
+ .switchIfEmpty (Mono .error (() -> new ClientAuthorizationRequiredException (clientRegistrationId )));
212
+ }
213
+
214
+ private Mono <OAuth2AuthorizedClient > refreshIfNecessary (ExchangeFunction next , OAuth2AuthorizedClient authorizedClient , ServerWebExchange exchange ) {
167
215
if (shouldRefresh (authorizedClient )) {
168
216
return refreshAuthorizedClient (next , authorizedClient , exchange );
169
217
}
@@ -184,13 +232,18 @@ private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction ne
184
232
return next .exchange (request )
185
233
.flatMap (response -> response .body (oauth2AccessTokenResponse ()))
186
234
.map (accessTokenResponse -> new OAuth2AuthorizedClient (authorizedClient .getClientRegistration (), authorizedClient .getPrincipalName (), accessTokenResponse .getAccessToken (), accessTokenResponse .getRefreshToken ()))
187
- .flatMap (result -> ReactiveSecurityContextHolder .getContext ()
188
- .map (SecurityContext ::getAuthentication )
235
+ .flatMap (result -> currentAuthentication ()
189
236
.defaultIfEmpty (new PrincipalNameAuthentication (authorizedClient .getPrincipalName ()))
190
237
.flatMap (principal -> this .authorizedClientRepository .saveAuthorizedClient (result , principal , exchange ))
191
238
.thenReturn (result ));
192
239
}
193
240
241
+ private Mono <Authentication > currentAuthentication () {
242
+ return ReactiveSecurityContextHolder .getContext ()
243
+ .map (SecurityContext ::getAuthentication )
244
+ .defaultIfEmpty (ANONYMOUS_USER_TOKEN );
245
+ }
246
+
194
247
private boolean shouldRefresh (OAuth2AuthorizedClient authorizedClient ) {
195
248
if (this .authorizedClientRepository == null ) {
196
249
return false ;
0 commit comments