1616
1717package org .springframework .security .oauth2 .jwt ;
1818
19- import com .nimbusds .jose .KeySourceException ;
20- import com .nimbusds .jose .jwk .JWK ;
21- import com .nimbusds .jose .jwk .JWKMatcher ;
22- import com .nimbusds .jose .jwk .JWKSelector ;
23- import com .nimbusds .jose .jwk .source .JWKSetParseException ;
24- import com .nimbusds .jose .jwk .source .JWKSetRetrievalException ;
25- import java .io .IOException ;
26- import java .net .MalformedURLException ;
27- import java .net .URL ;
19+ import java .net .URI ;
2820import java .security .interfaces .RSAPublicKey ;
2921import java .text .ParseException ;
3022import java .util .Arrays ;
3123import java .util .Collection ;
3224import java .util .Collections ;
3325import java .util .HashSet ;
3426import java .util .LinkedHashMap ;
35- import java .util .List ;
3627import java .util .Map ;
3728import java .util .Set ;
3829import java .util .concurrent .locks .ReentrantLock ;
4334
4435import com .nimbusds .jose .JOSEException ;
4536import com .nimbusds .jose .JWSAlgorithm ;
37+ import com .nimbusds .jose .KeySourceException ;
38+ import com .nimbusds .jose .RemoteKeySourceException ;
4639import com .nimbusds .jose .jwk .JWKSet ;
40+ import com .nimbusds .jose .jwk .source .JWKSetCacheRefreshEvaluator ;
41+ import com .nimbusds .jose .jwk .source .JWKSetSource ;
4742import com .nimbusds .jose .jwk .source .JWKSource ;
43+ import com .nimbusds .jose .jwk .source .JWKSourceBuilder ;
4844import com .nimbusds .jose .proc .JWSKeySelector ;
4945import com .nimbusds .jose .proc .JWSVerificationKeySelector ;
5046import com .nimbusds .jose .proc .SecurityContext ;
@@ -170,7 +166,7 @@ private Jwt createJwt(String token, JWT parsedJwt) {
170166 .build ();
171167 // @formatter:on
172168 }
173- catch (KeySourceException ex ) {
169+ catch (RemoteKeySourceException ex ) {
174170 this .logger .trace ("Failed to retrieve JWK set" , ex );
175171 if (ex .getCause () instanceof ParseException ) {
176172 throw new JwtException (String .format (DECODING_ERROR_MESSAGE_TEMPLATE , "Malformed Jwk set" ), ex );
@@ -383,7 +379,11 @@ JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSou
383379
384380 JWKSource <SecurityContext > jwkSource () {
385381 String jwkSetUri = this .jwkSetUri .apply (this .restOperations );
386- return new SpringJWKSource <>(this .restOperations , this .cache , toURL (jwkSetUri ), jwkSetUri );
382+ return JWKSourceBuilder .create (new SpringJWKSource <>(this .restOperations , this .cache , jwkSetUri ))
383+ .refreshAheadCache (false )
384+ .rateLimited (false )
385+ .cache (this .cache instanceof NoOpCache )
386+ .build ();
387387 }
388388
389389 JWTProcessor <SecurityContext > processor () {
@@ -405,16 +405,7 @@ public NimbusJwtDecoder build() {
405405 return new NimbusJwtDecoder (processor ());
406406 }
407407
408- private static URL toURL (String url ) {
409- try {
410- return new URL (url );
411- }
412- catch (MalformedURLException ex ) {
413- throw new IllegalArgumentException ("Invalid JWK Set URL \" " + url + "\" : " + ex .getMessage (), ex );
414- }
415- }
416-
417- private static final class SpringJWKSource <C extends SecurityContext > implements JWKSource <C > {
408+ private static final class SpringJWKSource <C extends SecurityContext > implements JWKSetSource <C > {
418409
419410 private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType ("application" , "jwk-set+json" );
420411
@@ -424,120 +415,63 @@ private static final class SpringJWKSource<C extends SecurityContext> implements
424415
425416 private final Cache cache ;
426417
427- private final URL url ;
428-
429418 private final String jwkSetUri ;
430419
431- private SpringJWKSource (RestOperations restOperations , Cache cache , URL url , String jwkSetUri ) {
420+ private JWKSet jwkSet ;
421+
422+ private SpringJWKSource (RestOperations restOperations , Cache cache , String jwkSetUri ) {
432423 Assert .notNull (restOperations , "restOperations cannot be null" );
433424 this .restOperations = restOperations ;
434425 this .cache = cache ;
435- this .url = url ;
436426 this .jwkSetUri = jwkSetUri ;
437- }
438-
439-
440- @ Override
441- public List <JWK > get (JWKSelector jwkSelector , SecurityContext context ) throws KeySourceException {
442- String cachedJwkSet = this .cache .get (this .jwkSetUri , String .class );
443- JWKSet jwkSet = null ;
444- if (cachedJwkSet != null ) {
445- jwkSet = parse (cachedJwkSet );
446- }
447- if (jwkSet == null ) {
448- if (reentrantLock .tryLock ()) {
449- try {
450- String cachedJwkSetAfterLock = this .cache .get (this .jwkSetUri , String .class );
451- if (cachedJwkSetAfterLock != null ) {
452- jwkSet = parse (cachedJwkSetAfterLock );
453- }
454- if (jwkSet == null ) {
455- try {
456- jwkSet = fetchJWKSet ();
457- } catch (IOException e ) {
458- throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
459- }
460- }
461- } finally {
462- reentrantLock .unlock ();
463- }
464- }
465- }
466- List <JWK > matches = jwkSelector .select (jwkSet );
467- if (!matches .isEmpty ()) {
468- return matches ;
469- }
470- String soughtKeyID = getFirstSpecifiedKeyID (jwkSelector .getMatcher ());
471- if (soughtKeyID == null ) {
472- return Collections .emptyList ();
473- }
474- if (jwkSet .getKeyByKeyId (soughtKeyID ) != null ) {
475- return Collections .emptyList ();
476- }
477-
478- if (reentrantLock .tryLock ()) {
427+ String jwks = this .cache .get (this .jwkSetUri , String .class );
428+ if (jwks != null ) {
479429 try {
480- String jwkSetUri = this .cache .get (this .jwkSetUri , String .class );
481- JWKSet cacheJwkSet = parse (jwkSetUri );
482- if (jwkSetUri != null && cacheJwkSet .toString ().equals (jwkSet .toString ())) {
483- try {
484- jwkSet = fetchJWKSet ();
485- } catch (IOException e ) {
486- throw new JWKSetRetrievalException ("Couldn't retrieve JWK set from URL: " + e .getMessage (), e );
487- }
488- } else if (jwkSetUri != null ) {
489- jwkSet = parse (jwkSetUri );
490- }
491- } finally {
492- reentrantLock .unlock ();
430+ this .jwkSet = JWKSet .parse (jwks );
431+ }
432+ catch (ParseException ignored ) {
433+ // Ignore invalid cache value
493434 }
494435 }
495- if (jwkSet == null ) {
496- return Collections .emptyList ();
497- }
498- return jwkSelector .select (jwkSet );
499436 }
500437
501- private JWKSet fetchJWKSet () throws IOException , KeySourceException {
438+ private String fetchJwks () throws Exception {
502439 HttpHeaders headers = new HttpHeaders ();
503440 headers .setAccept (Arrays .asList (MediaType .APPLICATION_JSON , APPLICATION_JWK_SET_JSON ));
504- ResponseEntity <String > response = getResponse (headers );
505- if (response .getStatusCode ().value () != 200 ) {
506- throw new IOException (response .toString ());
507- }
508- try {
509- String jwkSet = response .getBody ();
510- this .cache .put (this .jwkSetUri , jwkSet );
511- return JWKSet .parse (jwkSet );
512- } catch (ParseException e ) {
513- throw new JWKSetParseException ("Unable to parse JWK set" , e );
514- }
441+ RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , URI .create (this .jwkSetUri ));
442+ ResponseEntity <String > response = this .restOperations .exchange (request , String .class );
443+ String jwks = response .getBody ();
444+ this .jwkSet = JWKSet .parse (jwks );
445+ return jwks ;
515446 }
516447
517- private ResponseEntity <String > getResponse (HttpHeaders headers ) throws IOException {
448+ @ Override
449+ public JWKSet getJWKSet (JWKSetCacheRefreshEvaluator refreshEvaluator , long currentTime , C context )
450+ throws KeySourceException {
518451 try {
519- RequestEntity <Void > request = new RequestEntity <>(headers , HttpMethod .GET , this .url .toURI ());
520- return this .restOperations .exchange (request , String .class );
521- } catch (Exception ex ) {
522- throw new IOException (ex );
452+ this .reentrantLock .lock ();
453+ if (refreshEvaluator .requiresRefresh (this .jwkSet )) {
454+ this .cache .invalidate ();
455+ }
456+ this .cache .get (this .jwkSetUri , this ::fetchJwks );
457+ return this .jwkSet ;
523458 }
524- }
525-
526- private JWKSet parse ( String cachedJwkSet ) {
527- JWKSet jwkSet = null ;
528- try {
529- jwkSet = JWKSet . parse ( cachedJwkSet );
530- } catch ( ParseException ignored ) {
531- // Ignore invalid cache value
459+ catch ( Cache . ValueRetrievalException ex ) {
460+ if ( ex . getCause () instanceof RemoteKeySourceException keys ) {
461+ throw keys ;
462+ }
463+ throw new RemoteKeySourceException ( ex . getCause (). getMessage (), ex . getCause ());
464+ }
465+ finally {
466+ this . reentrantLock . unlock ();
532467 }
533- return jwkSet ;
534468 }
535469
536- private String getFirstSpecifiedKeyID (JWKMatcher jwkMatcher ) {
537- Set <String > keyIDs = jwkMatcher .getKeyIDs ();
538- return (keyIDs == null || keyIDs .isEmpty ()) ?
539- null : keyIDs .stream ().filter (id -> id != null ).findFirst ().orElse (null );
470+ @ Override
471+ public void close () {
472+
540473 }
474+
541475 }
542476
543477 }
0 commit comments