diff --git a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java index 2713ee96b2d..df0239ebfec 100644 --- a/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java +++ b/oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,7 @@ package org.springframework.security.oauth2.jwt; -import java.io.IOException; -import java.net.MalformedURLException; -import java.net.URL; +import java.net.URI; import java.security.interfaces.RSAPublicKey; import java.text.ParseException; import java.util.Arrays; @@ -28,6 +26,7 @@ import java.util.LinkedHashMap; import java.util.Map; import java.util.Set; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; import java.util.function.Function; @@ -35,17 +34,17 @@ import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.KeySourceException; import com.nimbusds.jose.RemoteKeySourceException; import com.nimbusds.jose.jwk.JWKSet; -import com.nimbusds.jose.jwk.source.JWKSetCache; +import com.nimbusds.jose.jwk.source.JWKSetCacheRefreshEvaluator; +import com.nimbusds.jose.jwk.source.JWKSetSource; import com.nimbusds.jose.jwk.source.JWKSource; -import com.nimbusds.jose.jwk.source.RemoteJWKSet; +import com.nimbusds.jose.jwk.source.JWKSourceBuilder; import com.nimbusds.jose.proc.JWSKeySelector; import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jose.proc.SingleKeyJWSKeySelector; -import com.nimbusds.jose.util.Resource; -import com.nimbusds.jose.util.ResourceRetriever; import com.nimbusds.jwt.JWT; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.JWTParser; @@ -57,6 +56,7 @@ import org.apache.commons.logging.LogFactory; import org.springframework.cache.Cache; +import org.springframework.cache.support.NoOpCache; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -80,6 +80,7 @@ * @author Josh Cummings * @author Joe Grandja * @author Mykyta Bezverkhyi + * @author Daeho Kwon * @since 5.2 */ public final class NimbusJwtDecoder implements JwtDecoder { @@ -273,7 +274,7 @@ public static final class JwkSetUriJwtDecoderBuilder { private RestOperations restOperations = new RestTemplate(); - private Cache cache; + private Cache cache = new NoOpCache("default"); private Consumer> jwtProcessorCustomizer; @@ -376,18 +377,17 @@ JWSKeySelector jwsKeySelector(JWKSource jwkSou return new JWSVerificationKeySelector<>(jwsAlgorithms, jwkSource); } - JWKSource jwkSource(ResourceRetriever jwkSetRetriever, String jwkSetUri) { - if (this.cache == null) { - return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever); - } - JWKSetCache jwkSetCache = new SpringJWKSetCache(jwkSetUri, this.cache); - return new RemoteJWKSet<>(toURL(jwkSetUri), jwkSetRetriever, jwkSetCache); + JWKSource jwkSource() { + String jwkSetUri = this.jwkSetUri.apply(this.restOperations); + return JWKSourceBuilder.create(new SpringJWKSource<>(this.restOperations, this.cache, jwkSetUri)) + .refreshAheadCache(false) + .rateLimited(false) + .cache(this.cache instanceof NoOpCache) + .build(); } JWTProcessor processor() { - ResourceRetriever jwkSetRetriever = new RestOperationsResourceRetriever(this.restOperations); - String jwkSetUri = this.jwkSetUri.apply(this.restOperations); - JWKSource jwkSource = jwkSource(jwkSetRetriever, jwkSetUri); + JWKSource jwkSource = jwkSource(); ConfigurableJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); jwtProcessor.setJWSKeySelector(jwsKeySelector(jwkSource)); // Spring Security validates the claim set independent from Nimbus @@ -405,34 +405,29 @@ public NimbusJwtDecoder build() { return new NimbusJwtDecoder(processor()); } - private static URL toURL(String url) { - try { - return new URL(url); - } - catch (MalformedURLException ex) { - throw new IllegalArgumentException("Invalid JWK Set URL \"" + url + "\" : " + ex.getMessage(), ex); - } - } + private static final class SpringJWKSource implements JWKSetSource { - private static final class SpringJWKSetCache implements JWKSetCache { + private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); - private final String jwkSetUri; + private final ReentrantLock reentrantLock = new ReentrantLock(); + + private final RestOperations restOperations; private final Cache cache; + private final String jwkSetUri; + private JWKSet jwkSet; - SpringJWKSetCache(String jwkSetUri, Cache cache) { - this.jwkSetUri = jwkSetUri; + private SpringJWKSource(RestOperations restOperations, Cache cache, String jwkSetUri) { + Assert.notNull(restOperations, "restOperations cannot be null"); + this.restOperations = restOperations; this.cache = cache; - this.updateJwkSetFromCache(); - } - - private void updateJwkSetFromCache() { - String cachedJwkSet = this.cache.get(this.jwkSetUri, String.class); - if (cachedJwkSet != null) { + this.jwkSetUri = jwkSetUri; + String jwks = this.cache.get(this.jwkSetUri, String.class); + if (jwks != null) { try { - this.jwkSet = JWKSet.parse(cachedJwkSet); + this.jwkSet = JWKSet.parse(jwks); } catch (ParseException ignored) { // Ignore invalid cache value @@ -440,58 +435,43 @@ private void updateJwkSetFromCache() { } } - // Note: Only called from inside a synchronized block in RemoteJWKSet. - @Override - public void put(JWKSet jwkSet) { - this.jwkSet = jwkSet; - this.cache.put(this.jwkSetUri, jwkSet.toString(false)); - } - - @Override - public JWKSet get() { - return (!requiresRefresh()) ? this.jwkSet : null; - - } - - @Override - public boolean requiresRefresh() { - return this.cache.get(this.jwkSetUri) == null; - } - - } - - private static class RestOperationsResourceRetriever implements ResourceRetriever { - - private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json"); - - private final RestOperations restOperations; - - RestOperationsResourceRetriever(RestOperations restOperations) { - Assert.notNull(restOperations, "restOperations cannot be null"); - this.restOperations = restOperations; - } - - @Override - public Resource retrieveResource(URL url) throws IOException { + private String fetchJwks() throws Exception { HttpHeaders headers = new HttpHeaders(); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); - ResponseEntity response = getResponse(url, headers); - if (response.getStatusCode().value() != 200) { - throw new IOException(response.toString()); - } - return new Resource(response.getBody(), "UTF-8"); + RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, URI.create(this.jwkSetUri)); + ResponseEntity response = this.restOperations.exchange(request, String.class); + String jwks = response.getBody(); + this.jwkSet = JWKSet.parse(jwks); + return jwks; } - private ResponseEntity getResponse(URL url, HttpHeaders headers) throws IOException { + @Override + public JWKSet getJWKSet(JWKSetCacheRefreshEvaluator refreshEvaluator, long currentTime, C context) + throws KeySourceException { try { - RequestEntity request = new RequestEntity<>(headers, HttpMethod.GET, url.toURI()); - return this.restOperations.exchange(request, String.class); + this.reentrantLock.lock(); + if (refreshEvaluator.requiresRefresh(this.jwkSet)) { + this.cache.invalidate(); + } + this.cache.get(this.jwkSetUri, this::fetchJwks); + return this.jwkSet; } - catch (Exception ex) { - throw new IOException(ex); + catch (Cache.ValueRetrievalException ex) { + if (ex.getCause() instanceof RemoteKeySourceException keys) { + throw keys; + } + throw new RemoteKeySourceException(ex.getCause().getMessage(), ex.getCause()); + } + finally { + this.reentrantLock.unlock(); } } + @Override + public void close() { + + } + } } diff --git a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java index fb4535f240d..d638795df9a 100644 --- a/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java +++ b/oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -60,7 +60,6 @@ import org.springframework.cache.Cache; import org.springframework.cache.concurrent.ConcurrentMapCache; -import org.springframework.cache.support.SimpleValueWrapper; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.converter.Converter; import org.springframework.http.HttpStatus; @@ -702,9 +701,8 @@ public void decodeWhenCacheStoredThenAbleToRetrieveJwkSetFromCache() { @Test public void decodeWhenCacheThenRetrieveFromCache() throws Exception { RestOperations restOperations = mock(RestOperations.class); - Cache cache = mock(Cache.class); - given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET); - given(cache.get(eq(JWK_SET_URI))).willReturn(mock(Cache.ValueWrapper.class)); + Cache cache = new ConcurrentMapCache("cache"); + cache.put(JWK_SET_URI, JWK_SET); // @formatter:off NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .cache(cache) @@ -712,9 +710,7 @@ public void decodeWhenCacheThenRetrieveFromCache() throws Exception { .build(); // @formatter:on jwtDecoder.decode(SIGNED_JWT); - verify(cache).get(eq(JWK_SET_URI), eq(String.class)); - verify(cache, times(2)).get(eq(JWK_SET_URI)); - verifyNoMoreInteractions(cache); + assertThat(cache.get(JWK_SET_URI, String.class)).isSameAs(JWK_SET); verifyNoInteractions(restOperations); } @@ -722,9 +718,8 @@ public void decodeWhenCacheThenRetrieveFromCache() throws Exception { @Test public void decodeWhenCacheAndUnknownKidShouldTriggerFetchOfJwkSet() throws JOSEException { RestOperations restOperations = mock(RestOperations.class); - Cache cache = mock(Cache.class); - given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET); - given(cache.get(eq(JWK_SET_URI))).willReturn(new SimpleValueWrapper(JWK_SET)); + Cache cache = new ConcurrentMapCache("cache"); + cache.put(JWK_SET_URI, JWK_SET); given(restOperations.exchange(any(RequestEntity.class), eq(String.class))) .willReturn(new ResponseEntity<>(NEW_KID_JWK_SET, HttpStatus.OK)); @@ -794,9 +789,8 @@ public void decodeWhenCacheIsConfiguredAndValueLoaderErrorsThenThrowsJwtExceptio @Test public void decodeWhenCacheIsConfiguredAndParseFailsOnCachedValueThenExceptionIgnored() { RestOperations restOperations = mock(RestOperations.class); - Cache cache = mock(Cache.class); - given(cache.get(eq(JWK_SET_URI), eq(String.class))).willReturn(JWK_SET); - given(cache.get(eq(JWK_SET_URI))).willReturn(mock(Cache.ValueWrapper.class)); + Cache cache = new ConcurrentMapCache("cache"); + cache.put(JWK_SET_URI, JWK_SET); // @formatter:off NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(JWK_SET_URI) .cache(cache) @@ -804,9 +798,7 @@ public void decodeWhenCacheIsConfiguredAndParseFailsOnCachedValueThenExceptionIg .build(); // @formatter:on jwtDecoder.decode(SIGNED_JWT); - verify(cache).get(eq(JWK_SET_URI), eq(String.class)); - verify(cache, times(2)).get(eq(JWK_SET_URI)); - verifyNoMoreInteractions(cache); + assertThat(cache.get(JWK_SET_URI, String.class)).isSameAs(JWK_SET); verifyNoInteractions(restOperations); }