25
25
import java .util .LinkedHashSet ;
26
26
import java .util .List ;
27
27
import java .util .Random ;
28
- import java .util .Set ;
29
28
import java .util .concurrent .TimeUnit ;
30
29
31
30
import javax .servlet .http .HttpServletRequest ;
@@ -97,9 +96,7 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
97
96
98
97
private boolean suppressCors = false ;
99
98
100
- protected final Set <String > allowedOrigins = new LinkedHashSet <>();
101
-
102
- protected final Set <String > allowedOriginPatterns = new LinkedHashSet <>();
99
+ protected final CorsConfiguration corsConfiguration ;
103
100
104
101
private final SockJsRequestHandler infoHandler = new InfoHandler ();
105
102
@@ -109,6 +106,18 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
109
106
public AbstractSockJsService (TaskScheduler scheduler ) {
110
107
Assert .notNull (scheduler , "TaskScheduler must not be null" );
111
108
this .taskScheduler = scheduler ;
109
+ this .corsConfiguration = initCorsConfiguration ();
110
+ }
111
+
112
+ private static CorsConfiguration initCorsConfiguration () {
113
+ CorsConfiguration config = new CorsConfiguration ();
114
+ config .addAllowedMethod ("*" );
115
+ config .setAllowedOrigins (Collections .emptyList ());
116
+ config .setAllowedOriginPatterns (Collections .emptyList ());
117
+ config .setAllowCredentials (true );
118
+ config .setMaxAge (ONE_YEAR );
119
+ config .addAllowedHeader ("*" );
120
+ return config ;
112
121
}
113
122
114
123
@@ -317,10 +326,18 @@ public boolean shouldSuppressCors() {
317
326
*/
318
327
public void setAllowedOrigins (Collection <String > allowedOrigins ) {
319
328
Assert .notNull (allowedOrigins , "Allowed origins Collection must not be null" );
320
- this .allowedOrigins .clear ();
321
- this .allowedOrigins .addAll (allowedOrigins );
329
+ this .corsConfiguration .setAllowedOrigins (new ArrayList <>(allowedOrigins ));
322
330
}
323
331
332
+ /**
333
+ * Return configure allowed {@code Origin} header values.
334
+ * @since 4.1.2
335
+ * @see #setAllowedOrigins
336
+ */
337
+ @ SuppressWarnings ("ConstantConditions" )
338
+ public Collection <String > getAllowedOrigins () {
339
+ return this .corsConfiguration .getAllowedOrigins ();
340
+ }
324
341
/**
325
342
* A variant of {@link #setAllowedOrigins(Collection)} that accepts flexible
326
343
* domain patterns, e.g. {@code "https://*.domain1.com"}. Furthermore it
@@ -331,26 +348,17 @@ public void setAllowedOrigins(Collection<String> allowedOrigins) {
331
348
*/
332
349
public void setAllowedOriginPatterns (Collection <String > allowedOriginPatterns ) {
333
350
Assert .notNull (allowedOriginPatterns , "Allowed origin patterns Collection must not be null" );
334
- this .allowedOriginPatterns .clear ();
335
- this .allowedOriginPatterns .addAll (allowedOriginPatterns );
336
- }
337
-
338
- /**
339
- * Return configure allowed {@code Origin} header values.
340
- * @since 4.1.2
341
- * @see #setAllowedOrigins
342
- */
343
- public Collection <String > getAllowedOrigins () {
344
- return Collections .unmodifiableSet (this .allowedOrigins );
351
+ this .corsConfiguration .setAllowedOriginPatterns (new ArrayList <>(allowedOriginPatterns ));
345
352
}
346
353
347
354
/**
348
- * Return configure allowed {@code Origin} pattern header values .
355
+ * Return {@link #setAllowedOriginPatterns(Collection) configured} origin patterns .
349
356
* @since 5.3.2
350
357
* @see #setAllowedOriginPatterns
351
358
*/
359
+ @ SuppressWarnings ("ConstantConditions" )
352
360
public Collection <String > getAllowedOriginPatterns () {
353
- return Collections . unmodifiableSet ( this .allowedOriginPatterns );
361
+ return this .corsConfiguration . getAllowedOriginPatterns ( );
354
362
}
355
363
356
364
@@ -396,15 +404,16 @@ else if (sockJsPath.equals("/info")) {
396
404
}
397
405
398
406
else if (sockJsPath .matches ("/iframe[0-9-.a-z_]*.html" )) {
399
- if (!this .allowedOrigins .isEmpty () && !this .allowedOrigins .contains ("*" )) {
407
+ if (!getAllowedOrigins ().isEmpty () && !getAllowedOrigins ().contains ("*" ) ||
408
+ !getAllowedOriginPatterns ().isEmpty ()) {
400
409
if (requestInfo != null ) {
401
410
logger .debug ("Iframe support is disabled when an origin check is required. " +
402
411
"Ignoring transport request: " + requestInfo );
403
412
}
404
413
response .setStatusCode (HttpStatus .NOT_FOUND );
405
414
return ;
406
415
}
407
- if (this . allowedOrigins .isEmpty ()) {
416
+ if (getAllowedOrigins () .isEmpty ()) {
408
417
response .getHeaders ().add (XFRAME_OPTIONS_HEADER , "SAMEORIGIN" );
409
418
}
410
419
if (requestInfo != null ) {
@@ -506,7 +515,7 @@ protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse resp
506
515
return true ;
507
516
}
508
517
509
- if (! WebUtils . isValidOrigin (request , this . allowedOrigins ) ) {
518
+ if (this . corsConfiguration . checkOrigin (request . getHeaders (). getOrigin ()) == null ) {
510
519
if (logger .isWarnEnabled ()) {
511
520
logger .warn ("Origin header value '" + request .getHeaders ().getOrigin () + "' not allowed." );
512
521
}
@@ -521,14 +530,7 @@ protected boolean checkOrigin(ServerHttpRequest request, ServerHttpResponse resp
521
530
@ Nullable
522
531
public CorsConfiguration getCorsConfiguration (HttpServletRequest request ) {
523
532
if (!this .suppressCors && (request .getHeader (HttpHeaders .ORIGIN ) != null )) {
524
- CorsConfiguration config = new CorsConfiguration ();
525
- config .setAllowedOrigins (new ArrayList <>(this .allowedOrigins ));
526
- config .setAllowedOriginPatterns (new ArrayList <>(this .allowedOriginPatterns ));
527
- config .addAllowedMethod ("*" );
528
- config .setAllowCredentials (true );
529
- config .setMaxAge (ONE_YEAR );
530
- config .addAllowedHeader ("*" );
531
- return config ;
533
+ return this .corsConfiguration ;
532
534
}
533
535
return null ;
534
536
}
0 commit comments