Skip to content

Commit 7a659e4

Browse files
mbhavesnicoll
authored andcommitted
Polish "Add support for aud claim in resource server"
See gh-29084
1 parent ee65627 commit 7a659e4

File tree

5 files changed

+292
-55
lines changed

5 files changed

+292
-55
lines changed

spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/OAuth2ResourceServerProperties.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import java.io.IOException;
2020
import java.io.InputStream;
2121
import java.nio.charset.StandardCharsets;
22+
import java.util.ArrayList;
23+
import java.util.List;
2224

2325
import org.springframework.boot.context.properties.ConfigurationProperties;
2426
import org.springframework.boot.context.properties.source.InvalidConfigurationPropertyValueException;
@@ -75,7 +77,7 @@ public static class Jwt {
7577
/**
7678
* Identifies the recipients that the JWT is intended for.
7779
*/
78-
private String audience;
80+
private List<String> audiences = new ArrayList<>();
7981

8082
public String getJwkSetUri() {
8183
return this.jwkSetUri;
@@ -109,12 +111,12 @@ public void setPublicKeyLocation(Resource publicKeyLocation) {
109111
this.publicKeyLocation = publicKeyLocation;
110112
}
111113

112-
public String getAudience() {
113-
return this.audience;
114+
public List<String> getAudiences() {
115+
return this.audiences;
114116
}
115117

116-
public void setAudience(String audience) {
117-
this.audience = audience;
118+
public void setAudiences(List<String> audiences) {
119+
this.audiences = audiences;
118120
}
119121

120122
public String readPublicKey() throws IOException {

spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerJwkConfiguration.java

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import java.security.spec.X509EncodedKeySpec;
2222
import java.util.ArrayList;
2323
import java.util.Base64;
24+
import java.util.Collections;
2425
import java.util.List;
26+
import java.util.function.Supplier;
2527

2628
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
2729
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -40,13 +42,13 @@
4042
import org.springframework.security.oauth2.jwt.Jwt;
4143
import org.springframework.security.oauth2.jwt.JwtClaimNames;
4244
import org.springframework.security.oauth2.jwt.JwtClaimValidator;
43-
import org.springframework.security.oauth2.jwt.JwtIssuerValidator;
44-
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
45+
import org.springframework.security.oauth2.jwt.JwtValidators;
4546
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
4647
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
4748
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoders;
4849
import org.springframework.security.oauth2.jwt.SupplierReactiveJwtDecoder;
4950
import org.springframework.security.web.server.SecurityWebFilterChain;
51+
import org.springframework.util.CollectionUtils;
5052

5153
/**
5254
* Configures a {@link ReactiveJwtDecoder} when a JWK Set URI, OpenID Connect Issuer URI
@@ -78,28 +80,35 @@ ReactiveJwtDecoder jwtDecoder() {
7880
NimbusReactiveJwtDecoder nimbusReactiveJwtDecoder = NimbusReactiveJwtDecoder
7981
.withJwkSetUri(this.properties.getJwkSetUri())
8082
.jwsAlgorithm(SignatureAlgorithm.from(this.properties.getJwsAlgorithm())).build();
81-
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
82-
validators.add(new JwtTimestampValidator());
8383
String issuerUri = this.properties.getIssuerUri();
84-
if (issuerUri != null) {
85-
validators.add(new JwtIssuerValidator(issuerUri));
86-
}
87-
String audience = this.properties.getAudience();
88-
if (audience != null) {
89-
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
90-
(aud) -> aud != null && aud.contains(audience)));
91-
}
92-
nimbusReactiveJwtDecoder.setJwtValidator(new DelegatingOAuth2TokenValidator<>(validators));
84+
Supplier<OAuth2TokenValidator<Jwt>> defaultValidator = (issuerUri != null)
85+
? () -> JwtValidators.createDefaultWithIssuer(issuerUri) : JwtValidators::createDefault;
86+
nimbusReactiveJwtDecoder.setJwtValidator(getValidators(defaultValidator));
9387
return nimbusReactiveJwtDecoder;
9488
}
9589

90+
private OAuth2TokenValidator<Jwt> getValidators(Supplier<OAuth2TokenValidator<Jwt>> defaultValidator) {
91+
OAuth2TokenValidator<Jwt> defaultValidators = defaultValidator.get();
92+
List<String> audiences = this.properties.getAudiences();
93+
if (CollectionUtils.isEmpty(audiences)) {
94+
return defaultValidators;
95+
}
96+
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
97+
validators.add(defaultValidators);
98+
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
99+
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
100+
return new DelegatingOAuth2TokenValidator<>(validators);
101+
}
102+
96103
@Bean
97104
@Conditional(KeyValueCondition.class)
98105
NimbusReactiveJwtDecoder jwtDecoderByPublicKeyValue() throws Exception {
99106
RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA")
100107
.generatePublic(new X509EncodedKeySpec(getKeySpec(this.properties.readPublicKey())));
101-
return NimbusReactiveJwtDecoder.withPublicKey(publicKey)
108+
NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withPublicKey(publicKey)
102109
.signatureAlgorithm(SignatureAlgorithm.from(this.properties.getJwsAlgorithm())).build();
110+
jwtDecoder.setJwtValidator(getValidators(JwtValidators::createDefault));
111+
return jwtDecoder;
103112
}
104113

105114
private byte[] getKeySpec(String keyValue) {
@@ -110,8 +119,13 @@ private byte[] getKeySpec(String keyValue) {
110119
@Bean
111120
@Conditional(IssuerUriCondition.class)
112121
SupplierReactiveJwtDecoder jwtDecoderByIssuerUri() {
113-
return new SupplierReactiveJwtDecoder(
114-
() -> ReactiveJwtDecoders.fromIssuerLocation(this.properties.getIssuerUri()));
122+
return new SupplierReactiveJwtDecoder(() -> {
123+
NimbusReactiveJwtDecoder jwtDecoder = (NimbusReactiveJwtDecoder) ReactiveJwtDecoders
124+
.fromIssuerLocation(this.properties.getIssuerUri());
125+
jwtDecoder.setJwtValidator(
126+
getValidators(() -> JwtValidators.createDefaultWithIssuer(this.properties.getIssuerUri())));
127+
return jwtDecoder;
128+
});
115129
}
116130

117131
}

spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/resource/servlet/OAuth2ResourceServerJwtConfiguration.java

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import java.security.spec.X509EncodedKeySpec;
2222
import java.util.ArrayList;
2323
import java.util.Base64;
24+
import java.util.Collections;
2425
import java.util.List;
26+
import java.util.function.Supplier;
2527

2628
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
2729
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -43,11 +45,11 @@
4345
import org.springframework.security.oauth2.jwt.JwtClaimValidator;
4446
import org.springframework.security.oauth2.jwt.JwtDecoder;
4547
import org.springframework.security.oauth2.jwt.JwtDecoders;
46-
import org.springframework.security.oauth2.jwt.JwtIssuerValidator;
47-
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
48+
import org.springframework.security.oauth2.jwt.JwtValidators;
4849
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
4950
import org.springframework.security.oauth2.jwt.SupplierJwtDecoder;
5051
import org.springframework.security.web.SecurityFilterChain;
52+
import org.springframework.util.CollectionUtils;
5153

5254
/**
5355
* Configures a {@link JwtDecoder} when a JWK Set URI, OpenID Connect Issuer URI or Public
@@ -77,28 +79,35 @@ static class JwtDecoderConfiguration {
7779
JwtDecoder jwtDecoderByJwkKeySetUri() {
7880
NimbusJwtDecoder nimbusJwtDecoder = NimbusJwtDecoder.withJwkSetUri(this.properties.getJwkSetUri())
7981
.jwsAlgorithm(SignatureAlgorithm.from(this.properties.getJwsAlgorithm())).build();
80-
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
81-
validators.add(new JwtTimestampValidator());
8282
String issuerUri = this.properties.getIssuerUri();
83-
if (issuerUri != null) {
84-
validators.add(new JwtIssuerValidator(issuerUri));
85-
}
86-
String audience = this.properties.getAudience();
87-
if (audience != null) {
88-
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
89-
(aud) -> aud != null && aud.contains(audience)));
90-
}
91-
nimbusJwtDecoder.setJwtValidator(new DelegatingOAuth2TokenValidator<>(validators));
83+
Supplier<OAuth2TokenValidator<Jwt>> defaultValidator = (issuerUri != null)
84+
? () -> JwtValidators.createDefaultWithIssuer(issuerUri) : JwtValidators::createDefault;
85+
nimbusJwtDecoder.setJwtValidator(getValidators(defaultValidator));
9286
return nimbusJwtDecoder;
9387
}
9488

89+
private OAuth2TokenValidator<Jwt> getValidators(Supplier<OAuth2TokenValidator<Jwt>> defaultValidator) {
90+
OAuth2TokenValidator<Jwt> defaultValidators = defaultValidator.get();
91+
List<String> audiences = this.properties.getAudiences();
92+
if (CollectionUtils.isEmpty(audiences)) {
93+
return defaultValidators;
94+
}
95+
List<OAuth2TokenValidator<Jwt>> validators = new ArrayList<>();
96+
validators.add(defaultValidators);
97+
validators.add(new JwtClaimValidator<List<String>>(JwtClaimNames.AUD,
98+
(aud) -> aud != null && !Collections.disjoint(aud, audiences)));
99+
return new DelegatingOAuth2TokenValidator<>(validators);
100+
}
101+
95102
@Bean
96103
@Conditional(KeyValueCondition.class)
97104
JwtDecoder jwtDecoderByPublicKeyValue() throws Exception {
98105
RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA")
99106
.generatePublic(new X509EncodedKeySpec(getKeySpec(this.properties.readPublicKey())));
100-
return NimbusJwtDecoder.withPublicKey(publicKey)
107+
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(publicKey)
101108
.signatureAlgorithm(SignatureAlgorithm.from(this.properties.getJwsAlgorithm())).build();
109+
jwtDecoder.setJwtValidator(getValidators(JwtValidators::createDefault));
110+
return jwtDecoder;
102111
}
103112

104113
private byte[] getKeySpec(String keyValue) {
@@ -109,7 +118,12 @@ private byte[] getKeySpec(String keyValue) {
109118
@Bean
110119
@Conditional(IssuerUriCondition.class)
111120
SupplierJwtDecoder jwtDecoderByIssuerUri() {
112-
return new SupplierJwtDecoder(() -> JwtDecoders.fromIssuerLocation(this.properties.getIssuerUri()));
121+
return new SupplierJwtDecoder(() -> {
122+
String issuerUri = this.properties.getIssuerUri();
123+
NimbusJwtDecoder jwtDecoder = JwtDecoders.fromIssuerLocation(issuerUri);
124+
jwtDecoder.setJwtValidator(getValidators(() -> JwtValidators.createDefaultWithIssuer(issuerUri)));
125+
return jwtDecoder;
126+
});
113127
}
114128

115129
}

spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/security/oauth2/resource/reactive/ReactiveOAuth2ResourceServerAutoConfigurationTests.java

Lines changed: 114 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
package org.springframework.boot.autoconfigure.security.oauth2.resource.reactive;
1818

1919
import java.io.IOException;
20+
import java.net.MalformedURLException;
21+
import java.net.URL;
22+
import java.time.Instant;
2023
import java.util.Collection;
2124
import java.util.Collections;
2225
import java.util.HashMap;
@@ -423,20 +426,108 @@ void autoConfigurationShouldConfigureIssuerAndAudienceJwtValidatorIfPropertyProv
423426
String issuer = this.server.url(path).toString();
424427
String cleanIssuerPath = cleanIssuerPath(issuer);
425428
setupMockResponse(cleanIssuerPath);
426-
this.contextRunner
427-
.withPropertyValues("spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
428-
"spring.security.oauth2.resourceserver.jwt.issuer-uri=http://" + this.server.getHostName() + ":"
429-
+ this.server.getPort() + "/" + path,
430-
"spring.security.oauth2.resourceserver.jwt.audience=http://test-audience.com")
429+
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
430+
this.contextRunner.withPropertyValues(
431+
"spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
432+
"spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri,
433+
"spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com")
431434
.run((context) -> {
432435
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
433436
ReactiveJwtDecoder reactiveJwtDecoder = context.getBean(ReactiveJwtDecoder.class);
437+
validate(issuerUri, reactiveJwtDecoder);
438+
});
439+
}
440+
441+
@SuppressWarnings("unchecked")
442+
private void validate(String issuerUri, ReactiveJwtDecoder jwtDecoder) throws MalformedURLException {
443+
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
444+
.getField(jwtDecoder, "jwtValidator");
445+
Jwt.Builder builder = jwt().claim("aud", Collections.singletonList("https://test-audience.com"));
446+
if (issuerUri != null) {
447+
builder.claim("iss", new URL(issuerUri));
448+
}
449+
Jwt jwt = builder.build();
450+
assertThat(jwtValidator.validate(jwt).hasErrors()).isFalse();
451+
Collection<OAuth2TokenValidator<Jwt>> delegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
452+
.getField(jwtValidator, "tokenValidators");
453+
validateDelegates(issuerUri, delegates);
454+
}
455+
456+
@SuppressWarnings("unchecked")
457+
private void validateDelegates(String issuerUri, Collection<OAuth2TokenValidator<Jwt>> delegates) {
458+
assertThat(delegates).hasAtLeastOneElementOfType(JwtClaimValidator.class);
459+
OAuth2TokenValidator<Jwt> delegatingValidator = delegates.stream()
460+
.filter((v) -> v instanceof DelegatingOAuth2TokenValidator).findFirst().get();
461+
Collection<OAuth2TokenValidator<Jwt>> nestedDelegates = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
462+
.getField(delegatingValidator, "tokenValidators");
463+
if (issuerUri != null) {
464+
assertThat(nestedDelegates).hasAtLeastOneElementOfType(JwtIssuerValidator.class);
465+
}
466+
}
467+
468+
@SuppressWarnings("unchecked")
469+
@Test
470+
void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndIssuerUri() throws Exception {
471+
this.server = new MockWebServer();
472+
this.server.start();
473+
String path = "test";
474+
String issuer = this.server.url(path).toString();
475+
String cleanIssuerPath = cleanIssuerPath(issuer);
476+
setupMockResponse(cleanIssuerPath);
477+
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
478+
this.contextRunner.withPropertyValues("spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri,
479+
"spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com")
480+
.run((context) -> {
481+
SupplierReactiveJwtDecoder supplierJwtDecoderBean = context
482+
.getBean(SupplierReactiveJwtDecoder.class);
483+
Mono<ReactiveJwtDecoder> jwtDecoderSupplier = (Mono<ReactiveJwtDecoder>) ReflectionTestUtils
484+
.getField(supplierJwtDecoderBean, "jwtDecoderMono");
485+
ReactiveJwtDecoder jwtDecoder = jwtDecoderSupplier.block();
486+
validate(issuerUri, jwtDecoder);
487+
});
488+
}
489+
490+
@SuppressWarnings("unchecked")
491+
@Test
492+
void autoConfigurationShouldConfigureAudienceValidatorIfPropertyProvidedAndPublicKey() throws Exception {
493+
this.server = new MockWebServer();
494+
this.server.start();
495+
String path = "test";
496+
String issuer = this.server.url(path).toString();
497+
String cleanIssuerPath = cleanIssuerPath(issuer);
498+
setupMockResponse(cleanIssuerPath);
499+
this.contextRunner.withPropertyValues(
500+
"spring.security.oauth2.resourceserver.jwt.public-key-location=classpath:public-key-location",
501+
"spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com")
502+
.run((context) -> {
503+
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
504+
ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class);
505+
validate(null, jwtDecoder);
506+
});
507+
}
508+
509+
@SuppressWarnings("unchecked")
510+
@Test
511+
void audienceValidatorWhenAudienceInvalid() throws Exception {
512+
this.server = new MockWebServer();
513+
this.server.start();
514+
String path = "test";
515+
String issuer = this.server.url(path).toString();
516+
String cleanIssuerPath = cleanIssuerPath(issuer);
517+
setupMockResponse(cleanIssuerPath);
518+
String issuerUri = "http://" + this.server.getHostName() + ":" + this.server.getPort() + "/" + path;
519+
this.contextRunner.withPropertyValues(
520+
"spring.security.oauth2.resourceserver.jwt.jwk-set-uri=https://jwk-set-uri.com",
521+
"spring.security.oauth2.resourceserver.jwt.issuer-uri=" + issuerUri,
522+
"spring.security.oauth2.resourceserver.jwt.audiences=https://test-audience.com,https://test-audience1.com")
523+
.run((context) -> {
524+
assertThat(context).hasSingleBean(ReactiveJwtDecoder.class);
525+
ReactiveJwtDecoder jwtDecoder = context.getBean(ReactiveJwtDecoder.class);
434526
DelegatingOAuth2TokenValidator<Jwt> jwtValidator = (DelegatingOAuth2TokenValidator<Jwt>) ReflectionTestUtils
435-
.getField(reactiveJwtDecoder, "jwtValidator");
436-
Collection<OAuth2TokenValidator<Jwt>> tokenValidators = (Collection<OAuth2TokenValidator<Jwt>>) ReflectionTestUtils
437-
.getField(jwtValidator, "tokenValidators");
438-
assertThat(tokenValidators).hasAtLeastOneElementOfType(JwtIssuerValidator.class);
439-
assertThat(tokenValidators).hasAtLeastOneElementOfType(JwtClaimValidator.class);
527+
.getField(jwtDecoder, "jwtValidator");
528+
Jwt jwt = jwt().claim("iss", new URL(issuerUri))
529+
.claim("aud", Collections.singletonList("https://other-audience.com")).build();
530+
assertThat(jwtValidator.validate(jwt).hasErrors()).isTrue();
440531
});
441532
}
442533

@@ -508,6 +599,19 @@ private Map<String, Object> getResponse(String issuer) {
508599
return response;
509600
}
510601

602+
static Jwt.Builder jwt() {
603+
// @formatter:off
604+
return Jwt.withTokenValue("token")
605+
.header("alg", "none")
606+
.expiresAt(Instant.MAX)
607+
.issuedAt(Instant.MIN)
608+
.issuer("https://issuer.example.org")
609+
.jti("jti")
610+
.notBefore(Instant.MIN)
611+
.subject("mock-test-subject");
612+
// @formatter:on
613+
}
614+
511615
@EnableWebFluxSecurity
512616
static class TestConfig {
513617

0 commit comments

Comments
 (0)