Skip to content

Commit 11805b6

Browse files
committed
ServerEndpointExporter can initialize itself based on a late-provided ServletContext as well (for Boot)
Also allows for direct setting of a ServerContainer and for custom triggering of endpoint registration. Issue: SPR-12109
1 parent 60e58a2 commit 11805b6

File tree

2 files changed

+134
-72
lines changed

2 files changed

+134
-72
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/server/standard/ServerEndpointExporter.java

+78-64
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,21 @@
1616

1717
package org.springframework.web.socket.server.standard;
1818

19-
import java.lang.reflect.Method;
20-
import java.util.ArrayList;
2119
import java.util.Arrays;
20+
import java.util.LinkedHashSet;
2221
import java.util.List;
23-
import java.util.Map;
22+
import java.util.Set;
23+
import javax.servlet.ServletContext;
2424
import javax.websocket.DeploymentException;
2525
import javax.websocket.server.ServerContainer;
2626
import javax.websocket.server.ServerEndpoint;
2727
import javax.websocket.server.ServerEndpointConfig;
2828

29-
import org.apache.commons.logging.Log;
30-
import org.apache.commons.logging.LogFactory;
31-
32-
import org.springframework.beans.BeansException;
3329
import org.springframework.beans.factory.InitializingBean;
3430
import org.springframework.beans.factory.config.BeanPostProcessor;
3531
import org.springframework.context.ApplicationContext;
36-
import org.springframework.context.ApplicationContextAware;
3732
import org.springframework.util.Assert;
38-
import org.springframework.util.ClassUtils;
39-
import org.springframework.util.ReflectionUtils;
33+
import org.springframework.web.context.support.WebApplicationObjectSupport;
4034

