Skip to content

Commit 214cfe8

Browse files
committed
Allow Jwt assertion to be resolved
Closes gh-9812
1 parent 1ab0705 commit 214cfe8

File tree

6 files changed

+131
-15
lines changed

6 files changed

+131
-15
lines changed

docs/modules/ROOT/pages/reactive/oauth2/client/authorization-grants.adoc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,3 +1092,9 @@ class OAuth2ResourceServerController {
10921092
}
10931093
----
10941094
====
1095+
1096+
[NOTE]
1097+
`JwtBearerReactiveOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example.
1098+
1099+
[TIP]
1100+
If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerReactiveOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function<OAuth2AuthorizationContext, Mono<Jwt>>`.

docs/modules/ROOT/pages/servlet/oauth2/client/authorization-grants.adoc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,3 +1270,9 @@ class OAuth2ResourceServerController {
12701270
}
12711271
----
12721272
====
1273+
1274+
[NOTE]
1275+
`JwtBearerOAuth2AuthorizedClientProvider` resolves the `Jwt` assertion via `OAuth2AuthorizationContext.getPrincipal().getPrincipal()` by default, hence the use of `JwtAuthenticationToken` in the preceding example.
1276+
1277+
[TIP]
1278+
If you need to resolve the `Jwt` assertion from a different source, you can provide `JwtBearerOAuth2AuthorizedClientProvider.setJwtAssertionResolver()` with a custom `Function<OAuth2AuthorizationContext, Jwt>`.

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProvider.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@
1919
import java.time.Clock;
2020
import java.time.Duration;
2121
import java.time.Instant;
22+
import java.util.function.Function;
2223

2324
import org.springframework.lang.Nullable;
2425
import org.springframework.security.oauth2.client.endpoint.DefaultJwtBearerTokenResponseClient;
@@ -45,6 +46,8 @@ public final class JwtBearerOAuth2AuthorizedClientProvider implements OAuth2Auth
4546

4647
private OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new DefaultJwtBearerTokenResponseClient();
4748

49+
private Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver = this::resolveJwtAssertion;
50+
4851
private Duration clockSkew = Duration.ofSeconds(60);
4952

5053
private Clock clock = Clock.systemUTC();
@@ -75,10 +78,10 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
7578
// need for re-authorization
7679
return null;
7780
}
78-
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
81+
Jwt jwt = this.jwtAssertionResolver.apply(context);
82+
if (jwt == null) {
7983
return null;
8084
}
81-
Jwt jwt = (Jwt) context.getPrincipal().getPrincipal();
8285
// As per spec, in section 4.1 Using Assertions as Authorization Grants
8386
// https://tools.ietf.org/html/rfc7521#section-4.1
8487
//
@@ -97,6 +100,13 @@ public OAuth2AuthorizedClient authorize(OAuth2AuthorizationContext context) {
97100
tokenResponse.getAccessToken());
98101
}
99102

