Skip to content

Commit e16d753

Browse files
committed
Allow athentication at the STOMP level
This commit makes it possible for a ChannelInterceptor to override the user header in a Spring Message that contains a STOMP CONNECT frame. After the message is sent, the updated user header is observed and saved to be associated with session thereafter. Issue: SPR-14690
1 parent b14d189 commit e16d753

File tree

3 files changed

+213
-46
lines changed

3 files changed

+213
-46
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

+32-16
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
105105

106106
private MessageHeaderInitializer headerInitializer;
107107

108+
private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<String, Principal>();
109+
108110
private Boolean immutableMessageInterceptorPresent;
109111

110112
private ApplicationEventPublisher eventPublisher;
@@ -272,11 +274,10 @@ else if (webSocketMessage instanceof BinaryMessage) {
272274
try {
273275
StompHeaderAccessor headerAccessor =
274276
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
275-
Principal user = session.getPrincipal();
276277

277278
headerAccessor.setSessionId(session.getId());
278279
headerAccessor.setSessionAttributes(session.getAttributes());
279-
headerAccessor.setUser(user);
280+
headerAccessor.setUser(getUser(session));
280281
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
281282
if (!detectImmutableMessageInterceptor(outputChannel)) {
282283
headerAccessor.setImmutable();
@@ -286,7 +287,8 @@ else if (webSocketMessage instanceof BinaryMessage) {
286287
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
287288
}
288289

289-
if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
290+
boolean isConnect = StompCommand.CONNECT.equals(headerAccessor.getCommand());
291+
if (isConnect) {
290292
this.stats.incrementConnectCount();
291293
}
292294
else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
@@ -297,15 +299,23 @@ else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
297299
SimpAttributesContextHolder.setAttributesFromMessage(message);
298300
boolean sent = outputChannel.send(message);
299301

300-
if (sent && this.eventPublisher != null) {
301-
if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
302-
publishEvent(new SessionConnectEvent(this, message, user));
303-
}
304-
else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) {
305-
publishEvent(new SessionSubscribeEvent(this, message, user));
302+
if (sent) {
303+
if (isConnect) {
304+
Principal user = headerAccessor.getUser();
305+
if (user != null && user != session.getPrincipal()) {
306+
this.stompAuthentications.put(session.getId(), user);
307+
}
306308
}
307-
else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
308-
publishEvent(new SessionUnsubscribeEvent(this, message, user));
309+
if (this.eventPublisher != null) {
310+
if (isConnect) {
311+
publishEvent(new SessionConnectEvent(this, message, getUser(session)));
312+
}
313+
else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) {
314+
publishEvent(new SessionSubscribeEvent(this, message, getUser(session)));
315+
}
316+
else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
317+
publishEvent(new SessionUnsubscribeEvent(this, message, getUser(session)));
318+
}
309319
}
310320
}
311321
}
@@ -323,6 +333,11 @@ else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
323333
}
324334
}
325335

