16
16
17
17
package org .springframework .security .oauth2 .client .web .reactive .function .client ;
18
18
19
+ import com .sun .security .ntlm .Server ;
19
20
import org .springframework .http .HttpHeaders ;
20
21
import org .springframework .http .HttpMethod ;
21
22
import org .springframework .http .MediaType ;
@@ -211,14 +212,25 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
211
212
212
213
@ Override
213
214
public Mono <ClientResponse > filter (ClientRequest request , ExchangeFunction next ) {
214
- ServerWebExchange exchange = (ServerWebExchange ) request .attributes ().get (SERVER_WEB_EXCHANGE_ATTR_NAME );
215
215
return authorizedClient (request )
216
- .flatMap (authorizedClient -> refreshIfNecessary (next , authorizedClient , exchange ))
216
+ .flatMap (authorizedClient -> refreshIfNecessary (next , authorizedClient , request ))
217
217
.map (authorizedClient -> bearer (request , authorizedClient ))
218
218
.flatMap (next ::exchange )
219
219
.switchIfEmpty (next .exchange (request ));
220
220
}
221
221
222
+ private Mono <ServerWebExchange > serverWebExchange (ClientRequest request ) {
223
+ ServerWebExchange exchange = (ServerWebExchange ) request .attributes ().get (SERVER_WEB_EXCHANGE_ATTR_NAME );
224
+ return Mono .justOrEmpty (exchange )
225
+ .switchIfEmpty (serverWebExchange ());
226
+ }
227
+
228
+ private Mono <ServerWebExchange > serverWebExchange () {
229
+ return Mono .subscriberContext ()
230
+ .filter (c -> c .hasKey (ServerWebExchange .class ))
231
+ .map (c -> c .get (ServerWebExchange .class ));
232
+ }
233
+
222
234
private Mono <OAuth2AuthorizedClient > authorizedClient (ClientRequest request ) {
223
235
Optional <OAuth2AuthorizedClient > attribute = request .attribute (OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME )
224
236
.map (OAuth2AuthorizedClient .class ::cast );
@@ -231,10 +243,9 @@ private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(Client
231
243
return Mono .empty ();
232
244
}
233
245
234
- ServerWebExchange exchange = (ServerWebExchange ) request .attributes ().get (SERVER_WEB_EXCHANGE_ATTR_NAME );
235
246
return currentAuthentication ()
236
247
.flatMap (principal -> clientRegistrationId (request , principal )
237
- .flatMap (clientRegistrationId -> loadAuthorizedClient (clientRegistrationId , exchange , principal ))
248
+ .flatMap (clientRegistrationId -> serverWebExchange ( request ). flatMap ( exchange -> loadAuthorizedClient (clientRegistrationId , exchange , principal ) ))
238
249
);
239
250
}
240
251
@@ -289,9 +300,10 @@ private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistratio
289
300
});
290
301
}
291
302
292
- private Mono <OAuth2AuthorizedClient > refreshIfNecessary (ExchangeFunction next , OAuth2AuthorizedClient authorizedClient , ServerWebExchange exchange ) {
303
+ private Mono <OAuth2AuthorizedClient > refreshIfNecessary (ExchangeFunction next , OAuth2AuthorizedClient authorizedClient , ClientRequest request ) {
293
304
if (shouldRefresh (authorizedClient )) {
294
- return refreshAuthorizedClient (next , authorizedClient , exchange );
305
+ return serverWebExchange (request )
306
+ .flatMap (exchange -> refreshAuthorizedClient (next , authorizedClient , exchange ));
295
307
}
296
308
return Mono .just (authorizedClient );
297
309
}
0 commit comments