Skip to content

Commit 23726ab

Browse files
committed
ServerOAuth2AuthorizedClientExchangeFilterFunction default ServerWebExchange
Leverage ServerWebExchange established by ServerWebExchangeReactorContextWebFilter Issue: gh-4921
1 parent ac78258 commit 23726ab

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.security.oauth2.client.web.reactive.function.client;
1818

19+
import com.sun.security.ntlm.Server;
1920
import org.springframework.http.HttpHeaders;
2021
import org.springframework.http.HttpMethod;
2122
import org.springframework.http.MediaType;
@@ -211,14 +212,25 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
211212

212213
@Override
213214
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
214-
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
215215
return authorizedClient(request)
216-
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, exchange))
216+
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request))
217217
.map(authorizedClient -> bearer(request, authorizedClient))
218218
.flatMap(next::exchange)
219219
.switchIfEmpty(next.exchange(request));
220220
}
221221

222+
private Mono<ServerWebExchange> serverWebExchange(ClientRequest request) {
223+
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
224+
return Mono.justOrEmpty(exchange)
225+
.switchIfEmpty(serverWebExchange());
226+
}
227+
228+
private Mono<ServerWebExchange> serverWebExchange() {
229+
return Mono.subscriberContext()
230+
.filter(c -> c.hasKey(ServerWebExchange.class))
231+
.map(c -> c.get(ServerWebExchange.class));
232+
}
233+
222234
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
223235
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
224236
.map(OAuth2AuthorizedClient.class::cast);
@@ -231,10 +243,9 @@ private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(Client
231243
return Mono.empty();
232244
}
233245

234-
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
235246
return currentAuthentication()
236247
.flatMap(principal -> clientRegistrationId(request, principal)
237-
.flatMap(clientRegistrationId -> loadAuthorizedClient(clientRegistrationId, exchange, principal))
248+
.flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal)))
238249
);
239250
}
240251

@@ -289,9 +300,10 @@ private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistratio
289300
});
290301
}
291302

292-
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
303+
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
293304
if (shouldRefresh(authorizedClient)) {
294-
return refreshAuthorizedClient(next, authorizedClient, exchange);
305+
return serverWebExchange(request)
306+
.flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange));
295307
}
296308
return Mono.just(authorizedClient);
297309
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@
4949
import org.springframework.security.oauth2.core.user.OAuth2User;
5050
import org.springframework.web.reactive.function.BodyInserter;
5151
import org.springframework.web.reactive.function.client.ClientRequest;
52+
import org.springframework.web.server.ServerWebExchange;
5253
import reactor.core.publisher.Mono;
54+
import reactor.util.context.Context;
5355

5456
import java.net.URI;
5557
import java.time.Duration;
@@ -83,6 +85,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
8385
@Mock
8486
private ReactiveClientRegistrationRepository clientRegistrationRepository;
8587

88+
@Mock
89+
private ServerWebExchange serverWebExchange;
90+
8691
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
8792

8893
private MockExchangeFunction exchange = new MockExchangeFunction();
@@ -352,6 +357,30 @@ public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
352357
verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
353358
}
354359

360+
@Test
361+
public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() {
362+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
363+
364+
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
365+
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
366+
"principalName", this.accessToken, refreshToken);
367+
when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
368+
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
369+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
370+
.attributes(clientRegistrationId(this.registration.getRegistrationId()))
371+
.build();
372+
373+
this.function.filter(request, this.exchange)
374+
.subscriberContext(serverWebExchange())
375+
.block();
376+
377+
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), eq(this.serverWebExchange));
378+
}
379+
380+
private Context serverWebExchange() {
381+
return Context.of(ServerWebExchange.class, this.serverWebExchange);
382+
}
383+
355384
private static String getBody(ClientRequest request) {
356385
final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
357386
messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));

0 commit comments

Comments
 (0)