103+
private Jwt resolveJwtAssertion(OAuth2AuthorizationContext context) {
104+
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
105+
return null;
106+
}
107+
return (Jwt) context.getPrincipal().getPrincipal();
108+
}
109+
100110
private OAuth2AccessTokenResponse getTokenResponse(ClientRegistration clientRegistration,
101111
JwtBearerGrantRequest jwtBearerGrantRequest) {
102112
try {
@@ -123,6 +133,17 @@ public void setAccessTokenResponseClient(
123133
this.accessTokenResponseClient = accessTokenResponseClient;
124134
}
125135

136+
/**
137+
* Sets the resolver used for resolving the {@link Jwt} assertion.
138+
* @param jwtAssertionResolver the resolver used for resolving the {@link Jwt}
139+
* assertion
140+
* @since 5.7
141+
*/
142+
public void setJwtAssertionResolver(Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver) {
143+
Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null");
144+
this.jwtAssertionResolver = jwtAssertionResolver;
145+
}
146+
126147
/**
127148
* Sets the maximum acceptable clock skew, which is used when checking the
128149
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProvider.java

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@
1919
import java.time.Clock;
2020
import java.time.Duration;
2121
import java.time.Instant;
22+
import java.util.function.Function;
2223

2324
import reactor.core.publisher.Mono;
2425

@@ -45,6 +46,8 @@ public final class JwtBearerReactiveOAuth2AuthorizedClientProvider implements Re
4546

4647
private ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = new WebClientReactiveJwtBearerTokenResponseClient();
4748

49+
private Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver = this::resolveJwtAssertion;
50+
4851
private Duration clockSkew = Duration.ofSeconds(60);
4952

5053
private Clock clock = Clock.systemUTC();
@@ -74,10 +77,7 @@ public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context
7477
// need for re-authorization
7578
return Mono.empty();
7679
}
77-
if (!(context.getPrincipal().getPrincipal() instanceof Jwt)) {
78-
return Mono.empty();
79-
}
80-
Jwt jwt = (Jwt) context.getPrincipal().getPrincipal();
80+
8181
// As per spec, in section 4.1 Using Assertions as Authorization Grants
8282
// https://tools.ietf.org/html/rfc7521#section-4.1
8383
//
@@ -90,13 +90,26 @@ public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext context
9090
// issued with a reasonably short lifetime. Clients can refresh an
9191
// expired access token by requesting a new one using the same
9292
// assertion, if it is still valid, or with a new assertion.
93-
return Mono.just(new JwtBearerGrantRequest(clientRegistration, jwt))
93+
94+
// @formatter:off
95+
return this.jwtAssertionResolver.apply(context)
96+
.map((jwt) -> new JwtBearerGrantRequest(clientRegistration, jwt))
9497
.flatMap(this.accessTokenResponseClient::getTokenResponse)
9598
.onErrorMap(OAuth2AuthorizationException.class,
9699
(ex) -> new ClientAuthorizationException(ex.getError(), clientRegistration.getRegistrationId(),
97100
ex))
98101
.map((tokenResponse) -> new OAuth2AuthorizedClient(clientRegistration, context.getPrincipal().getName(),
99102
tokenResponse.getAccessToken()));
103+
// @formatter:on
104+
}
105+
106+
private Mono<Jwt> resolveJwtAssertion(OAuth2AuthorizationContext context) {
107+
// @formatter:off
108+
return Mono.just(context)
109+
.map((ctx) -> ctx.getPrincipal().getPrincipal())
110+
.filter((principal) -> principal instanceof Jwt)
111+
.cast(Jwt.class);
112+
// @formatter:on
100113
}
101114

102115
private boolean hasTokenExpired(OAuth2Token token) {
@@ -115,6 +128,17 @@ public void setAccessTokenResponseClient(
115128
this.accessTokenResponseClient = accessTokenResponseClient;
116129
}
117130

131+
/**
132+
* Sets the resolver used for resolving the {@link Jwt} assertion.
133+
* @param jwtAssertionResolver the resolver used for resolving the {@link Jwt}
134+
* assertion
135+
* @since 5.7
136+
*/
137+
public void setJwtAssertionResolver(Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver) {
138+
Assert.notNull(jwtAssertionResolver, "jwtAssertionResolver cannot be null");
139+
this.jwtAssertionResolver = jwtAssertionResolver;
140+
}
141+
118142
/**
119143
* Sets the maximum acceptable clock skew, which is used when checking the
120144
* {@link OAuth2AuthorizedClient#getAccessToken() access token} expiry. The default is

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerOAuth2AuthorizedClientProviderTests.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@
1818

1919
import java.time.Duration;
2020
import java.time.Instant;
21+
import java.util.function.Function;
2122

2223
import org.junit.jupiter.api.BeforeEach;
2324
import org.junit.jupiter.api.Test;
@@ -42,6 +43,7 @@
4243
import static org.mockito.ArgumentMatchers.any;
4344
import static org.mockito.BDDMockito.given;
4445
import static org.mockito.Mockito.mock;
46+
import static org.mockito.Mockito.verify;
4547

4648
/**
4749
* Tests for {@link JwtBearerOAuth2AuthorizedClientProvider}.
@@ -87,6 +89,13 @@ public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgument
8789
.withMessage("accessTokenResponseClient cannot be null");
8890
}
8991

92+
@Test
93+
public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() {
94+
assertThatIllegalArgumentException()
95+
.isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null))
96+
.withMessage("jwtAssertionResolver cannot be null");
97+
}
98+
9099
@Test
91100
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
92101
// @formatter:off
@@ -198,7 +207,7 @@ public void authorizeWhenJwtBearerAndTokenNotExpiredButClockSkewForcesExpiryThen
198207
}
199208

200209
@Test
201-
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
210+
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() {
202211
// @formatter:off
203212
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
204213
.withClientRegistration(this.clientRegistration)
@@ -209,7 +218,7 @@ public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableTo
209218
}
210219

211220
@Test
212-
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() {
221+
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() {
213222
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
214223
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
215224
// @formatter:off
@@ -224,4 +233,25 @@ public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize()
224233
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
225234
}
226235

236+
@Test
237+
public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() {
238+
Function<OAuth2AuthorizationContext, Jwt> jwtAssertionResolver = mock(Function.class);
239+
given(jwtAssertionResolver.apply(any())).willReturn(this.jwtAssertion);
240+
this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver);
241+
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
242+
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
243+
// @formatter:off
244+
TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
245+
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
246+
.withClientRegistration(this.clientRegistration)
247+
.principal(principal)
248+
.build();
249+
// @formatter:on
250+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
251+
verify(jwtAssertionResolver).apply(any());
252+
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
253+
assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
254+
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
255+
}
256+
227257
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JwtBearerReactiveOAuth2AuthorizedClientProviderTests.java

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2022 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,6 +19,7 @@
1919
import java.time.Clock;
2020
import java.time.Duration;
2121
import java.time.Instant;
22+
import java.util.function.Function;
2223

2324
import org.junit.jupiter.api.BeforeEach;
2425
import org.junit.jupiter.api.Test;
@@ -93,6 +94,13 @@ public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgument
9394
.withMessage("accessTokenResponseClient cannot be null");
9495
}
9596

97+
@Test
98+
public void setJwtAssertionResolverWhenNullThenThrowIllegalArgumentException() {
99+
assertThatIllegalArgumentException()
100+
.isThrownBy(() -> this.authorizedClientProvider.setJwtAssertionResolver(null))
101+
.withMessage("jwtAssertionResolver cannot be null");
102+
}
103+
96104
@Test
97105
public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
98106
// @formatter:off
@@ -222,7 +230,7 @@ public void authorizeWhenJwtBearerAndTokenNotExpiredButClockSkewForcesExpiryThen
222230
}
223231

224232
@Test
225-
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalNotJwtThenUnableToAuthorize() {
233+
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtDoesNotResolveThenUnableToAuthorize() {
226234
// @formatter:off
227235
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
228236
.withClientRegistration(this.clientRegistration)
@@ -251,7 +259,7 @@ public void authorizeWhenInvalidRequestThenThrowClientAuthorizationException() {
251259
}
252260

253261
@Test
254-
public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize() {
262+
public void authorizeWhenJwtBearerAndNotAuthorizedAndJwtResolvesThenAuthorize() {
255263
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
256264
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
257265
// @formatter:off
@@ -266,4 +274,25 @@ public void authorizeWhenJwtBearerAndNotAuthorizedAndPrincipalJwtThenAuthorize()
266274
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
267275
}
268276

277+
@Test
278+
public void authorizeWhenCustomJwtAssertionResolverSetThenUsed() {
279+
Function<OAuth2AuthorizationContext, Mono<Jwt>> jwtAssertionResolver = mock(Function.class);
280+
given(jwtAssertionResolver.apply(any())).willReturn(Mono.just(this.jwtAssertion));
281+
this.authorizedClientProvider.setJwtAssertionResolver(jwtAssertionResolver);
282+
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
283+
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
284+
// @formatter:off
285+
TestingAuthenticationToken principal = new TestingAuthenticationToken("user", "password");
286+
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
287+
.withClientRegistration(this.clientRegistration)
288+
.principal(principal)
289+
.build();
290+
// @formatter:on
291+
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
292+
verify(jwtAssertionResolver).apply(any());
293+
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
294+
assertThat(authorizedClient.getPrincipalName()).isEqualTo(principal.getName());
295+
assertThat(authorizedClient.getAccessToken()).isEqualTo(accessTokenResponse.getAccessToken());
296+
}
297+
269298
}

0 commit comments

Comments
 (0)