Skip to content

Multiple JWS Algorithms #7162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright 2002-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.oauth2.jwt;

import java.security.Key;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.SecurityContext;

/**
* Class for delegating to a Nimbus JWSKeySelector by the given JWSAlgorithm
*
* @author Josh Cummings
*/
class JWSAlgorithmMapJWSKeySelector<C extends SecurityContext> implements JWSKeySelector<C> {
private Map<JWSAlgorithm, JWSKeySelector<C>> jwsKeySelectors;

JWSAlgorithmMapJWSKeySelector(Map<JWSAlgorithm, JWSKeySelector<C>> jwsKeySelectors) {
this.jwsKeySelectors = jwsKeySelectors;
}

@Override
public List<? extends Key> selectJWSKeys(JWSHeader header, C context) throws KeySourceException {
JWSKeySelector<C> keySelector = this.jwsKeySelectors.get(header.getAlgorithm());
if (keySelector == null) {
throw new IllegalArgumentException("Unsupported algorithm of " + header.getAlgorithm());
}
return keySelector.selectJWSKeys(header, context);
}

public Set<JWSAlgorithm> getExpectedJWSAlgorithms() {
return this.jwsKeySelectors.keySet();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@
import java.text.ParseException;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import javax.crypto.SecretKey;

import com.nimbusds.jose.JWSAlgorithm;
Expand Down Expand Up @@ -209,7 +213,7 @@ public static SecretKeyJwtDecoderBuilder withSecretKey(SecretKey secretKey) {
*/
public static final class JwkSetUriJwtDecoderBuilder {
private String jwkSetUri;
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256;
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
private RestOperations restOperations = new RestTemplate();

private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
Expand All @@ -218,15 +222,30 @@ private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
}

/**
* Use the given signing
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
* Append the given signing
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>
* to the set of algorithms to use.
*
* @param signatureAlgorithm the algorithm to use
* @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations
*/
public JwkSetUriJwtDecoderBuilder jwsAlgorithm(SignatureAlgorithm signatureAlgorithm) {
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
this.signatureAlgorithms.add(signatureAlgorithm);
return this;
}

/**
* Configure the list of
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithms</a>
* to use with the given {@link Consumer}.
*
* @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list
* @return a {@link JwkSetUriJwtDecoderBuilder} for further configurations
*/
public JwkSetUriJwtDecoderBuilder jwsAlgorithms(Consumer<Set<SignatureAlgorithm>> signatureAlgorithmsConsumer) {
Assert.notNull(signatureAlgorithmsConsumer, "signatureAlgorithmsConsumer cannot be null");
signatureAlgorithmsConsumer.accept(this.signatureAlgorithms);
return this;
}

Expand All @@ -245,13 +264,27 @@ public JwkSetUriJwtDecoderBuilder restOperations(RestOperations restOperations)
return this;
}

JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
} else if (this.signatureAlgorithms.size() == 1) {
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName());
return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource);
} else {
Map<JWSAlgorithm, JWSKeySelector<SecurityContext>> jwsKeySelectors = new HashMap<>();
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName());
jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource));
}
return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors);
}
}

JWTProcessor<SecurityContext> processor() {
ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations);
JWKSource<SecurityContext> jwkSource = new RemoteJWKSet<>(toURL(this.jwkSetUri), jwkSetRetriever);
JWSKeySelector<SecurityContext> jwsKeySelector =
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
jwtProcessor.setJWSKeySelector(jwsKeySelector);
jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource));

// Spring Security validates the claim set independent from Nimbus
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
import java.security.interfaces.RSAPublicKey;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.crypto.SecretKey;

