Skip to content

Commit 0209fba

Browse files
committed
Multiple JWS Algorithms
Fixes: gh-6883
1 parent 1a233a5 commit 0209fba

File tree

5 files changed

+263
-36
lines changed

5 files changed

+263
-36
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.jwt;
18+
19+
import java.security.Key;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Set;
23+
24+
import com.nimbusds.jose.JWSAlgorithm;
25+
import com.nimbusds.jose.JWSHeader;
26+
import com.nimbusds.jose.KeySourceException;
27+
import com.nimbusds.jose.proc.JWSKeySelector;
28+
import com.nimbusds.jose.proc.SecurityContext;
29+
30+
/**
31+
* Class for delegating to a Nimbus JWSKeySelector by the given JWSAlgorithm
32+
*
33+
* @author Josh Cummings
34+
*/
35+
class JWSAlgorithmMapJWSKeySelector<C extends SecurityContext> implements JWSKeySelector<C> {
36+
private Map<JWSAlgorithm, JWSKeySelector<C>> jwsKeySelectors;
37+
38+
JWSAlgorithmMapJWSKeySelector(Map<JWSAlgorithm, JWSKeySelector<C>> jwsKeySelectors) {
39+
this.jwsKeySelectors = jwsKeySelectors;
40+
}
41+
42+
@Override
43+
public List<? extends Key> selectJWSKeys(JWSHeader header, C context) throws KeySourceException {
44+
JWSKeySelector<C> keySelector = this.jwsKeySelectors.get(header.getAlgorithm());
45+
if (keySelector == null) {
46+
throw new IllegalArgumentException("Unsupported algorithm of " + header.getAlgorithm());
47+
}
48+
return keySelector.selectJWSKeys(header, context);
49+
}
50+
51+
public Set<JWSAlgorithm> getExpectedJWSAlgorithms() {
52+
return this.jwsKeySelectors.keySet();
53+
}
54+
}

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@
2323
import java.text.ParseException;
2424
import java.time.Instant;
2525
import java.util.Collections;
26+
import java.util.HashMap;
27+
import java.util.HashSet;
2628
import java.util.LinkedHashMap;
2729
import java.util.Map;
30+
import java.util.Set;
31+
import java.util.function.Consumer;
2832
import javax.crypto.SecretKey;
2933

3034
import com.nimbusds.jose.JWSAlgorithm;
@@ -209,7 +213,7 @@ public static SecretKeyJwtDecoderBuilder withSecretKey(SecretKey secretKey) {
209213
*/
210214
public static final class JwkSetUriJwtDecoderBuilder {
211215
private String jwkSetUri;
212-
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256;
216+
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
213217
private RestOperations restOperations = new RestTemplate();
214218

215219
private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
@@ -218,15 +222,30 @@ private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
218222
}
219223

220224
/**
221-
* Use the given signing
222-
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
225+
* Append the given signing
226+
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>
227+
* to the set of algorithms to use.
223228
*
224229
* @param signatureAlgorithm the algorithm to use
225230
* @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations
226231
*/
227232
public JwkSetUriJwtDecoderBuilder jwsAlgorithm(SignatureAlgorithm signatureAlgorithm) {
228233
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
229-
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
234+
this.signatureAlgorithms.add(signatureAlgorithm);
235+
return this;
236+
}
237+
238+
/**
239+
* Configure the list of
240+
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithms</a>
241+
* to use with the given {@link Consumer}.
242+
*
243+
* @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list
244+
* @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations
245+
*/
246+
public JwkSetUriJwtDecoderBuilder jwsAlgorithms(Consumer<Set<SignatureAlgorithm>> signatureAlgorithmsConsumer) {
247+
Assert.notNull(signatureAlgorithmsConsumer, "signatureAlgorithmsConsumer cannot be null");
248+
signatureAlgorithmsConsumer.accept(this.signatureAlgorithms);
230249
return this;
231250
}
232251

@@ -245,13 +264,27 @@ public JwkSetUriJwtDecoderBuilder restOperations(RestOperations restOperations)
245264
return this;
246265
}
247266

