Skip to content

Commit 5bcbb1c

Browse files
committed
ServerOAuth2AuthorizedClientExchangeFilterFunction uses ServerOAuth2AuthorizedClientRepository
Issue: gh-4921
1 parent 07b6699 commit 5bcbb1c

File tree

2 files changed

+53
-21
lines changed

2 files changed

+53
-21
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
2525
import org.springframework.security.core.context.SecurityContext;
2626
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
27-
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
2827
import org.springframework.security.oauth2.client.registration.ClientRegistration;
28+
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
2929
import org.springframework.security.oauth2.core.AuthorizationGrantType;
3030
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
3131
import org.springframework.util.Assert;
@@ -34,6 +34,7 @@
3434
import org.springframework.web.reactive.function.client.ClientResponse;
3535
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
3636
import org.springframework.web.reactive.function.client.ExchangeFunction;
37+
import org.springframework.web.server.ServerWebExchange;
3738
import reactor.core.publisher.Mono;
3839

3940
import java.net.URI;
@@ -60,16 +61,22 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
6061
*/
6162
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
6263

64+
/**
65+
* The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
66+
*/
67+
private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
68+
6369
private Clock clock = Clock.systemUTC();
6470

6571
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
6672

67-
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
73+
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
6874

6975
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
7076

71-
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientService authorizedClientService) {
72-
this.authorizedClientService = authorizedClientService;
77+
public ServerOAuth2AuthorizedClientExchangeFilterFunction(
78+
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
79+
this.authorizedClientRepository = authorizedClientRepository;
7380
}
7481

7582
/**
@@ -78,7 +85,7 @@ public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2Authoriz
7885
*
7986
* <pre>
8087
* WebClient webClient = WebClient.builder()
81-
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService))
88+
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
8289
* .build();
8390
* Mono<String> response = webClient
8491
* .get()
@@ -110,6 +117,30 @@ public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2Authori
110117
return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
111118
}
112119

120+
121+
/**
122+
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
123+
* providing the Bearer Token. Example usage:
124+
*
125+
* <pre>
126+
* WebClient webClient = WebClient.builder()
127+
* .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
128+
* .build();
129+
* Mono<String> response = webClient
130+
* .get()
131+
* .uri(uri)
132+
* .attributes(serverWebExchange(serverWebExchange))
133+
* // ...
134+
* .retrieve()
135+
* .bodyToMono(String.class);
136+
* </pre>
137+
* @param serverWebExchange the {@link ServerWebExchange} to use
138+
* @return the {@link Consumer} to populate the client request attributes
139+
*/
140+
public static Consumer<Map<String, Object>> serverWebExchange(ServerWebExchange serverWebExchange) {
141+
return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
142+
}
143+
113144
/**
114145
* An access token will be considered expired by comparing its expiration to now +
115146
* this skewed Duration. The default is 1 minute.
@@ -124,22 +155,23 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
124155
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
125156
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
126157
.map(OAuth2AuthorizedClient.class::cast);
158+
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
127159
return Mono.justOrEmpty(attribute)
128-
.flatMap(authorizedClient -> authorizedClient(next, authorizedClient))
160+
.flatMap(authorizedClient -> authorizedClient(next, authorizedClient, exchange))
129161
.map(authorizedClient -> bearer(request, authorizedClient))
130162
.flatMap(next::exchange)
131163
.switchIfEmpty(next.exchange(request));
132164
}
133165

134-
private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
166+
private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
135167
if (shouldRefresh(authorizedClient)) {
136-
return refreshAuthorizedClient(next, authorizedClient);
168+
return refreshAuthorizedClient(next, authorizedClient, exchange);
137169
}
138170
return Mono.just(authorizedClient);
139171
}
140172

141173
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
142-
OAuth2AuthorizedClient authorizedClient) {
174+
OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
143175
ClientRegistration clientRegistration = authorizedClient
144176
.getClientRegistration();
145177
String tokenUri = clientRegistration
@@ -155,12 +187,12 @@ private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction ne
155187
.flatMap(result -> ReactiveSecurityContextHolder.getContext()
156188
.map(SecurityContext::getAuthentication)
157189
.defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
158-
.flatMap(principal -> this.authorizedClientService.saveAuthorizedClient(result, principal))
190+
.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
159191
.thenReturn(result));
160192
}
161193

162194
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
163-
if (this.authorizedClientService == null) {
195+
if (this.authorizedClientRepository == null) {
164196
return false;
165197
}
166198
OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
import org.springframework.security.authentication.TestingAuthenticationToken;
3737
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
3838
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
39-
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
4039
import org.springframework.security.oauth2.client.registration.ClientRegistration;
4140
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
41+
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
4242
import org.springframework.security.oauth2.core.OAuth2AccessToken;
4343
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
4444
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -70,7 +70,7 @@
7070
@RunWith(MockitoJUnitRunner.class)
7171
public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
7272
@Mock
73-
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
73+
private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
7474

7575
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
7676

@@ -124,7 +124,7 @@ public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
124124

125125
@Test
126126
public void filterWhenRefreshRequiredThenRefresh() {
127-
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty());
127+
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
128128
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
129129
.tokenType(OAuth2AccessToken.TokenType.BEARER)
130130
.expiresIn(3600)
@@ -139,7 +139,7 @@ public void filterWhenRefreshRequiredThenRefresh() {
139139
this.accessToken.getTokenValue(),
140140
issuedAt,
141141
accessTokenExpiresAt);
142-
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
142+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
143143

144144
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
145145
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -153,7 +153,7 @@ public void filterWhenRefreshRequiredThenRefresh() {
153153
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
154154
.block();
155155

156-
verify(this.authorizedClientService).saveAuthorizedClient(any(), eq(authentication));
156+
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
157157

158158
List<ClientRequest> requests = this.exchange.getRequests();
159159
assertThat(requests).hasSize(2);
@@ -173,7 +173,7 @@ public void filterWhenRefreshRequiredThenRefresh() {
173173

174174
@Test
175175
public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
176-
when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty());
176+
when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
177177
OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
178178
.tokenType(OAuth2AccessToken.TokenType.BEARER)
179179
.expiresIn(3600)
@@ -188,7 +188,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
188188
this.accessToken.getTokenValue(),
189189
issuedAt,
190190
accessTokenExpiresAt);
191-
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
191+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
192192

193193
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
194194
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -200,7 +200,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
200200
this.function.filter(request, this.exchange)
201201
.block();
202202

203-
verify(this.authorizedClientService).saveAuthorizedClient(any(), any());
203+
verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
204204

205205
List<ClientRequest> requests = this.exchange.getRequests();
206206
assertThat(requests).hasSize(2);
@@ -220,7 +220,7 @@ public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved()
220220

221221
@Test
222222
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
223-
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
223+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
224224

225225
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
226226
"principalName", this.accessToken);
@@ -242,7 +242,7 @@ public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
242242

243243
@Test
244244
public void filterWhenNotExpiredThenShouldRefreshFalse() {
245-
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
245+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
246246

247247
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
248248
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,

0 commit comments

Comments
 (0)