@@ -105,6 +105,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
105
105
106
106
private MessageHeaderInitializer headerInitializer ;
107
107
108
+ private final Map <String , Principal > stompAuthentications = new ConcurrentHashMap <String , Principal >();
109
+
108
110
private Boolean immutableMessageInterceptorPresent ;
109
111
110
112
private ApplicationEventPublisher eventPublisher ;
@@ -272,11 +274,10 @@ else if (webSocketMessage instanceof BinaryMessage) {
272
274
try {
273
275
StompHeaderAccessor headerAccessor =
274
276
MessageHeaderAccessor .getAccessor (message , StompHeaderAccessor .class );
275
- Principal user = session .getPrincipal ();
276
277
277
278
headerAccessor .setSessionId (session .getId ());
278
279
headerAccessor .setSessionAttributes (session .getAttributes ());
279
- headerAccessor .setUser (user );
280
+ headerAccessor .setUser (getUser ( session ) );
280
281
headerAccessor .setHeader (SimpMessageHeaderAccessor .HEART_BEAT_HEADER , headerAccessor .getHeartbeat ());
281
282
if (!detectImmutableMessageInterceptor (outputChannel )) {
282
283
headerAccessor .setImmutable ();
@@ -286,7 +287,8 @@ else if (webSocketMessage instanceof BinaryMessage) {
286
287
logger .trace ("From client: " + headerAccessor .getShortLogMessage (message .getPayload ()));
287
288
}
288
289
289
- if (StompCommand .CONNECT .equals (headerAccessor .getCommand ())) {
290
+ boolean isConnect = StompCommand .CONNECT .equals (headerAccessor .getCommand ());
291
+ if (isConnect ) {
290
292
this .stats .incrementConnectCount ();
291
293
}
292
294
else if (StompCommand .DISCONNECT .equals (headerAccessor .getCommand ())) {
@@ -297,15 +299,23 @@ else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
297
299
SimpAttributesContextHolder .setAttributesFromMessage (message );
298
300
boolean sent = outputChannel .send (message );
299
301
300
- if (sent && this . eventPublisher != null ) {
301
- if (StompCommand . CONNECT . equals ( headerAccessor . getCommand ()) ) {
302
- publishEvent ( new SessionConnectEvent ( this , message , user ) );
303
- }
304
- else if ( StompCommand . SUBSCRIBE . equals ( headerAccessor . getCommand ())) {
305
- publishEvent ( new SessionSubscribeEvent ( this , message , user ));
302
+ if (sent ) {
303
+ if (isConnect ) {
304
+ Principal user = headerAccessor . getUser ( );
305
+ if ( user != null && user != session . getPrincipal ()) {
306
+ this . stompAuthentications . put ( session . getId (), user );
307
+ }
306
308
}
307
- else if (StompCommand .UNSUBSCRIBE .equals (headerAccessor .getCommand ())) {
308
- publishEvent (new SessionUnsubscribeEvent (this , message , user ));
309
+ if (this .eventPublisher != null ) {
310
+ if (isConnect ) {
311
+ publishEvent (new SessionConnectEvent (this , message , getUser (session )));
312
+ }
313
+ else if (StompCommand .SUBSCRIBE .equals (headerAccessor .getCommand ())) {
314
+ publishEvent (new SessionSubscribeEvent (this , message , getUser (session )));
315
+ }
316
+ else if (StompCommand .UNSUBSCRIBE .equals (headerAccessor .getCommand ())) {
317
+ publishEvent (new SessionUnsubscribeEvent (this , message , getUser (session )));
318
+ }
309
319
}
310
320
}
311
321
}
@@ -323,6 +333,11 @@ else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
323
333
}
324
334
}
325
335
336
+ private Principal getUser (WebSocketSession session ) {
337
+ Principal user = this .stompAuthentications .get (session .getId ());
338
+ return user != null ? user : session .getPrincipal ();
339
+ }
340
+
326
341
@ SuppressWarnings ("deprecation" )
327
342
private void handleError (WebSocketSession session , Throwable ex , Message <byte []> clientMessage ) {
328
343
if (getErrorHandler () == null ) {
@@ -425,7 +440,7 @@ else if (StompCommand.CONNECTED.equals(command)) {
425
440
try {
426
441
SimpAttributes simpAttributes = new SimpAttributes (session .getId (), session .getAttributes ());
427
442
SimpAttributesContextHolder .setAttributes (simpAttributes );
428
- Principal user = session . getPrincipal ( );
443
+ Principal user = getUser ( session );
429
444
publishEvent (new SessionConnectedEvent (this , (Message <byte []>) message , user ));
430
445
}
431
446
finally {
@@ -566,7 +581,7 @@ protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccess
566
581
private StompHeaderAccessor afterStompSessionConnected (Message <?> message , StompHeaderAccessor accessor ,
567
582
WebSocketSession session ) {
568
583
569
- Principal principal = session . getPrincipal ( );
584
+ Principal principal = getUser ( session );
570
585
if (principal != null ) {
571
586
accessor = toMutableAccessor (accessor , message );
572
587
accessor .setNativeHeader (CONNECTED_USER_HEADER , principal .getName ());
@@ -613,7 +628,7 @@ public void afterSessionStarted(WebSocketSession session, MessageChannel outputC
613
628
public void afterSessionEnded (WebSocketSession session , CloseStatus closeStatus , MessageChannel outputChannel ) {
614
629
this .decoders .remove (session .getId ());
615
630
616
- Principal principal = session . getPrincipal ( );
631
+ Principal principal = getUser ( session );
617
632
if (principal != null && this .userSessionRegistry != null ) {
618
633
String userName = getSessionRegistryUserName (principal );
619
634
this .userSessionRegistry .unregisterSessionId (userName , session .getId ());
@@ -624,12 +639,13 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
624
639
try {
625
640
SimpAttributesContextHolder .setAttributes (simpAttributes );
626
641
if (this .eventPublisher != null ) {
627
- Principal user = session . getPrincipal ( );
642
+ Principal user = getUser ( session );
628
643
publishEvent (new SessionDisconnectEvent (this , message , session .getId (), closeStatus , user ));
629
644
}
630
645
outputChannel .send (message );
631
646
}
632
647
finally {
648
+ this .stompAuthentications .remove (session .getId ());
633
649
SimpAttributesContextHolder .resetAttributes ();
634
650
simpAttributes .sessionCompleted ();
635
651
}
@@ -642,7 +658,7 @@ private Message<byte[]> createDisconnectMessage(WebSocketSession session) {
642
658
}
643
659
headerAccessor .setSessionId (session .getId ());
644
660
headerAccessor .setSessionAttributes (session .getAttributes ());
645
- headerAccessor .setUser (session . getPrincipal ( ));
661
+ headerAccessor .setUser (getUser ( session ));
646
662
return MessageBuilder .createMessage (EMPTY_PAYLOAD , headerAccessor .getMessageHeaders ());
647
663
}
648
664
0 commit comments