267+
JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
268+
if (this.signatureAlgorithms.isEmpty()) {
269+
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
270+
} else if (this.signatureAlgorithms.size() == 1) {
271+
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName());
272+
return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource);
273+
} else {
274+
Map<JWSAlgorithm, JWSKeySelector<SecurityContext>> jwsKeySelectors = new HashMap<>();
275+
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
276+
JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName());
277+
jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource));
278+
}
279+
return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors);
280+
}
281+
}
282+
248283
JWTProcessor<SecurityContext> processor() {
249284
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
250285
JWKSource<SecurityContext> jwkSource = new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
251-
JWSKeySelector<SecurityContext> jwsKeySelector =
252-
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
253286
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
254-
jwtProcessor.setJWSKeySelector(jwsKeySelector);
287+
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));
255288

256289
// Spring Security validates the claim set independent from Nimbus
257290
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { });

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818
import java.security.interfaces.RSAPublicKey;
1919
import java.time.Instant;
2020
import java.util.Collections;
21+
import java.util.HashMap;
22+
import java.util.HashSet;
2123
import java.util.LinkedHashMap;
2224
import java.util.Map;
25+
import java.util.Set;
26+
import java.util.function.Consumer;
2327
import java.util.function.Function;
2428
import javax.crypto.SecretKey;
2529

@@ -31,6 +35,7 @@
3135
import com.nimbusds.jose.jwk.JWKMatcher;
3236
import com.nimbusds.jose.jwk.JWKSelector;
3337
import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
38+
import com.nimbusds.jose.jwk.source.JWKSource;
3439
import com.nimbusds.jose.proc.BadJOSEException;
3540
import com.nimbusds.jose.proc.JWKSecurityContext;
3641
import com.nimbusds.jose.proc.JWSKeySelector;
@@ -233,7 +238,7 @@ public static JwkSourceReactiveJwtDecoderBuilder withJwkSource(Function<SignedJW
233238
*/
234239
public static final class JwkSetUriReactiveJwtDecoderBuilder {
235240
private final String jwkSetUri;
236-
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256;
241+
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
237242
private WebClient webClient = WebClient.create();
238243

239244
private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
@@ -242,15 +247,30 @@ private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
242247
}
243248

244249
/**
245-
* Use the given signing
246-
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
250+
* Append the given signing
251+
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>
252+
* to the set of algorithms to use.
247253
*
248254
* @param signatureAlgorithm the algorithm to use
249255
* @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
250256
*/
251257
public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithm(SignatureAlgorithm signatureAlgorithm) {
252258
Assert.notNull(signatureAlgorithm, "sig cannot be null");
253-
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
259+
this.signatureAlgorithms.add(signatureAlgorithm);
260+
return this;
261+
}
262+
263+
/**
264+
* Configure the list of
265+
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithms</a>
266+
* to use with the given {@link Consumer}.
267+
*
268+
* @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list
269+
* @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
270+
*/
271+
public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithms(Consumer<Set<SignatureAlgorithm>> signatureAlgorithmsConsumer) {
272+
Assert.notNull(signatureAlgorithmsConsumer, "signatureAlgorithmsConsumer cannot be null");
273+
signatureAlgorithmsConsumer.accept(this.signatureAlgorithms);
254274
return this;
255275
}
256276

@@ -278,28 +298,53 @@ public NimbusReactiveJwtDecoder build() {
278298
return new NimbusReactiveJwtDecoder(processor());
279299
}
280300

301+
JWSKeySelector<JWKSecurityContext> jwsKeySelector(JWKSource<JWKSecurityContext> jwkSource) {
302+
if (this.signatureAlgorithms.isEmpty()) {
303+
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
304+
} else if (this.signatureAlgorithms.size() == 1) {
305+
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName());
306+
return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource);
307+
} else {
308+
Map<JWSAlgorithm, JWSKeySelector<JWKSecurityContext>> jwsKeySelectors = new HashMap<>();
309+
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
310+
JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName());
311+
jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource));
312+
}
313+
return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors);
314+
}
315+
}
316+
281317
Converter<JWT, Mono<JWTClaimsSet>> processor() {
282318
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();
283-
284-
JWSKeySelector<JWKSecurityContext> jwsKeySelector =
285-
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
286319
DefaultJWTProcessor<JWKSecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
320+
JWSKeySelector<JWKSecurityContext> jwsKeySelector = jwsKeySelector(jwkSource);
287321
jwtProcessor.setJWSKeySelector(jwsKeySelector);
288322
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});
289323

290324
ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
291325
source.setWebClient(this.webClient);
292326