Expand All @@ -31,6 +35,7 @@
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.source.JWKSecurityContextJWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWKSecurityContext;
import com.nimbusds.jose.proc.JWSKeySelector;
Expand Down Expand Up @@ -233,7 +238,7 @@ public static JwkSourceReactiveJwtDecoderBuilder withJwkSource(Function<SignedJW
*/
public static final class JwkSetUriReactiveJwtDecoderBuilder {
private final String jwkSetUri;
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.RS256;
private Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
private WebClient webClient = WebClient.create();

private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
Expand All @@ -242,15 +247,30 @@ private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
}

/**
* Use the given signing
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
* Append the given signing
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>
* to the set of algorithms to use.
*
* @param signatureAlgorithm the algorithm to use
* @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
*/
public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithm(SignatureAlgorithm signatureAlgorithm) {
Assert.notNull(signatureAlgorithm, "sig cannot be null");
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
this.signatureAlgorithms.add(signatureAlgorithm);
return this;
}

/**
* Configure the list of
* <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithms</a>
* to use with the given {@link Consumer}.
*
* @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list
* @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
*/
public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithms(Consumer<Set<SignatureAlgorithm>> signatureAlgorithmsConsumer) {
Assert.notNull(signatureAlgorithmsConsumer, "signatureAlgorithmsConsumer cannot be null");
signatureAlgorithmsConsumer.accept(this.signatureAlgorithms);
return this;
}

Expand Down Expand Up @@ -278,28 +298,53 @@ public NimbusReactiveJwtDecoder build() {
return new NimbusReactiveJwtDecoder(processor());
}

JWSKeySelector<JWKSecurityContext> jwsKeySelector(JWKSource<JWKSecurityContext> jwkSource) {
if (this.signatureAlgorithms.isEmpty()) {
return new JWSVerificationKeySelector<>(JWSAlgorithm.RS256, jwkSource);
} else if (this.signatureAlgorithms.size() == 1) {
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(this.signatureAlgorithms.iterator().next().getName());
return new JWSVerificationKeySelector<>(jwsAlgorithm, jwkSource);
} else {
Map<JWSAlgorithm, JWSKeySelector<JWKSecurityContext>> jwsKeySelectors = new HashMap<>();
for (SignatureAlgorithm signatureAlgorithm : this.signatureAlgorithms) {
JWSAlgorithm jwsAlg = JWSAlgorithm.parse(signatureAlgorithm.getName());
jwsKeySelectors.put(jwsAlg, new JWSVerificationKeySelector<>(jwsAlg, jwkSource));
}
return new JWSAlgorithmMapJWSKeySelector<>(jwsKeySelectors);
}
}

Converter<JWT, Mono<JWTClaimsSet>> processor() {
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet();

JWSKeySelector<JWKSecurityContext> jwsKeySelector =
new JWSVerificationKeySelector<>(this.jwsAlgorithm, jwkSource);
DefaultJWTProcessor<JWKSecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
JWSKeySelector<JWKSecurityContext> jwsKeySelector = jwsKeySelector(jwkSource);
jwtProcessor.setJWSKeySelector(jwsKeySelector);
jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> {});

ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource(this.jwkSetUri);
source.setWebClient(this.webClient);

Set<JWSAlgorithm> expectedJwsAlgorithms = getExpectedJwsAlgorithms(jwsKeySelector);
return jwt -> {
JWKSelector selector = createSelector(jwt.getHeader());
JWKSelector selector = createSelector(expectedJwsAlgorithms, jwt.getHeader());
return source.get(selector)
.onErrorMap(e -> new IllegalStateException("Could not obtain the keys", e))
.map(jwkList -> createClaimsSet(jwtProcessor, jwt, new JWKSecurityContext(jwkList)));
};
}

private JWKSelector createSelector(Header header) {
if (!this.jwsAlgorithm.equals(header.getAlgorithm())) {
private Set<JWSAlgorithm> getExpectedJwsAlgorithms(JWSKeySelector<?> jwsKeySelector) {
if (jwsKeySelector instanceof JWSVerificationKeySelector) {
return Collections.singleton(((JWSVerificationKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithm());
}
if (jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector) {
return ((JWSAlgorithmMapJWSKeySelector<?>) jwsKeySelector).getExpectedJWSAlgorithms();
}
throw new IllegalArgumentException("Unsupported key selector type " + jwsKeySelector.getClass());
}

private JWKSelector createSelector(Set<JWSAlgorithm> expectedJwsAlgorithms, Header header) {
if (!expectedJwsAlgorithms.contains(header.getAlgorithm())) {
throw new JwtException("Unsupported algorithm of " + header.getAlgorithm());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.MACSigner;
import com.nimbusds.jose.crypto.RSASSASigner;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
Expand Down Expand Up @@ -357,6 +360,46 @@ public void decodeWhenUsingSecertKeyWithKidThenStillUsesKey() throws Exception {
.isEqualTo("test-subject");
}

@Test
public void jwsKeySelectorWhenNoAlgorithmThenReturnsRS256Selector() {
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
JWSKeySelector<SecurityContext> jwsKeySelector =
withJwkSetUri(JWK_SET_URI).jwsKeySelector(jwkSource);
assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
JWSVerificationKeySelector<?> jwsVerificationKeySelector =
(JWSVerificationKeySelector<?>) jwsKeySelector;
assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm())
.isEqualTo(JWSAlgorithm.RS256);
}

@Test
public void jwsKeySelectorWhenOneAlgorithmThenReturnsSingleSelector() {
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
JWSKeySelector<SecurityContext> jwsKeySelector =
withJwkSetUri(JWK_SET_URI).jwsAlgorithm(SignatureAlgorithm.RS512)
.jwsKeySelector(jwkSource);
assertThat(jwsKeySelector instanceof JWSVerificationKeySelector);
JWSVerificationKeySelector<?> jwsVerificationKeySelector =
(JWSVerificationKeySelector<?>) jwsKeySelector;
assertThat(jwsVerificationKeySelector.getExpectedJWSAlgorithm())
.isEqualTo(JWSAlgorithm.RS512);
}

@Test
public void jwsKeySelectorWhenMultipleAlgorithmThenReturnsCompositeSelector() {
JWKSource<SecurityContext> jwkSource = mock(JWKSource.class);
JWSKeySelector<SecurityContext> jwsKeySelector =
withJwkSetUri(JWK_SET_URI)
.jwsAlgorithm(SignatureAlgorithm.RS256)
.jwsAlgorithm(SignatureAlgorithm.RS512)
.jwsKeySelector(jwkSource);
assertThat(jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector);
JWSAlgorithmMapJWSKeySelector<?> jwsAlgorithmMapKeySelector =
(JWSAlgorithmMapJWSKeySelector<?>) jwsKeySelector;
assertThat(jwsAlgorithmMapKeySelector.getExpectedJWSAlgorithms())
.containsExactlyInAnyOrder(JWSAlgorithm.RS256, JWSAlgorithm.RS512);
}

private RSAPublicKey key() throws InvalidKeySpecException {
byte[] decoded = Base64.getDecoder().decode(VERIFY_KEY.getBytes());
EncodedKeySpec spec = new X509EncodedKeySpec(decoded);
Expand Down
Loading