4135
/**
4236
* Detects beans of type {@link javax.websocket.server.ServerEndpointConfig} and registers
@@ -50,24 +44,36 @@
5044
* done with the help of the {@code <absolute-ordering>} element in web.xml.
5145
*
5246
* @author Rossen Stoyanchev
47+
* @author Juergen Hoeller
5348
* @since 4.0
5449
* @see ServerEndpointRegistration
5550
* @see SpringConfigurator
5651
* @see ServletServerContainerFactoryBean
5752
*/
58-
public class ServerEndpointExporter implements InitializingBean, BeanPostProcessor, ApplicationContextAware {
59-
60-
private static final Log logger = LogFactory.getLog(ServerEndpointExporter.class);
53+
public class ServerEndpointExporter extends WebApplicationObjectSupport implements BeanPostProcessor, InitializingBean {
6154

55+
private ServerContainer serverContainer;
6256

63-
private final List<Class<?>> annotatedEndpointClasses = new ArrayList<Class<?>>();
57+
private List<Class<?>> annotatedEndpointClasses;
6458

65-
private final List<Class<?>> annotatedEndpointBeanTypes = new ArrayList<Class<?>>();
59+
private Set<Class<?>> annotatedEndpointBeanTypes;
6660

67-
private ApplicationContext applicationContext;
6861

69-
private ServerContainer serverContainer;
62+
/**
63+
* Set the JSR-356 {@link ServerContainer} to use for endpoint registration.
64+
* If not set, the container is going to be retrieved via the {@code ServletContext}.
65+
* @since 4.1
66+
*/
67+
public void setServerContainer(ServerContainer serverContainer) {
68+
this.serverContainer = serverContainer;
69+
}
7070

71+
/**
72+
* Return the JSR-356 {@link ServerContainer} to use for endpoint registration.
73+
*/
74+
protected ServerContainer getServerContainer() {
75+
return this.serverContainer;
76+
}
7177

7278
/**
7379
* Explicitly list annotated endpoint types that should be registered on startup. This
@@ -76,84 +82,92 @@ public class ServerEndpointExporter implements InitializingBean, BeanPostProcess
7682
* @param annotatedEndpointClasses {@link ServerEndpoint}-annotated types
7783
*/
7884
public void setAnnotatedEndpointClasses(Class<?>... annotatedEndpointClasses) {
79-
this.annotatedEndpointClasses.clear();
80-
this.annotatedEndpointClasses.addAll(Arrays.asList(annotatedEndpointClasses));
85+
this.annotatedEndpointClasses = Arrays.asList(annotatedEndpointClasses);
8186
}
8287

8388
@Override
84-
public void setApplicationContext(ApplicationContext applicationContext) {
85-
this.applicationContext = applicationContext;
86-
this.serverContainer = getServerContainer();
87-
Map<String, Object> beans = applicationContext.getBeansWithAnnotation(ServerEndpoint.class);
88-
for (String beanName : beans.keySet()) {
89-
Class<?> beanType = applicationContext.getType(beanName);
89+
protected void initApplicationContext(ApplicationContext context) {
90+
// Initializes ServletContext given a WebApplicationContext
91+
super.initApplicationContext(context);
92+
93+
// Retrieve beans which are annotated with @ServerEndpoint
94+
this.annotatedEndpointBeanTypes = new LinkedHashSet<Class<?>>();
95+
String[] beanNames = context.getBeanNamesForAnnotation(ServerEndpoint.class);
96+
for (String beanName : beanNames) {
97+
Class<?> beanType = context.getType(beanName);
9098
if (logger.isInfoEnabled()) {
9199
logger.info("Detected @ServerEndpoint bean '" + beanName + "', registering it as an endpoint by type");
92100
}
93101
this.annotatedEndpointBeanTypes.add(beanType);
94102
}
95103
}
96104

97-
protected ServerContainer getServerContainer() {
98-
Class<?> servletContextClass;
99-
try {
100-
servletContextClass = ClassUtils.forName("javax.servlet.ServletContext", getClass().getClassLoader());
105+
@Override
106+
protected void initServletContext(ServletContext servletContext) {
107+
if (this.serverContainer == null) {
108+
this.serverContainer =
109+
(ServerContainer) servletContext.getAttribute("javax.websocket.server.ServerContainer");
110+
}
111+
}
112+
113+
114+
@Override
115+
public void afterPropertiesSet() {
116+
Assert.state(getServerContainer() != null, "javax.websocket.server.ServerContainer not available");
117+
registerEndpoints();
118+
}
119+
120+
/**
121+
* Actually register the endpoints. Called by {@link #afterPropertiesSet()}.
122+
* @since 4.1
123+
*/
124+
protected void registerEndpoints() {
125+
Set<Class<?>> endpointClasses = new LinkedHashSet<Class<?>>();
126+
if (this.annotatedEndpointClasses != null) {
127+
endpointClasses.addAll(this.annotatedEndpointClasses);
128+
}
129+
if (this.annotatedEndpointBeanTypes != null) {
130+
endpointClasses.addAll(this.annotatedEndpointBeanTypes);
101131
}
102-
catch (Throwable ex) {
103-
return null;
132+
for (Class<?> endpointClass : endpointClasses) {
133+
registerEndpoint(endpointClass);
104134
}
135+
}
105136

137+
private void registerEndpoint(Class<?> endpointClass) {
106138
try {
107-
Method getter = ReflectionUtils.findMethod(this.applicationContext.getClass(), "getServletContext");
108-
Object servletContext = getter.invoke(this.applicationContext);
109-
Method attrMethod = ReflectionUtils.findMethod(servletContextClass, "getAttribute", String.class);
110-
return (ServerContainer) attrMethod.invoke(servletContext, "javax.websocket.server.ServerContainer");
139+
if (logger.isInfoEnabled()) {
140+
logger.info("Registering @ServerEndpoint type: " + endpointClass);
141+
}
142+
getServerContainer().addEndpoint(endpointClass);
111143
}
112-
catch (Exception ex) {
113-
throw new IllegalStateException(
114-
"Failed to get javax.websocket.server.ServerContainer via ServletContext attribute", ex);
144+
catch (DeploymentException ex) {
145+
throw new IllegalStateException("Failed to register @ServerEndpoint type " + endpointClass, ex);
115146
}
116147
}
117148

118-
@Override
119-
public void afterPropertiesSet() throws Exception {
120-
Assert.state(this.serverContainer != null, "javax.websocket.server.ServerContainer not available");
121-
122-
List<Class<?>> allClasses = new ArrayList<Class<?>>(this.annotatedEndpointClasses);
123-
allClasses.addAll(this.annotatedEndpointBeanTypes);
124149

125-
for (Class<?> clazz : allClasses) {
126-
try {
127-
logger.info("Registering @ServerEndpoint type " + clazz);
128-
this.serverContainer.addEndpoint(clazz);
129-
}
130-
catch (DeploymentException e) {
131-
throw new IllegalStateException("Failed to register @ServerEndpoint type " + clazz, e);
132-
}
133-
}
150+
@Override
151+
public Object postProcessBeforeInitialization(Object bean, String beanName) {
152+
return bean;
134153
}
135154

136155
@Override
137-
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
156+
public Object postProcessAfterInitialization(Object bean, String beanName) {
138157
if (bean instanceof ServerEndpointConfig) {
139-
ServerEndpointConfig sec = (ServerEndpointConfig) bean;
158+
ServerEndpointConfig endpointConfig = (ServerEndpointConfig) bean;
140159
try {
141160
if (logger.isInfoEnabled()) {
142161
logger.info("Registering bean '" + beanName +
143-
"' as javax.websocket.Endpoint under path " + sec.getPath());
162+
"' as javax.websocket.Endpoint under path " + endpointConfig.getPath());
144163
}
145-
getServerContainer().addEndpoint(sec);
164+
getServerContainer().addEndpoint(endpointConfig);
146165
}
147-
catch (DeploymentException e) {
148-
throw new IllegalStateException("Failed to deploy Endpoint bean " + bean, e);
166+
catch (DeploymentException ex) {
167+
throw new IllegalStateException("Failed to deploy Endpoint bean with name '" + bean + "'", ex);
149168
}
150169
}
151170
return bean;
152171
}
153172

154-
@Override
155-
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
156-
return bean;
157-
}
158-
159173
}

spring-websocket/src/test/java/org/springframework/web/socket/server/standard/ServerEndpointExporterTests.java

+56-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2013 the original author or authors.
2+
* Copyright 2002-2014 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,6 +16,7 @@
1616

1717
package org.springframework.web.socket.server.standard;
1818

19+
import javax.servlet.ServletContext;
1920
import javax.websocket.Endpoint;
2021
import javax.websocket.EndpointConfig;
2122
import javax.websocket.Session;
@@ -24,6 +25,7 @@
2425

2526
import org.junit.Before;
2627
import org.junit.Test;
28+
2729
import org.springframework.context.annotation.Bean;
2830
import org.springframework.context.annotation.Configuration;
2931
import org.springframework.mock.web.test.MockServletContext;
@@ -35,37 +37,59 @@
3537
* Test fixture for {@link ServerEndpointExporter}.
3638
*
3739
* @author Rossen Stoyanchev
40+
* @author Juergen Hoeller
3841
*/
3942
public class ServerEndpointExporterTests {
4043

4144
private ServerContainer serverContainer;
4245

43-
private ServerEndpointExporter exporter;
46+
private ServletContext servletContext;
4447

4548
private AnnotationConfigWebApplicationContext webAppContext;
4649

50+
private ServerEndpointExporter exporter;
51+
4752

4853
@Before
4954
public void setup() {
5055
this.serverContainer = mock(ServerContainer.class);
5156

52-
MockServletContext servletContext = new MockServletContext();
53-
servletContext.setAttribute("javax.websocket.server.ServerContainer", this.serverContainer);
57+
this.servletContext = new MockServletContext();
58+
this.servletContext.setAttribute("javax.websocket.server.ServerContainer", this.serverContainer);
5459

5560
this.webAppContext = new AnnotationConfigWebApplicationContext();
5661
this.webAppContext.register(Config.class);
57-
this.webAppContext.setServletContext(servletContext);
62+
this.webAppContext.setServletContext(this.servletContext);
5863
this.webAppContext.refresh();
5964

6065
this.exporter = new ServerEndpointExporter();
61-
this.exporter.setApplicationContext(this.webAppContext);
6266
}
6367

6468

6569
@Test
66-
public void addAnnotatedEndpointBean() throws Exception {
67-
70+
public void addAnnotatedEndpointBeans() throws Exception {
6871
this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class);
72+
this.exporter.setApplicationContext(this.webAppContext);
73+
this.exporter.afterPropertiesSet();
74+
75+
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class);
76+
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpointBean.class);
77+
}
78+
79+
@Test
80+
public void addAnnotatedEndpointBeansWithServletContextOnly() throws Exception {
81+
this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class, AnnotatedDummyEndpointBean.class);
82+
this.exporter.setServletContext(this.servletContext);
83+
this.exporter.afterPropertiesSet();
84+
85+
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class);
86+
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpointBean.class);
87+
}
88+
89+
@Test
90+
public void addAnnotatedEndpointBeansWithServerContainerOnly() throws Exception {
91+
this.exporter.setAnnotatedEndpointClasses(AnnotatedDummyEndpoint.class, AnnotatedDummyEndpointBean.class);
92+
this.exporter.setServerContainer(this.serverContainer);
6993
this.exporter.afterPropertiesSet();
7094

7195
verify(this.serverContainer).addEndpoint(AnnotatedDummyEndpoint.class);
@@ -74,10 +98,31 @@ public void addAnnotatedEndpointBean() throws Exception {
7498

7599
@Test
76100
public void addServerEndpointConfigBean() throws Exception {
101+
this.exporter.setApplicationContext(this.webAppContext);
102+
this.exporter.afterPropertiesSet();
103+
104+
ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint());
105+
this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint");
106+
verify(this.serverContainer).addEndpoint(endpointRegistration);
107+
}
108+
109+
@Test
110+
public void addServerEndpointConfigBeanWithServletContextOnly() throws Exception {
111+
this.exporter.setServletContext(this.servletContext);
112+
this.exporter.afterPropertiesSet();
77113

78114
ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint());
79115
this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint");
116+
verify(this.serverContainer).addEndpoint(endpointRegistration);
117+
}
118+
119+
@Test
120+
public void addServerEndpointConfigBeanWithServerContainerOnly() throws Exception {
121+
this.exporter.setServerContainer(this.serverContainer);
122+
this.exporter.afterPropertiesSet();
80123

124+
ServerEndpointRegistration endpointRegistration = new ServerEndpointRegistration("/dummy", new DummyEndpoint());
125+
this.exporter.postProcessAfterInitialization(endpointRegistration, "dummyEndpoint");
81126
verify(this.serverContainer).addEndpoint(endpointRegistration);
82127
}
83128

@@ -89,14 +134,17 @@ public void onOpen(Session session, EndpointConfig config) {
89134
}
90135
}
91136

137+
92138
@ServerEndpoint("/path")
93139
private static class AnnotatedDummyEndpoint {
94140
}
95141

142+
96143
@ServerEndpoint("/path")
97144
private static class AnnotatedDummyEndpointBean {
98145
}
99146

147+
100148
@Configuration
101149
static class Config {
102150

0 commit comments

Comments
 (0)