Skip to content

Commit ae75db2

Browse files
Benjamin Faalrstoyanchev
authored andcommitted
Add allowedOriginPatterns to SockJS config
See gh-26108
1 parent 4cc8312 commit ae75db2

File tree

6 files changed

+126
-8
lines changed

6 files changed

+126
-8
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/SockJsServiceRegistration.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ public class SockJsServiceRegistration {
7373

7474
private final List<String> allowedOrigins = new ArrayList<>();
7575

76+
private final List<String> allowedOriginPatterns = new ArrayList<>();
77+
7678
@Nullable
7779
private Boolean suppressCors;
7880

@@ -232,6 +234,18 @@ protected SockJsServiceRegistration setAllowedOrigins(String... allowedOrigins)
232234
return this;
233235
}
234236

237+
/**
238+
* Configure allowed {@code Origin} pattern header values.
239+
* @since 5.3.2
240+
*/
241+
protected SockJsServiceRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) {
242+
this.allowedOriginPatterns.clear();
243+
if (!ObjectUtils.isEmpty(allowedOriginPatterns)) {
244+
this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns));
245+
}
246+
return this;
247+
}
248+
235249
/**
236250
* This option can be used to disable automatic addition of CORS headers for
237251
* SockJS requests.
@@ -284,6 +298,7 @@ protected SockJsService getSockJsService() {
284298
service.setSuppressCors(this.suppressCors);
285299
}
286300
service.setAllowedOrigins(this.allowedOrigins);
301+
service.setAllowedOriginPatterns(this.allowedOriginPatterns);
287302

288303
if (this.messageCodec != null) {
289304
service.setMessageCodec(this.messageCodec);

spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/StompWebSocketEndpointRegistration.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,11 @@ public interface StompWebSocketEndpointRegistration {
6161
*/
6262
StompWebSocketEndpointRegistration setAllowedOrigins(String... origins);
6363

64+
/**
65+
* Configure allowed {@code Origin} header values.
66+
*
67+
* @see org.springframework.web.cors.CorsConfiguration#setAllowedOriginPatterns(java.util.List)
68+
*/
69+
StompWebSocketEndpointRegistration setAllowedOriginPatterns(String... originPatterns);
70+
6471
}

spring-websocket/src/main/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistration.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ public class WebMvcStompWebSocketEndpointRegistration implements StompWebSocketE
5858

5959
private final List<String> allowedOrigins = new ArrayList<>();
6060

61+
private final List<String> allowedOriginPatterns = new ArrayList<>();
62+
6163
@Nullable
6264
private SockJsServiceRegistration registration;
6365

@@ -97,6 +99,15 @@ public StompWebSocketEndpointRegistration setAllowedOrigins(String... allowedOri
9799
return this;
98100
}
99101

