24
24
import org .springframework .security .core .context .ReactiveSecurityContextHolder ;
25
25
import org .springframework .security .core .context .SecurityContext ;
26
26
import org .springframework .security .oauth2 .client .OAuth2AuthorizedClient ;
27
- import org .springframework .security .oauth2 .client .ReactiveOAuth2AuthorizedClientService ;
28
27
import org .springframework .security .oauth2 .client .registration .ClientRegistration ;
28
+ import org .springframework .security .oauth2 .client .web .server .ServerOAuth2AuthorizedClientRepository ;
29
29
import org .springframework .security .oauth2 .core .AuthorizationGrantType ;
30
30
import org .springframework .security .oauth2 .core .OAuth2RefreshToken ;
31
31
import org .springframework .util .Assert ;
34
34
import org .springframework .web .reactive .function .client .ClientResponse ;
35
35
import org .springframework .web .reactive .function .client .ExchangeFilterFunction ;
36
36
import org .springframework .web .reactive .function .client .ExchangeFunction ;
37
+ import org .springframework .web .server .ServerWebExchange ;
37
38
import reactor .core .publisher .Mono ;
38
39
39
40
import java .net .URI ;
@@ -60,16 +61,22 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
60
61
*/
61
62
private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient .class .getName ();
62
63
64
+ /**
65
+ * The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
66
+ */
67
+ private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange .class .getName ();
68
+
63
69
private Clock clock = Clock .systemUTC ();
64
70
65
71
private Duration accessTokenExpiresSkew = Duration .ofMinutes (1 );
66
72
67
- private ReactiveOAuth2AuthorizedClientService authorizedClientService ;
73
+ private ServerOAuth2AuthorizedClientRepository authorizedClientRepository ;
68
74
69
75
public ServerOAuth2AuthorizedClientExchangeFilterFunction () {}
70
76
71
- public ServerOAuth2AuthorizedClientExchangeFilterFunction (ReactiveOAuth2AuthorizedClientService authorizedClientService ) {
72
- this .authorizedClientService = authorizedClientService ;
77
+ public ServerOAuth2AuthorizedClientExchangeFilterFunction (
78
+ ServerOAuth2AuthorizedClientRepository authorizedClientRepository ) {
79
+ this .authorizedClientRepository = authorizedClientRepository ;
73
80
}
74
81
75
82
/**
@@ -78,7 +85,7 @@ public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2Authoriz
78
85
*
79
86
* <pre>
80
87
* WebClient webClient = WebClient.builder()
81
- * .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService ))
88
+ * .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository ))
82
89
* .build();
83
90
* Mono<String> response = webClient
84
91
* .get()
@@ -110,6 +117,30 @@ public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2Authori
110
117
return attributes -> attributes .put (OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME , authorizedClient );
111
118
}
112
119
120
+
121
+ /**
122
+ * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
123
+ * providing the Bearer Token. Example usage:
124
+ *
125
+ * <pre>
126
+ * WebClient webClient = WebClient.builder()
127
+ * .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
128
+ * .build();
129
+ * Mono<String> response = webClient
130
+ * .get()
131
+ * .uri(uri)
132
+ * .attributes(serverWebExchange(serverWebExchange))
133
+ * // ...
134
+ * .retrieve()
135
+ * .bodyToMono(String.class);
136
+ * </pre>
137
+ * @param serverWebExchange the {@link ServerWebExchange} to use
138
+ * @return the {@link Consumer} to populate the client request attributes
139
+ */
140
+ public static Consumer <Map <String , Object >> serverWebExchange (ServerWebExchange serverWebExchange ) {
141
+ return attributes -> attributes .put (SERVER_WEB_EXCHANGE_ATTR_NAME , serverWebExchange );
142
+ }
143
+
113
144
/**
114
145
* An access token will be considered expired by comparing its expiration to now +
115
146
* this skewed Duration. The default is 1 minute.
@@ -124,22 +155,23 @@ public void setAccessTokenExpiresSkew(Duration accessTokenExpiresSkew) {
124
155
public Mono <ClientResponse > filter (ClientRequest request , ExchangeFunction next ) {
125
156
Optional <OAuth2AuthorizedClient > attribute = request .attribute (OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME )
126
157
.map (OAuth2AuthorizedClient .class ::cast );
158
+ ServerWebExchange exchange = (ServerWebExchange ) request .attributes ().get (SERVER_WEB_EXCHANGE_ATTR_NAME );
127
159
return Mono .justOrEmpty (attribute )
128
- .flatMap (authorizedClient -> authorizedClient (next , authorizedClient ))
160
+ .flatMap (authorizedClient -> authorizedClient (next , authorizedClient , exchange ))
129
161
.map (authorizedClient -> bearer (request , authorizedClient ))
130
162
.flatMap (next ::exchange )
131
163
.switchIfEmpty (next .exchange (request ));
132
164
}
133
165
134
- private Mono <OAuth2AuthorizedClient > authorizedClient (ExchangeFunction next , OAuth2AuthorizedClient authorizedClient ) {
166
+ private Mono <OAuth2AuthorizedClient > authorizedClient (ExchangeFunction next , OAuth2AuthorizedClient authorizedClient , ServerWebExchange exchange ) {
135
167
if (shouldRefresh (authorizedClient )) {
136
- return refreshAuthorizedClient (next , authorizedClient );
168
+ return refreshAuthorizedClient (next , authorizedClient , exchange );
137
169
}
138
170
return Mono .just (authorizedClient );
139
171
}
140
172
141
173
private Mono <OAuth2AuthorizedClient > refreshAuthorizedClient (ExchangeFunction next ,
142
- OAuth2AuthorizedClient authorizedClient ) {
174
+ OAuth2AuthorizedClient authorizedClient , ServerWebExchange exchange ) {
143
175
ClientRegistration clientRegistration = authorizedClient
144
176
.getClientRegistration ();
145
177
String tokenUri = clientRegistration
@@ -155,12 +187,12 @@ private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction ne
155
187
.flatMap (result -> ReactiveSecurityContextHolder .getContext ()
156
188
.map (SecurityContext ::getAuthentication )
157
189
.defaultIfEmpty (new PrincipalNameAuthentication (authorizedClient .getPrincipalName ()))
158
- .flatMap (principal -> this .authorizedClientService .saveAuthorizedClient (result , principal ))
190
+ .flatMap (principal -> this .authorizedClientRepository .saveAuthorizedClient (result , principal , exchange ))
159
191
.thenReturn (result ));
160
192
}
161
193
162
194
private boolean shouldRefresh (OAuth2AuthorizedClient authorizedClient ) {
163
- if (this .authorizedClientService == null ) {
195
+ if (this .authorizedClientRepository == null ) {
164
196
return false ;
165
197
}
166
198
OAuth2RefreshToken refreshToken = authorizedClient .getRefreshToken ();
0 commit comments