336+
private Principal getUser(WebSocketSession session) {
337+
Principal user = this.stompAuthentications.get(session.getId());
338+
return user != null ? user : session.getPrincipal();
339+
}
340+
326341
@SuppressWarnings("deprecation")
327342
private void handleError(WebSocketSession session, Throwable ex, Message<byte[]> clientMessage) {
328343
if (getErrorHandler() == null) {
@@ -425,7 +440,7 @@ else if (StompCommand.CONNECTED.equals(command)) {
425440
try {
426441
SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes());
427442
SimpAttributesContextHolder.setAttributes(simpAttributes);
428-
Principal user = session.getPrincipal();
443+
Principal user = getUser(session);
429444
publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message, user));
430445
}
431446
finally {
@@ -566,7 +581,7 @@ protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccess
566581
private StompHeaderAccessor afterStompSessionConnected(Message<?> message, StompHeaderAccessor accessor,
567582
WebSocketSession session) {
568583

569-
Principal principal = session.getPrincipal();
584+
Principal principal = getUser(session);
570585
if (principal != null) {
571586
accessor = toMutableAccessor(accessor, message);
572587
accessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
@@ -613,7 +628,7 @@ public void afterSessionStarted(WebSocketSession session, MessageChannel outputC
613628
public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus, MessageChannel outputChannel) {
614629
this.decoders.remove(session.getId());
615630

616-
Principal principal = session.getPrincipal();
631+
Principal principal = getUser(session);
617632
if (principal != null && this.userSessionRegistry != null) {
618633
String userName = getSessionRegistryUserName(principal);
619634
this.userSessionRegistry.unregisterSessionId(userName, session.getId());
@@ -624,12 +639,13 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
624639
try {
625640
SimpAttributesContextHolder.setAttributes(simpAttributes);
626641
if (this.eventPublisher != null) {
627-
Principal user = session.getPrincipal();
642+
Principal user = getUser(session);
628643
publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus, user));
629644
}
630645
outputChannel.send(message);
631646
}
632647
finally {
648+
this.stompAuthentications.remove(session.getId());
633649
SimpAttributesContextHolder.resetAttributes();
634650
simpAttributes.sessionCompleted();
635651
}
@@ -642,7 +658,7 @@ private Message<byte[]> createDisconnectMessage(WebSocketSession session) {
642658
}
643659
headerAccessor.setSessionId(session.getId());
644660
headerAccessor.setSessionAttributes(session.getAttributes());
645-
headerAccessor.setUser(session.getPrincipal());
661+
headerAccessor.setUser(getUser(session));
646662
return MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders());
647663
}
648664

spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

+60-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.web.socket.messaging;
1818

1919
import java.io.IOException;
20+
import java.security.Principal;
2021
import java.util.ArrayList;
2122
import java.util.Arrays;
2223
import java.util.Collections;
@@ -34,6 +35,8 @@
3435
import org.springframework.context.PayloadApplicationEvent;
3536
import org.springframework.messaging.Message;
3637
import org.springframework.messaging.MessageChannel;
38+
import org.springframework.messaging.MessageHandler;
39+
import org.springframework.messaging.MessagingException;
3740
import org.springframework.messaging.simp.SimpAttributes;
3841
import org.springframework.messaging.simp.SimpAttributesContextHolder;
3942
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
@@ -68,7 +71,7 @@
6871
*/
6972
public class StompSubProtocolHandlerTests {
7073

71-
public static final byte[] EMPTY_PAYLOAD = new byte[0];
74+
private static final byte[] EMPTY_PAYLOAD = new byte[0];
7275

7376
private StompSubProtocolHandler protocolHandler;
7477

@@ -210,22 +213,26 @@ public void handleMessageToClientWithSimpHeartbeat() {
210213
public void handleMessageToClientWithHeartbeatSuppressingSockJsHeartbeat() throws IOException {
211214

212215
SockJsSession sockJsSession = Mockito.mock(SockJsSession.class);
216+
when(sockJsSession.getId()).thenReturn("s1");
213217
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
214218
accessor.setHeartbeat(0, 10);
215219
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
216220
this.protocolHandler.handleMessageToClient(sockJsSession, message);
217221

222+
verify(sockJsSession).getId();
218223
verify(sockJsSession).getPrincipal();
219224
verify(sockJsSession).disableHeartbeat();
220225
verify(sockJsSession).sendMessage(any(WebSocketMessage.class));
221226
verifyNoMoreInteractions(sockJsSession);
222227

223228
sockJsSession = Mockito.mock(SockJsSession.class);
229+
when(sockJsSession.getId()).thenReturn("s1");
224230
accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
225231
accessor.setHeartbeat(0, 0);
226232
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
227233
this.protocolHandler.handleMessageToClient(sockJsSession, message);
228234

235+
verify(sockJsSession).getId();
229236
verify(sockJsSession).getPrincipal();
230237
verify(sockJsSession).sendMessage(any(WebSocketMessage.class));
231238
verifyNoMoreInteractions(sockJsSession);
@@ -352,6 +359,28 @@ public Message<?> preSend(Message<?> message, MessageChannel channel) {
352359
assertFalse(mutable.get());
353360
}
354361

362+
@Test // SPR-14690
363+
public void handleMessageFromClientWithTokenAuthentication() {
364+
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
365+
channel.addInterceptor(new AuthenticationInterceptor("[email protected]"));
366+
channel.addInterceptor(new ImmutableMessageChannelInterceptor());
367+
368+
TestMessageHandler messageHandler = new TestMessageHandler();
369+
channel.subscribe(messageHandler);
370+
371+
StompSubProtocolHandler handler = new StompSubProtocolHandler();
372+
handler.afterSessionStarted(this.session, channel);
373+
374+
TextMessage wsMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
375+
handler.handleMessageFromClient(this.session, wsMessage, channel);
376+
377+
assertEquals(1, messageHandler.getMessages().size());
378+
Message<?> message = messageHandler.getMessages().get(0);
379+
Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
380+
assertNotNull(user);
381+
assertEquals("[email protected]", user.getName());
382+
}
383+
355384
@Test
356385
public void handleMessageFromClientWithInvalidStompCommand() {
357386

@@ -504,4 +533,34 @@ public void publishEvent(Object event) {
504533
}
505534
}
506535

536+
private static class TestMessageHandler implements MessageHandler {
537+
538+
private final List<Message> messages = new ArrayList<>();
539+
540+
public List<Message> getMessages() {
541+
return this.messages;
542+
}
543+
544+
@Override
545+
public void handleMessage(Message<?> message) throws MessagingException {
546+
this.messages.add(message);
547+
}
548+
}
549+
550+
private static class AuthenticationInterceptor extends ChannelInterceptorAdapter {
551+
552+
private final String name;
553+
554+
555+
public AuthenticationInterceptor(String name) {
556+
this.name = name;
557+
}
558+
559+
@Override
560+
public Message<?> preSend(Message<?> message, MessageChannel channel) {
561+
TestPrincipal user = new TestPrincipal(name);
562+
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class).setUser(user);
563+
return message;
564+
}
565+
}
507566
}

0 commit comments

Comments
 (0)