102+
@Override
103+
public StompWebSocketEndpointRegistration setAllowedOriginPatterns(String... allowedOriginPatterns) {
104+
this.allowedOriginPatterns.clear();
105+
if (!ObjectUtils.isEmpty(allowedOriginPatterns)) {
106+
this.allowedOriginPatterns.addAll(Arrays.asList(allowedOriginPatterns));
107+
}
108+
return this;
109+
}
110+
100111
@Override
101112
public SockJsServiceRegistration withSockJS() {
102113
this.registration = new SockJsServiceRegistration();
@@ -112,13 +123,22 @@ public SockJsServiceRegistration withSockJS() {
112123
if (!this.allowedOrigins.isEmpty()) {
113124
this.registration.setAllowedOrigins(StringUtils.toStringArray(this.allowedOrigins));
114125
}
126+
if (!this.allowedOriginPatterns.isEmpty()) {
127+
this.registration.setAllowedOriginPatterns(StringUtils.toStringArray(this.allowedOriginPatterns));
128+
}
115129
return this.registration;
116130
}
117131

118132
protected HandshakeInterceptor[] getInterceptors() {
119133
List<HandshakeInterceptor> interceptors = new ArrayList<>(this.interceptors.size() + 1);
120134
interceptors.addAll(this.interceptors);
121-
interceptors.add(new OriginHandshakeInterceptor(this.allowedOrigins));
135+
OriginHandshakeInterceptor originHandshakeInterceptor = new OriginHandshakeInterceptor(this.allowedOrigins);
136+
interceptors.add(originHandshakeInterceptor);
137+
138+
if (!ObjectUtils.isEmpty(this.allowedOriginPatterns)) {
139+
originHandshakeInterceptor.setAllowedOriginPatterns(this.allowedOriginPatterns);
140+
}
141+
122142
return interceptors.toArray(new HandshakeInterceptor[0]);
123143
}
124144

spring-websocket/src/main/java/org/springframework/web/socket/server/support/OriginHandshakeInterceptor.java

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
package org.springframework.web.socket.server.support;
1818

19+
import java.util.ArrayList;
1920
import java.util.Collection;
2021
import java.util.Collections;
21-
import java.util.LinkedHashSet;
22+
import java.util.HashSet;
23+
import java.util.List;
2224
import java.util.Map;
23-
import java.util.Set;
2425

2526
import org.apache.commons.logging.Log;
2627
import org.apache.commons.logging.LogFactory;
@@ -30,6 +31,7 @@
3031
import org.springframework.http.server.ServerHttpResponse;
3132
import org.springframework.lang.Nullable;
3233
import org.springframework.util.Assert;
34+
import org.springframework.web.cors.CorsConfiguration;
3335
import org.springframework.web.socket.WebSocketHandler;
3436
import org.springframework.web.socket.server.HandshakeInterceptor;
3537
import org.springframework.web.util.WebUtils;
@@ -45,7 +47,7 @@ public class OriginHandshakeInterceptor implements HandshakeInterceptor {
4547

4648
protected final Log logger = LogFactory.getLog(getClass());
4749

48-
private final Set<String> allowedOrigins = new LinkedHashSet<>();
50+
private final CorsConfiguration corsConfiguration = new CorsConfiguration();
4951

5052

5153
/**
@@ -74,8 +76,7 @@ public OriginHandshakeInterceptor(Collection<String> allowedOrigins) {
7476
*/
7577
public void setAllowedOrigins(Collection<String> allowedOrigins) {
7678
Assert.notNull(allowedOrigins, "Allowed origins Collection must not be null");
77-
this.allowedOrigins.clear();
78-
this.allowedOrigins.addAll(allowedOrigins);
79+
this.corsConfiguration.setAllowedOrigins(new ArrayList<>(allowedOrigins));
7980
}
8081

8182
/**
@@ -84,15 +85,41 @@ public void setAllowedOrigins(Collection<String> allowedOrigins) {
8485
* @see #setAllowedOrigins
8586
*/
8687
public Collection<String> getAllowedOrigins() {
87-
return Collections.unmodifiableSet(this.allowedOrigins);
88+
if (this.corsConfiguration.getAllowedOrigins() == null) {
89+
return Collections.emptyList();
90+
}
91+
return Collections.unmodifiableSet(new HashSet<>(this.corsConfiguration.getAllowedOrigins()));
92+
}
93+
94+
/**
95+
* Configure allowed {@code Origin} pattern header values.
96+
*
97+
* @see CorsConfiguration#setAllowedOriginPatterns(List)
98+
*/
99+
public void setAllowedOriginPatterns(Collection<String> allowedOriginPatterns) {
100+
Assert.notNull(allowedOriginPatterns, "Allowed origin patterns Collection must not be null");
101+
this.corsConfiguration.setAllowedOriginPatterns(new ArrayList<>(allowedOriginPatterns));
102+
}
103+
104+
/**
105+
* Return the allowed {@code Origin} pattern header values.
106+
*
107+
* @since 5.3.2
108+
* @see CorsConfiguration#getAllowedOriginPatterns()
109+
*/
110+
public Collection<String> getAllowedOriginPatterns() {
111+
if (this.corsConfiguration.getAllowedOriginPatterns() == null) {
112+
return Collections.emptyList();
113+
}
114+
return Collections.unmodifiableSet(new HashSet<>(this.corsConfiguration.getAllowedOriginPatterns()));
88115
}
89116

90117

91118
@Override
92119
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
93120
WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
94121

95-
if (!WebUtils.isSameOrigin(request) && !WebUtils.isValidOrigin(request, this.allowedOrigins)) {
122+
if (!WebUtils.isSameOrigin(request) && this.corsConfiguration.checkOrigin(request.getHeaders().getOrigin()) == null) {
96123
response.setStatusCode(HttpStatus.FORBIDDEN);
97124
if (logger.isDebugEnabled()) {
98125
logger.debug("Handshake request rejected, Origin header value " +

spring-websocket/src/main/java/org/springframework/web/socket/sockjs/support/AbstractSockJsService.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ public abstract class AbstractSockJsService implements SockJsService, CorsConfig
9999

100100
protected final Set<String> allowedOrigins = new LinkedHashSet<>();
101101

102+
protected final Set<String> allowedOriginPatterns = new LinkedHashSet<>();
103+
102104
private final SockJsRequestHandler infoHandler = new InfoHandler();
103105

104106
private final SockJsRequestHandler iframeHandler = new IframeHandler();
@@ -319,6 +321,17 @@ public void setAllowedOrigins(Collection<String> allowedOrigins) {
319321
this.allowedOrigins.addAll(allowedOrigins);
320322
}
321323

324+
/**
325+
* Configure allowed {@code Origin} header values.
326+
*
327+
* @see org.springframework.web.cors.CorsConfiguration#setAllowedOriginPatterns(java.util.List)
328+
*/
329+
public void setAllowedOriginPatterns(Collection<String> allowedOriginPatterns) {
330+
Assert.notNull(allowedOriginPatterns, "Allowed origin patterns Collection must not be null");
331+
this.allowedOriginPatterns.clear();
332+
this.allowedOriginPatterns.addAll(allowedOriginPatterns);
333+
}
334+
322335
/**
323336
* Return configure allowed {@code Origin} header values.
324337
* @since 4.1.2
@@ -328,6 +341,15 @@ public Collection<String> getAllowedOrigins() {
328341
return Collections.unmodifiableSet(this.allowedOrigins);
329342
}
330343

344+
/**
345+
* Return configure allowed {@code Origin} pattern header values.
346+
* @since 5.3.2
347+
* @see #setAllowedOriginPatterns
348+
*/
349+
public Collection<String> getAllowedOriginPatterns() {
350+
return Collections.unmodifiableSet(this.allowedOriginPatterns);
351+
}
352+
331353

332354
/**
333355
* This method determines the SockJS path and handles SockJS static URLs.
@@ -498,6 +520,7 @@ public CorsConfiguration getCorsConfiguration(HttpServletRequest request) {
498520
if (!this.suppressCors && (request.getHeader(HttpHeaders.ORIGIN) != null)) {
499521
CorsConfiguration config = new CorsConfiguration();
500522
config.setAllowedOrigins(new ArrayList<>(this.allowedOrigins));
523+
config.setAllowedOriginPatterns(new ArrayList<>(this.allowedOriginPatterns));
501524
config.addAllowedMethod("*");
502525
config.setAllowCredentials(true);
503526
config.setMaxAge(ONE_YEAR);

spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebMvcStompWebSocketEndpointRegistrationTests.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,32 @@ public void allowedOriginsWithSockJsService() {
135135
assertThat(sockJsService.shouldSuppressCors()).isFalse();
136136
}
137137

138+
@Test
139+
public void allowedOriginPatterns() {
140+
WebMvcStompWebSocketEndpointRegistration registration =
141+
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
142+
143+
String origin = "https://*.mydomain.com";
144+
registration.setAllowedOriginPatterns(origin).withSockJS();
145+
146+
MultiValueMap<HttpRequestHandler, String> mappings = registration.getMappings();
147+
assertThat(mappings.size()).isEqualTo(1);
148+
SockJsHttpRequestHandler requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
149+
assertThat(requestHandler.getSockJsService()).isNotNull();
150+
DefaultSockJsService sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
151+
assertThat(sockJsService.getAllowedOriginPatterns().contains(origin)).isTrue();
152+
153+
registration =
154+
new WebMvcStompWebSocketEndpointRegistration(new String[] {"/foo"}, this.handler, this.scheduler);
155+
registration.withSockJS().setAllowedOriginPatterns(origin);
156+
mappings = registration.getMappings();
157+
assertThat(mappings.size()).isEqualTo(1);
158+
requestHandler = (SockJsHttpRequestHandler)mappings.entrySet().iterator().next().getKey();
159+
assertThat(requestHandler.getSockJsService()).isNotNull();
160+
sockJsService = (DefaultSockJsService)requestHandler.getSockJsService();
161+
assertThat(sockJsService.getAllowedOriginPatterns().contains(origin)).isTrue();
162+
}
163+
138164
@Test // SPR-12283
139165
public void disableCorsWithSockJsService() {
140166
WebMvcStompWebSocketEndpointRegistration registration =

0 commit comments

Comments
 (0)