Skip to content

Support symmetric key for JwtDecoder #6495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,30 @@
*/
package org.springframework.security.oauth2.client.oidc.authentication;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusJwtDecoder.withSecretKey;

/**
* A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder}
Expand All @@ -47,14 +54,45 @@
*/
public final class OidcIdTokenDecoderFactory implements JwtDecoderFactory<ClientRegistration> {
private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier";
private static Map<JwsAlgorithm, String> jcaAlgorithmMappings = new HashMap<JwsAlgorithm, String>() {
{
put(MacAlgorithm.HS256, "HmacSHA256");
put(MacAlgorithm.HS384, "HmacSHA384");
put(MacAlgorithm.HS512, "HmacSHA512");
}
};
private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = OidcIdTokenValidator::new;
private Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver = clientRegistration -> SignatureAlgorithm.RS256;

@Override
public JwtDecoder createDecoder(ClientRegistration clientRegistration) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> {
if (!StringUtils.hasText(clientRegistration.getProviderDetails().getJwkSetUri())) {
NimbusJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
OAuth2TokenValidator<Jwt> jwtValidator = this.jwtValidatorFactory.apply(clientRegistration);
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
});
}

private NimbusJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(clientRegistration);
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 6. If the ID Token is received via direct communication between the Client
// and the Token Endpoint (which it is in this flow),
// the TLS server validation MAY be used to validate the issuer in place of checking the token signature.
// The Client MUST validate the signature of all other ID Tokens according to JWS [JWS]
// using the algorithm specified in the JWT alg Header Parameter.
// The Client MUST use the keys provided by the Issuer.
//
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client
// in the id_token_signed_response_alg parameter during Registration.

String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
if (!StringUtils.hasText(jwkSetUri)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
Expand All @@ -64,12 +102,42 @@ public JwtDecoder createDecoder(ClientRegistration clientRegistration) {
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
NimbusJwtDecoder jwtDecoder = withJwkSetUri(jwkSetUri).build();
OAuth2TokenValidator<Jwt> jwtValidator = this.jwtValidatorFactory.apply(clientRegistration);
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
});
return withJwkSetUri(jwkSetUri).jwsAlgorithm(jwsAlgorithm).build();
} else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 8. If the JWT alg Header Parameter uses a MAC based algorithm such as HS256, HS384, or HS512,
// the octets of the UTF-8 representation of the client_secret
// corresponding to the client_id contained in the aud (audience) Claim
// are used as the key to validate the signature.
// For MAC based algorithms, the behavior is unspecified if the aud is multi-valued or
// if an azp value is present that is different than the aud value.

String clientSecret = clientRegistration.getClientSecret();
if (!StringUtils.hasText(clientSecret)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured the client secret.",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
SecretKeySpec secretKeySpec = new SecretKeySpec(
clientSecret.getBytes(StandardCharsets.UTF_8), jcaAlgorithmMappings.get(jwsAlgorithm));
return withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
}

OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured a valid JWS Algorithm: '" +
jwsAlgorithm + "'",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

