Skip to content

Commit 2907864

Browse files
hitchanjzheaux
authored andcommitted
Add Support for JWK Signature Algorithm Discovery
Issue gh-7160
1 parent 4ffc3d6 commit 2907864

File tree

13 files changed

+229
-12
lines changed

13 files changed

+229
-12
lines changed

config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/resource/OAuth2ResourceServerConfigurerTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ public void getWhenUsingDefaultsInLambdaWithValidBearerTokenThenAcceptsRequest()
223223
public void getWhenUsingJwkSetUriThenAcceptsRequest() throws Exception {
224224
this.spring.register(WebServerConfig.class, JwkSetUriConfig.class, BasicController.class).autowire();
225225
mockWebServer(jwks("Default"));
226+
mockWebServer(jwks("Default"));
226227
String token = this.token("ValidNoScopes");
227228
// @formatter:off
228229
this.mvc.perform(get("/").with(bearerToken(token)))
@@ -235,6 +236,7 @@ public void getWhenUsingJwkSetUriThenAcceptsRequest() throws Exception {
235236
public void getWhenUsingJwkSetUriInLambdaThenAcceptsRequest() throws Exception {
236237
this.spring.register(WebServerConfig.class, JwkSetUriInLambdaConfig.class, BasicController.class).autowire();
237238
mockWebServer(jwks("Default"));
239+
mockWebServer(jwks("Default"));
238240
String token = this.token("ValidNoScopes");
239241
// @formatter:off
240242
this.mvc.perform(get("/").with(bearerToken(token)))
@@ -1185,13 +1187,15 @@ public void getWhenMultipleIssuersThenUsesIssuerClaimToDifferentiate() throws Ex
11851187
String jwtThree = jwtFromIssuer(issuerThree);
11861188
mockWebServer(String.format(metadata, issuerOne, issuerOne));
11871189
mockWebServer(jwkSet);
1190+
mockWebServer(jwkSet);
11881191
// @formatter:off
11891192
this.mvc.perform(get("/authenticated").with(bearerToken(jwtOne)))
11901193
.andExpect(status().isOk())
11911194
.andExpect(content().string("test-subject"));
11921195
// @formatter:on
11931196
mockWebServer(String.format(metadata, issuerTwo, issuerTwo));
11941197
mockWebServer(jwkSet);
1198+
mockWebServer(jwkSet);
11951199
// @formatter:off
11961200
this.mvc.perform(get("/authenticated").with(bearerToken(jwtTwo)))
11971201
.andExpect(status().isOk())

config/src/test/java/org/springframework/security/config/http/OAuth2ResourceServerBeanDefinitionParserTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ public void getWhenValidBearerTokenThenAcceptsRequest() throws Exception {
148148
public void getWhenUsingJwkSetUriThenAcceptsRequest() throws Exception {
149149
this.spring.configLocations(xml("WebServer"), xml("JwkSetUri")).autowire();
150150
mockWebServer(jwks("Default"));
151+
mockWebServer(jwks("Default"));
151152
String token = this.token("ValidNoScopes");
152153
// @formatter:off
153154
this.mvc.perform(get("/").header("Authorization", "Bearer " + token))
@@ -709,18 +710,21 @@ public void getWhenMultipleIssuersThenUsesIssuerClaimToDifferentiate() throws Ex
709710
String jwtThree = jwtFromIssuer(issuerThree);
710711
mockWebServer(String.format(metadata, issuerOne, issuerOne));
711712
mockWebServer(jwkSet);
713+
mockWebServer(jwkSet);
712714
// @formatter:off
713715
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtOne))
714716
.andExpect(status().isNotFound());
715717
// @formatter:on
716718
mockWebServer(String.format(metadata, issuerTwo, issuerTwo));
717719
mockWebServer(jwkSet);
720+
mockWebServer(jwkSet);
718721
// @formatter:off
719722
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtTwo))
720723
.andExpect(status().isNotFound());
721724
// @formatter:on
722725
mockWebServer(String.format(metadata, issuerThree, issuerThree));
723726
mockWebServer(jwkSet);
727+
mockWebServer(jwkSet);
724728
// @formatter:off
725729
this.mvc.perform(get("/authenticated").header("Authorization", "Bearer " + jwtThree))
726730
.andExpect(status().isUnauthorized())

config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ public void getWhenUsingJwkSetUriThenConsultsAccordingly() {
261261
this.spring.register(JwkSetUriConfig.class, RootController.class).autowire();
262262
MockWebServer mockWebServer = this.spring.getContext().getBean(MockWebServer.class);
263263
mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet));
264+
mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet));
264265
// @formatter:off
265266
this.client.get()
266267
.headers((headers) -> headers
@@ -276,6 +277,7 @@ public void getWhenUsingJwkSetUriInLambdaThenConsultsAccordingly() {
276277
this.spring.register(JwkSetUriInLambdaConfig.class, RootController.class).autowire();
277278
MockWebServer mockWebServer = this.spring.getContext().getBean(MockWebServer.class);
278279
mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet));
280+
mockWebServer.enqueue(new MockResponse().setBody(this.jwkSet));
279281
// @formatter:off
280282
this.client.get()
281283
.headers((headers) -> headers

config/src/test/kotlin/org/springframework/security/config/web/server/ServerJwtDslTests.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class ServerJwtDslTests {
160160
fun `jwt when using custom JWK Set URI then custom URI used`() {
161161
this.spring.register(CustomJwkSetUriConfig::class.java).autowire()
162162

163+
CustomJwkSetUriConfig.MOCK_WEB_SERVER.enqueue(MockResponse().setBody(jwkSet))
163164
CustomJwkSetUriConfig.MOCK_WEB_SERVER.enqueue(MockResponse().setBody(jwkSet))
164165

165166
this.client.get()

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,26 @@
2121
import java.net.URL;
2222
import java.security.interfaces.RSAPublicKey;
2323
import java.text.ParseException;
24+
import java.util.ArrayList;
2425
import java.util.Arrays;
2526
import java.util.Collection;
2627
import java.util.Collections;
2728
import java.util.HashSet;
2829
import java.util.LinkedHashMap;
30+
import java.util.List;
2931
import java.util.Map;
3032
import java.util.Set;
3133
import java.util.function.Consumer;
3234

3335
import javax.crypto.SecretKey;
3436

37+
import com.nimbusds.jose.Algorithm;
3538
import com.nimbusds.jose.JOSEException;
3639
import com.nimbusds.jose.JWSAlgorithm;
3740
import com.nimbusds.jose.RemoteKeySourceException;
41+
import com.nimbusds.jose.jwk.JWK;
3842
import com.nimbusds.jose.jwk.JWKSet;
43+
import com.nimbusds.jose.jwk.KeyUse;
3944
import com.nimbusds.jose.jwk.source.JWKSetCache;
4045
import com.nimbusds.jose.jwk.source.JWKSource;
4146
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
@@ -234,6 +239,8 @@ public static SecretKeyJwtDecoderBuilder withSecretKey(SecretKey secretKey) {
234239
*/
235240
public static final class JwkSetUriJwtDecoderBuilder {
236241

242+
private static final Log log = LogFactory.getLog(JwkSetUriJwtDecoderBuilder.class);
243+
237244
private String jwkSetUri;
238245

239246
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
@@ -322,17 +329,60 @@ public JwkSetUriJwtDecoderBuilder jwtProcessorCustomizer(
322329
}
323330

324331
JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
325-
if (this.signatureAlgorithms.isEmpty()) {
326-
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
332+
Set<SignatureAlgorithm> algorithms = new HashSet<>();
333+
if (!this.signatureAlgorithms.isEmpty()) {
334+
algorithms.addAll(this.signatureAlgorithms);
335+
} else {
336+
algorithms.addAll(fetchSignatureAlgorithms());
337+
}
338+
339+
if (algorithms.isEmpty()) {
340+
algorithms.add(SignatureAlgorithm.RS256);
327341
}
342+
328343
Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
329-
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
330-
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
331-
jwsAlgorithms.add(jwsAlgorithm);
344+
for (SignatureAlgorithm signatureAlgorithm : algorithms) {
345+
jwsAlgorithms.add(JWSAlgorithm.parse(signatureAlgorithm.getName()));
332346
}
347+
333348
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
334349
}
335350

351+
private Set<SignatureAlgorithm> fetchSignatureAlgorithms() {
352+
try {
353+
return parseAlgorithms(JWKSet.load(toURL(jwkSetUri), 5000, 5000, 0));
354+
} catch (Exception ex) {
355+
throw new IllegalArgumentException("Failed to load Signature Algorithms from remote JWK source.", ex);
356+
}
357+
}
358+
359+
private Set<SignatureAlgorithm> parseAlgorithms(JWKSet jwkSet) {
360+
if (jwkSet == null) {
361+
throw new IllegalArgumentException(String.format("No JWKs received from %s", jwkSetUri));
362+
}
363+
364+
List<JWK> jwks = new ArrayList<>();
365+
for (JWK jwk : jwkSet.getKeys()) {
366+
KeyUse keyUse = jwk.getKeyUse();
367+
if (keyUse != null && keyUse.equals(KeyUse.SIGNATURE)) {
368+
jwks.add(jwk);
369+
}
370+
}
371+
372+
Set<SignatureAlgorithm> algorithms = new HashSet<>();
373+
for (JWK jwk : jwks) {
374+
Algorithm algorithm = jwk.getAlgorithm();
375+
if (algorithm != null) {
376+
SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.from(algorithm.getName());
377+
if (signatureAlgorithm != null) {
378+
algorithms.add(signatureAlgorithm);
379+
}
380+
}
381+
}
382+
383+
return algorithms;
384+
}
385+
336386
JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) {
337387
if (this.cache == null) {
338388
return new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,31 @@
1616

1717
package org.springframework.security.oauth2.jwt;
1818

19+
import java.net.MalformedURLException;
20+
import java.net.URL;
1921
import java.security.interfaces.RSAPublicKey;
2022
import java.util.Collection;
2123
import java.util.Collections;
2224
import java.util.HashSet;
2325
import java.util.LinkedHashMap;
26+
import java.util.List;
2427
import java.util.Map;
2528
import java.util.Set;
2629
import java.util.function.Consumer;
2730
import java.util.function.Function;
2831

2932
import javax.crypto.SecretKey;
3033

34+
import com.nimbusds.jose.Algorithm;
3135
import com.nimbusds.jose.Header;
3236
import com.nimbusds.jose.JOSEException;
3337
import com.nimbusds.jose.JWSAlgorithm;
3438
import com.nimbusds.jose.JWSHeader;
3539
import com.nimbusds.jose.jwk.JWK;
3640
import com.nimbusds.jose.jwk.JWKMatcher;
3741
import com.nimbusds.jose.jwk.JWKSelector;
42+
import com.nimbusds.jose.jwk.JWKSet;
43+
import com.nimbusds.jose.jwk.KeyUse;
3844
import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
3945
import com.nimbusds.jose.jwk.source.JWKSource;
4046
import com.nimbusds.jose.proc.BadJOSEException;
@@ -50,6 +56,8 @@
5056
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
5157
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
5258
import com.nimbusds.jwt.proc.JWTProcessor;
59+
import org.apache.commons.logging.Log;
60+
import org.apache.commons.logging.LogFactory;
5361
import reactor.core.publisher.Flux;
5462
import reactor.core.publisher.Mono;
5563

@@ -273,6 +281,8 @@ private static <C extends SecurityContext> JWTClaimsSet createClaimsSet(JWTProce
273281
*/
274282
public static final class JwkSetUriReactiveJwtDecoderBuilder {
275283

284+
private static final Log log = LogFactory.getLog(JwkSetUriReactiveJwtDecoderBuilder.class);
285+
276286
private final String jwkSetUri;
277287

278288
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
@@ -354,17 +364,63 @@ public NimbusReactiveJwtDecoder build() {
354364
}
355365

356366
JWSKeySelector<JWKSecurityContext> jwsKeySelector(JWKSource<JWKSecurityContext> jwkSource) {
357-
if (this.signatureAlgorithms.isEmpty()) {
358-
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
367+
Set<SignatureAlgorithm> algorithms = new HashSet<>();
368+
if (!this.signatureAlgorithms.isEmpty()) {
369+
algorithms.addAll(this.signatureAlgorithms);
370+
} else {
371+
algorithms.addAll(fetchSignatureAlgorithms());
372+
}
373+
374+
if (algorithms.isEmpty()) {
375+
algorithms.add(SignatureAlgorithm.RS256);
359376
}
377+
360378
Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
361-
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
362-
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
363-
jwsAlgorithms.add(jwsAlgorithm);
379+
for (SignatureAlgorithm signatureAlgorithm : algorithms) {
380+
jwsAlgorithms.add(JWSAlgorithm.parse(signatureAlgorithm.getName()));
364381
}
382+
365383
return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource);
366384
}
367385

386+
private Set<SignatureAlgorithm> fetchSignatureAlgorithms() {
387+
if (StringUtils.isEmpty(jwkSetUri)) {
388+
return Collections.emptySet();
389+
}
390+
try {
391+
return parseAlgorithms(JWKSet.load(toURL(jwkSetUri), 5000, 5000, 0));
392+
} catch (Exception ex) {
393+
throw new IllegalArgumentException("Failed to load Signature Algorithms from remote JWK source.", ex);
394+
}
395+
}
396+
397+
private Set<SignatureAlgorithm> parseAlgorithms(JWKSet jwkSet) {
398+
if (jwkSet == null) {
399+
throw new IllegalArgumentException(String.format("No JWKs received from %s", jwkSetUri));
400+
}
401+
402+
List<JWK> jwks = new ArrayList<>();
403+
for (JWK jwk : jwkSet.getKeys()) {
404+
KeyUse keyUse = jwk.getKeyUse();
405+
if (keyUse != null && keyUse.equals(KeyUse.SIGNATURE)) {
406+
jwks.add(jwk);
407+
}
408+
}
409+
410+
Set<SignatureAlgorithm> algorithms = new HashSet<>();
411+
for (JWK jwk : jwks) {
412+
Algorithm algorithm = jwk.getAlgorithm();
413+
if (algorithm != null) {
414+
SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.from(algorithm.getName());
415+
if (signatureAlgorithm != null) {
416+
algorithms.add(signatureAlgorithm);
417+
}
418+
}
419+
}
420+
421+
return algorithms;
422+
}
423+
368424
Converter<JWT, Mono<JWTClaimsSet>> processor() {
369425
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
370426
DefaultJWTProcessor<JWKSecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
@@ -399,6 +455,13 @@ private JWKSelector createSelector(Function<JWSAlgorithm, Boolean> expectedJwsAl
399455
return new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
400456
}
401457

458+
private static URL toURL(String url) {
459+
try {
460+
return new URL(url);
461+
} catch (MalformedURLException ex) {
462+
throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex);
463+
}
464+
}
402465
}
403466

404467
/**

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/JwtDecodersTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ private void prepareConfigurationResponse() {
305305
private void prepareConfigurationResponse(String body) {
306306
this.server.enqueue(response(body));
307307
this.server.enqueue(response(JWK_SET));
308+
this.server.enqueue(response(JWK_SET));
308309
}
309310

310311
private void prepareConfigurationResponseOidc() {

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderJwkSupportTests.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ public void decodeWhenJwtIsMalformedThenReturnsStockException() throws Exception
136136
@Test
137137
public void decodeWhenJwkResponseIsMalformedThenReturnsStockException() throws Exception {
138138
try (MockWebServer server = new MockWebServer()) {
139+
server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
139140
server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
140141
String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
141142
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
@@ -151,6 +152,7 @@ public void decodeWhenJwkResponseIsMalformedThenReturnsStockException() throws E
151152
@Test
152153
public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws Exception {
153154
try (MockWebServer server = new MockWebServer()) {
155+
server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
154156
server.enqueue(new MockResponse().setBody(MALFORMED_JWK_SET));
155157
String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
156158
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);
@@ -167,6 +169,7 @@ public void decodeWhenJwkEndpointIsUnresponsiveThenReturnsJwtException() throws
167169
@Test
168170
public void decodeWhenCustomRestOperationsSetThenUsed() throws Exception {
169171
try (MockWebServer server = new MockWebServer()) {
172+
server.enqueue(new MockResponse().setBody(JWK_SET));
170173
server.enqueue(new MockResponse().setBody(JWK_SET));
171174
String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
172175
NimbusJwtDecoderJwkSupport jwtDecoder = new NimbusJwtDecoderJwkSupport(jwkSetUrl);

0 commit comments

Comments
 (0)