diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java index 2de3e64a815..fb0468fa9bc 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,6 +46,7 @@ import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -86,6 +87,14 @@ public final class NimbusJwtEncoder implements JwtEncoder { private final JWKSource jwkSource; + private Converter, JWK> jwkSelector = (jwks) -> { + throw new JwtEncodingException( + String.format( + "Failed to select a key since there are multiple for the signing algorithm [%s]; " + + "please specify a selector in NimbusJwsEncoder#setJwkSelector", + jwks.get(0).getAlgorithm())); + }; + /** * Constructs a {@code NimbusJwtEncoder} using the provided parameters. * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource} @@ -95,6 +104,21 @@ public NimbusJwtEncoder(JWKSource jwkSource) { this.jwkSource = jwkSource; } + /** + * Use this strategy to reduce the list of matching JWKs when there is more than one. + *

+ * For example, you can call {@code setJwkSelector(List::getFirst)} in order to have + * this encoder select the first match. + * + *

+ * By default, the class with throw an exception. + * @since 6.5 + */ + public void setJwkSelector(Converter, JWK> jwkSelector) { + Assert.notNull(jwkSelector, "jwkSelector cannot be null"); + this.jwkSelector = jwkSelector; + } + @Override public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException { Assert.notNull(parameters, "parameters cannot be null"); @@ -123,18 +147,14 @@ private JWK selectJwk(JwsHeader headers) { throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key -> " + ex.getMessage()), ex); } - - if (jwks.size() > 1) { - throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, - "Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'")); - } - if (jwks.isEmpty()) { throw new JwtEncodingException( String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key")); } - - return jwks.get(0); + if (jwks.size() == 1) { + return jwks.get(0); + } + return this.jwkSelector.convert(jwks); } private String serialize(JwsHeader headers, JwtClaimsSet claims, JWK jwk) { diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java index 412adbfd4d2..d0426a4533f 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jose/TestJwks.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,10 @@ public final class TestJwks { private TestJwks() { } + public static RSAKey.Builder rsa() { + return jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY); + } + public static RSAKey.Builder jwk(RSAPublicKey publicKey, RSAPrivateKey privateKey) { // @formatter:off return new RSAKey.Builder(publicKey) diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java index e9825f0a359..ab17156eacb 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import java.util.Collections; import java.util.List; +import com.nimbusds.jose.JWSAlgorithm; import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.jwk.ECKey; import com.nimbusds.jose.jwk.JWK; @@ -39,6 +40,7 @@ import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import org.springframework.core.convert.converter.Converter; import org.springframework.security.oauth2.jose.TestJwks; import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; @@ -51,6 +53,8 @@ import static org.mockito.BDDMockito.willAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; /** * Tests for {@link NimbusJwtEncoder}. @@ -109,7 +113,7 @@ public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exce @Test public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception { - RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK; + RSAKey rsaJwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build(); this.jwkList.add(rsaJwk); this.jwkList.add(rsaJwk); @@ -118,7 +122,7 @@ public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws assertThatExceptionOfType(JwtEncodingException.class) .isThrownBy(() -> this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, jwtClaimsSet))) - .withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'"); + .withMessageContaining("Failed to select a key since there are multiple for the signing algorithm [RS256]"); } @Test @@ -291,6 +295,55 @@ public List get(JWKSelector jwkSelector, SecurityContext context) { assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID()); } + @Test + public void encodeWhenMultipleKeysThenJwkSelectorUsed() throws Exception { + JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build(); + JWKSource jwkSource = mock(JWKSource.class); + given(jwkSource.get(any(), any())).willReturn(List.of(jwk, jwk)); + Converter, JWK> selector = mock(Converter.class); + given(selector.convert(any())).willReturn(TestJwks.DEFAULT_RSA_JWK); + + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource); + jwtEncoder.setJwkSelector(selector); + + JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build(); + jwtEncoder.encode(JwtEncoderParameters.from(claims)); + + verify(selector).convert(any()); + } + + @Test + public void encodeWhenSingleKeyThenJwkSelectorIsNotUsed() throws Exception { + JWK jwk = TestJwks.rsa().algorithm(JWSAlgorithm.RS256).build(); + JWKSource jwkSource = mock(JWKSource.class); + given(jwkSource.get(any(), any())).willReturn(List.of(jwk)); + Converter, JWK> selector = mock(Converter.class); + + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource); + jwtEncoder.setJwkSelector(selector); + + JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build(); + jwtEncoder.encode(JwtEncoderParameters.from(claims)); + + verifyNoInteractions(selector); + } + + @Test + public void encodeWhenNoKeysThenJwkSelectorIsNotUsed() throws Exception { + JWKSource jwkSource = mock(JWKSource.class); + given(jwkSource.get(any(), any())).willReturn(List.of()); + Converter, JWK> selector = mock(Converter.class); + + NimbusJwtEncoder jwtEncoder = new NimbusJwtEncoder(jwkSource); + jwtEncoder.setJwkSelector(selector); + + JwtClaimsSet claims = JwtClaimsSet.builder().subject("sub").build(); + assertThatExceptionOfType(JwtEncodingException.class) + .isThrownBy(() -> jwtEncoder.encode(JwtEncoderParameters.from(claims))); + + verifyNoInteractions(selector); + } + private static final class JwkListResultCaptor implements Answer> { private List result;