18
18
import java .security .interfaces .RSAPublicKey ;
19
19
import java .time .Instant ;
20
20
import java .util .Collections ;
21
+ import java .util .HashMap ;
22
+ import java .util .HashSet ;
21
23
import java .util .LinkedHashMap ;
22
24
import java .util .Map ;
25
+ import java .util .Set ;
26
+ import java .util .function .Consumer ;
23
27
import java .util .function .Function ;
24
28
import javax .crypto .SecretKey ;
25
29
31
35
import com .nimbusds .jose .jwk .JWKMatcher ;
32
36
import com .nimbusds .jose .jwk .JWKSelector ;
33
37
import com .nimbusds .jose .jwk .source .JWKSecurityContextJWKSet ;
38
+ import com .nimbusds .jose .jwk .source .JWKSource ;
34
39
import com .nimbusds .jose .proc .BadJOSEException ;
35
40
import com .nimbusds .jose .proc .JWKSecurityContext ;
36
41
import com .nimbusds .jose .proc .JWSKeySelector ;
@@ -233,7 +238,7 @@ public static JwkSourceReactiveJwtDecoderBuilder withJwkSource(Function<SignedJW
233
238
*/
234
239
public static final class JwkSetUriReactiveJwtDecoderBuilder {
235
240
private final String jwkSetUri ;
236
- private JWSAlgorithm jwsAlgorithm = JWSAlgorithm . RS256 ;
241
+ private Set < SignatureAlgorithm > signatureAlgorithms = new HashSet <>() ;
237
242
private WebClient webClient = WebClient .create ();
238
243
239
244
private JwkSetUriReactiveJwtDecoderBuilder (String jwkSetUri ) {
@@ -242,15 +247,30 @@ private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
242
247
}
243
248
244
249
/**
245
- * Use the given signing
246
- * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>.
250
+ * Append the given signing
251
+ * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithm</a>
252
+ * to the set of algorithms to use.
247
253
*
248
254
* @param signatureAlgorithm the algorithm to use
249
255
* @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
250
256
*/
251
257
public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithm (SignatureAlgorithm signatureAlgorithm ) {
252
258
Assert .notNull (signatureAlgorithm , "sig cannot be null" );
253
- this .jwsAlgorithm = JWSAlgorithm .parse (signatureAlgorithm .getName ());
259
+ this .signatureAlgorithms .add (signatureAlgorithm );
260
+ return this ;
261
+ }
262
+
263
+ /**
264
+ * Configure the list of
265
+ * <a href="https://tools.ietf.org/html/rfc7515#section-4.1.1" target="_blank">algorithms</a>
266
+ * to use with the given {@link Consumer}.
267
+ *
268
+ * @param signatureAlgorithmsConsumer a {@link Consumer} for further configuring the algorithm list
269
+ * @return a {@link JwkSetUriReactiveJwtDecoderBuilder} for further configurations
270
+ */
271
+ public JwkSetUriReactiveJwtDecoderBuilder jwsAlgorithms (Consumer <Set <SignatureAlgorithm >> signatureAlgorithmsConsumer ) {
272
+ Assert .notNull (signatureAlgorithmsConsumer , "signatureAlgorithmsConsumer cannot be null" );
273
+ signatureAlgorithmsConsumer .accept (this .signatureAlgorithms );
254
274
return this ;
255
275
}
256
276
@@ -278,28 +298,53 @@ public NimbusReactiveJwtDecoder build() {
278
298
return new NimbusReactiveJwtDecoder (processor ());
279
299
}
280
300
301
+ JWSKeySelector <JWKSecurityContext > jwsKeySelector (JWKSource <JWKSecurityContext > jwkSource ) {
302
+ if (this .signatureAlgorithms .isEmpty ()) {
303
+ return new JWSVerificationKeySelector <>(JWSAlgorithm .RS256 , jwkSource );
304
+ } else if (this .signatureAlgorithms .size () == 1 ) {
305
+ JWSAlgorithm jwsAlgorithm = JWSAlgorithm .parse (this .signatureAlgorithms .iterator ().next ().getName ());
306
+ return new JWSVerificationKeySelector <>(jwsAlgorithm , jwkSource );
307
+ } else {
308
+ Map <JWSAlgorithm , JWSKeySelector <JWKSecurityContext >> jwsKeySelectors = new HashMap <>();
309
+ for (SignatureAlgorithm signatureAlgorithm : this .signatureAlgorithms ) {
310
+ JWSAlgorithm jwsAlg = JWSAlgorithm .parse (signatureAlgorithm .getName ());
311
+ jwsKeySelectors .put (jwsAlg , new JWSVerificationKeySelector <>(jwsAlg , jwkSource ));
312
+ }
313
+ return new JWSAlgorithmMapJWSKeySelector <>(jwsKeySelectors );
314
+ }
315
+ }
316
+
281
317
Converter <JWT , Mono <JWTClaimsSet >> processor () {
282
318
JWKSecurityContextJWKSet jwkSource = new JWKSecurityContextJWKSet ();
283
-
284
- JWSKeySelector <JWKSecurityContext > jwsKeySelector =
285
- new JWSVerificationKeySelector <>(this .jwsAlgorithm , jwkSource );
286
319
DefaultJWTProcessor <JWKSecurityContext > jwtProcessor = new DefaultJWTProcessor <>();
320
+ JWSKeySelector <JWKSecurityContext > jwsKeySelector = jwsKeySelector (jwkSource );
287
321
jwtProcessor .setJWSKeySelector (jwsKeySelector );
288
322
jwtProcessor .setJWTClaimsSetVerifier ((claims , context ) -> {});
289
323
290
324
ReactiveRemoteJWKSource source = new ReactiveRemoteJWKSource (this .jwkSetUri );
291
325
source .setWebClient (this .webClient );
292
326
327
+ Set <JWSAlgorithm > expectedJwsAlgorithms = getExpectedJwsAlgorithms (jwsKeySelector );
293
328
return jwt -> {
294
- JWKSelector selector = createSelector (jwt .getHeader ());
329
+ JWKSelector selector = createSelector (expectedJwsAlgorithms , jwt .getHeader ());
295
330
return source .get (selector )
296
331
.onErrorMap (e -> new IllegalStateException ("Could not obtain the keys" , e ))
297
332
.map (jwkList -> createClaimsSet (jwtProcessor , jwt , new JWKSecurityContext (jwkList )));
298
333
};
299
334
}
300
335
301
- private JWKSelector createSelector (Header header ) {
302
- if (!this .jwsAlgorithm .equals (header .getAlgorithm ())) {
336
+ private Set <JWSAlgorithm > getExpectedJwsAlgorithms (JWSKeySelector <?> jwsKeySelector ) {
337
+ if (jwsKeySelector instanceof JWSVerificationKeySelector ) {
338
+ return Collections .singleton (((JWSVerificationKeySelector <?>) jwsKeySelector ).getExpectedJWSAlgorithm ());
339
+ }
340
+ if (jwsKeySelector instanceof JWSAlgorithmMapJWSKeySelector ) {
341
+ return ((JWSAlgorithmMapJWSKeySelector <?>) jwsKeySelector ).getExpectedJWSAlgorithms ();
342
+ }
343
+ throw new IllegalArgumentException ("Unsupported key selector type " + jwsKeySelector .getClass ());
344
+ }
345
+
346
+ private JWKSelector createSelector (Set <JWSAlgorithm > expectedJwsAlgorithms , Header header ) {
347
+ if (!expectedJwsAlgorithms .contains (header .getAlgorithm ())) {
303
348
throw new JwtException ("Unsupported algorithm of " + header .getAlgorithm ());
304
349
}
305
350
0 commit comments