327+
Set<JWSAlgorithm> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
293328
return jwt -> {
294-
JWKSelector selector = createSelector(jwt.getHeader());
329+
JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
295330
return source.get(selector)
296331
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
297332
.map(jwkList -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList)));
298333
};
299334
}
300335

301-
private JWKSelector createSelector(Header header) {
302-
if (!this.jwsAlgorithm.equals(header.getAlgorithm())) {
336+
private Set<JWSAlgorithm> getExpectedJwsAlgorithms(JWSKeySelector<?> jwsKeySelector) {
337+
if (jwsKeySelector instanceof JWSVerificationKeySelector) {
338+
return Collections.singleton(((JWSVerificationKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithm());
339+
}
340+
if (jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector) {
341+
return ((JWSAlgorithmMapJWSKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithms();
342+
}
343+
throw new IllegalArgumentException("Unsupported key selector type " + jwsKeySelector.getClass());
344+
}
345+
346+
private JWKSelector createSelector(Set<JWSAlgorithm> expectedJwsAlgorithms, Header header) {
347+
if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) {
303348
throw new JwtException("Unsupported algorithm of " + header.getAlgorithm());
304349
}
305350

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@
3939
import com.nimbusds.jose.JWSSigner;
4040
import com.nimbusds.jose.crypto.MACSigner;
4141
import com.nimbusds.jose.crypto.RSASSASigner;
42+
import com.nimbusds.jose.jwk.source.JWKSource;
4243
import com.nimbusds.jose.proc.BadJOSEException;
44+
import com.nimbusds.jose.proc.JWSKeySelector;
45+
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
4346
import com.nimbusds.jose.proc.SecurityContext;
4447
import com.nimbusds.jwt.JWTClaimsSet;
4548
import com.nimbusds.jwt.SignedJWT;
@@ -357,6 +360,46 @@ public void decodeWhenUsingSecertKeyWithKidThenStillUsesKey() throws Exception {
357360
.isEqualTo("test-subject");
358361
}
359362

363+
@Test
364+
public void jwsKeySelectorWhenNoAlgorithmThenReturnsRS256Selector() {
365+
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
366+
JWSKeySelector<SecurityContext> jwsKeySelector =
367+
withJwkSetUri(JWK_SET_URI).jwsKeySelector(jwkSource);
368+
assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
369+
JWSVerificationKeySelector<?> jwsVerificationKeySelector =
370+
(JWSVerificationKeySelector<?>) jwsKeySelector;
371+
assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm())
372+
.isEqualTo(JWSAlgorithm.RS256);
373+
}
374+
375+
@Test
376+
public void jwsKeySelectorWhenOneAlgorithmThenReturnsSingleSelector() {
377+
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
378+
JWSKeySelector<SecurityContext> jwsKeySelector =
379+
withJwkSetUri(JWK_SET_URI).jwsAlgorithm(SignatureAlgorithm.RS512)
380+
.jwsKeySelector(jwkSource);
381+
assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
382+
JWSVerificationKeySelector<?> jwsVerificationKeySelector =
383+
(JWSVerificationKeySelector<?>) jwsKeySelector;
384+
assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm())
385+
.isEqualTo(JWSAlgorithm.RS512);
386+
}
387+
388+
@Test
389+
public void jwsKeySelectorWhenMultipleAlgorithmThenReturnsCompositeSelector() {
390+
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
391+
JWSKeySelector<SecurityContext> jwsKeySelector =
392+
withJwkSetUri(JWK_SET_URI)
393+
.jwsAlgorithm(SignatureAlgorithm.RS256)
394+
.jwsAlgorithm(SignatureAlgorithm.RS512)
395+
.jwsKeySelector(jwkSource);
396+
assertThat(jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector);
397+
JWSAlgorithmMapJWSKeySelector<?> jwsAlgorithmMapKeySelector =
398+
(JWSAlgorithmMapJWSKeySelector<?>) jwsKeySelector;
399+
assertThat(jwsAlgorithmMapKeySelector.getExpectedJWSAlgorithms())
400+
.containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS512);
401+
}
402+
360403
private RSAPublicKey key() throws InvalidKeySpecException {
361404
byte[] decoded = Base64.getDecoder().decode(VERIFY_KEY.getBytes());
362405
EncodedKeySpec spec = new X509EncodedKeySpec(decoded);

0 commit comments

Comments
 (0)