/**
Expand All @@ -82,4 +150,17 @@ public final void setJwtValidatorFactory(Function<ClientRegistration, OAuth2Toke
Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null");
this.jwtValidatorFactory = jwtValidatorFactory;
}

/**
* Sets the resolver that provides the expected {@link JwsAlgorithm JWS algorithm}
* used for the signature or MAC on the {@link OidcIdToken ID Token}.
* The default resolves to {@link SignatureAlgorithm#RS256 RS256} for all {@link ClientRegistration clients}.
*
* @param jwsAlgorithmResolver the resolver that provides the expected {@link JwsAlgorithm JWS algorithm}
* for a specific {@link ClientRegistration client}
*/
public final void setJwsAlgorithmResolver(Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver) {
Assert.notNull(jwsAlgorithmResolver, "jwsAlgorithmResolver cannot be null");
this.jwsAlgorithmResolver = jwsAlgorithmResolver;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,26 @@
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withJwkSetUri;
import static org.springframework.security.oauth2.jwt.NimbusReactiveJwtDecoder.withSecretKey;

/**
* A {@link ReactiveJwtDecoderFactory factory} that provides a {@link ReactiveJwtDecoder}
* used for {@link OidcIdToken} signature verification.
Expand All @@ -45,14 +54,45 @@
*/
public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecoderFactory<ClientRegistration> {
private static final String MISSING_SIGNATURE_VERIFIER_ERROR_CODE = "missing_signature_verifier";
private static Map<JwsAlgorithm, String> jcaAlgorithmMappings = new HashMap<JwsAlgorithm, String>() {
{
put(MacAlgorithm.HS256, "HmacSHA256");
put(MacAlgorithm.HS384, "HmacSHA384");
put(MacAlgorithm.HS512, "HmacSHA512");
}
};
private final Map<String, ReactiveJwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
private Function<ClientRegistration, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = OidcIdTokenValidator::new;
private Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver = clientRegistration -> SignatureAlgorithm.RS256;

@Override
public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
return this.jwtDecoders.computeIfAbsent(clientRegistration.getRegistrationId(), key -> {
if (!StringUtils.hasText(clientRegistration.getProviderDetails().getJwkSetUri())) {
NimbusReactiveJwtDecoder jwtDecoder = buildDecoder(clientRegistration);
OAuth2TokenValidator<Jwt> jwtValidator = this.jwtValidatorFactory.apply(clientRegistration);
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
});
}

private NimbusReactiveJwtDecoder buildDecoder(ClientRegistration clientRegistration) {
JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(clientRegistration);
if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 6. If the ID Token is received via direct communication between the Client
// and the Token Endpoint (which it is in this flow),
// the TLS server validation MAY be used to validate the issuer in place of checking the token signature.
// The Client MUST validate the signature of all other ID Tokens according to JWS [JWS]
// using the algorithm specified in the JWT alg Header Parameter.
// The Client MUST use the keys provided by the Issuer.
//
// 7. The alg value SHOULD be the default of RS256 or the algorithm sent by the Client
// in the id_token_signed_response_alg parameter during Registration.

String jwkSetUri = clientRegistration.getProviderDetails().getJwkSetUri();
if (!StringUtils.hasText(jwkSetUri)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
Expand All @@ -62,12 +102,42 @@ public ReactiveJwtDecoder createDecoder(ClientRegistration clientRegistration) {
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
NimbusReactiveJwtDecoder jwtDecoder = new NimbusReactiveJwtDecoder(
clientRegistration.getProviderDetails().getJwkSetUri());
OAuth2TokenValidator<Jwt> jwtValidator = this.jwtValidatorFactory.apply(clientRegistration);
jwtDecoder.setJwtValidator(jwtValidator);
return jwtDecoder;
});
return withJwkSetUri(jwkSetUri).jwsAlgorithm(jwsAlgorithm).build();
} else if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
// https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
//
// 8. If the JWT alg Header Parameter uses a MAC based algorithm such as HS256, HS384, or HS512,
// the octets of the UTF-8 representation of the client_secret
// corresponding to the client_id contained in the aud (audience) Claim
// are used as the key to validate the signature.
// For MAC based algorithms, the behavior is unspecified if the aud is multi-valued or
// if an azp value is present that is different than the aud value.

String clientSecret = clientRegistration.getClientSecret();
if (!StringUtils.hasText(clientSecret)) {
OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured the client secret.",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
SecretKeySpec secretKeySpec = new SecretKeySpec(
clientSecret.getBytes(StandardCharsets.UTF_8), jcaAlgorithmMappings.get(jwsAlgorithm));
return withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
}

OAuth2Error oauth2Error = new OAuth2Error(
MISSING_SIGNATURE_VERIFIER_ERROR_CODE,
"Failed to find a Signature Verifier for Client Registration: '" +
clientRegistration.getRegistrationId() +
"'. Check to ensure you have configured a valid JWS Algorithm: '" +
jwsAlgorithm + "'",
null
);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}

/**
Expand All @@ -80,4 +150,17 @@ public final void setJwtValidatorFactory(Function<ClientRegistration, OAuth2Toke
Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null");
this.jwtValidatorFactory = jwtValidatorFactory;
}

/**
* Sets the resolver that provides the expected {@link JwsAlgorithm JWS algorithm}
* used for the signature or MAC on the {@link OidcIdToken ID Token}.
* The default resolves to {@link SignatureAlgorithm#RS256 RS256} for all {@link ClientRegistration clients}.
*
* @param jwsAlgorithmResolver the resolver that provides the expected {@link JwsAlgorithm JWS algorithm}
* for a specific {@link ClientRegistration client}
*/
public final void setJwsAlgorithmResolver(Function<ClientRegistration, JwsAlgorithm> jwsAlgorithmResolver) {
Assert.notNull(jwsAlgorithmResolver, "jwsAlgorithmResolver cannot be null");
this.jwsAlgorithmResolver = jwsAlgorithmResolver;
}
}
Loading