Skip to content

Commit cc33bfa

Browse files
committed
Support receipt on DISCONNECT with simple broker
Issue: SPR-14568
1 parent 91387a5 commit cc33bfa

File tree

5 files changed

+81
-11
lines changed

5 files changed

+81
-11
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/SimpMessageHeaderAccessor.java

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
6565

6666
public static final String CONNECT_MESSAGE_HEADER = "simpConnectMessage";
6767

68+
public static final String DISCONNECT_MESSAGE_HEADER = "simpDisconnectMessage";
69+
6870
public static final String HEART_BEAT_HEADER = "simpHeartbeat";
6971

7072

spring-messaging/src/main/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandler.java

+6-3
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ else if (SimpMessageType.CONNECT.equals(messageType)) {
250250
}
251251
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
252252
logMessage(message);
253-
handleDisconnect(sessionId, user);
253+
handleDisconnect(sessionId, user, message);
254254
}
255255
else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
256256
logMessage(message);
@@ -285,12 +285,15 @@ private void initHeaders(SimpMessageHeaderAccessor accessor) {
285285
}
286286
}
287287

288-
private void handleDisconnect(String sessionId, Principal user) {
288+
private void handleDisconnect(String sessionId, Principal user, Message<?> origMessage) {
289289
this.sessions.remove(sessionId);
290290
this.subscriptionRegistry.unregisterAllSubscriptions(sessionId);
291291
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
292292
accessor.setSessionId(sessionId);
293293
accessor.setUser(user);
294+
if (origMessage != null) {
295+
accessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, origMessage);
296+
}
294297
initHeaders(accessor);
295298
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
296299
getClientOutboundChannel().send(message);
@@ -407,7 +410,7 @@ public void run() {
407410
long now = System.currentTimeMillis();
408411
for (SessionInfo info : sessions.values()) {
409412
if (info.getReadInterval() > 0 && (now - info.getLastReadTime()) > info.getReadInterval()) {
410-
handleDisconnect(info.getSessiondId(), info.getUser());
413+
handleDisconnect(info.getSessiondId(), info.getUser(), null);
411414
}
412415
if (info.getWriteInterval() > 0 && (now - info.getLastWriteTime()) > info.getWriteInterval()) {
413416
SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);

spring-messaging/src/test/java/org/springframework/messaging/simp/broker/SimpleBrokerMessageHandlerTests.java

+18-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2015 the original author or authors.
2+
* Copyright 2002-2016 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,9 +16,6 @@
1616

1717
package org.springframework.messaging.simp.broker;
1818

19-
import static org.junit.Assert.*;
20-
import static org.mockito.Mockito.*;
21-
2219
import java.security.Principal;
2320
import java.util.Collections;
2421
import java.util.List;
@@ -41,6 +38,21 @@
4138
import org.springframework.messaging.support.MessageBuilder;
4239
import org.springframework.scheduling.TaskScheduler;
4340

41+
import static org.junit.Assert.assertArrayEquals;
42+
import static org.junit.Assert.assertEquals;
43+
import static org.junit.Assert.assertNotNull;
44+
import static org.junit.Assert.assertNull;
45+
import static org.junit.Assert.assertSame;
46+
import static org.junit.Assert.assertTrue;
47+
import static org.mockito.Mockito.any;
48+
import static org.mockito.Mockito.atLeast;
49+
import static org.mockito.Mockito.eq;
50+
import static org.mockito.Mockito.mock;
51+
import static org.mockito.Mockito.times;
52+
import static org.mockito.Mockito.verify;
53+
import static org.mockito.Mockito.verifyNoMoreInteractions;
54+
import static org.mockito.Mockito.when;
55+
4456
/**
4557
* Unit tests for SimpleBrokerMessageHandler.
4658
*
@@ -72,7 +84,7 @@ public class SimpleBrokerMessageHandlerTests {
7284
public void setup() {
7385
MockitoAnnotations.initMocks(this);
7486
this.messageHandler = new SimpleBrokerMessageHandler(this.clientInboundChannel,
75-
this.clientOutboundChannel, this.brokerChannel, Collections.<String>emptyList());
87+
this.clientOutboundChannel, this.brokerChannel, Collections.emptyList());
7688
}
7789

7890

@@ -130,6 +142,7 @@ public void subcribeDisconnectPublish() {
130142

131143
Message<?> captured = this.messageCaptor.getAllValues().get(0);
132144
assertEquals(SimpMessageType.DISCONNECT_ACK, SimpMessageHeaderAccessor.getMessageType(captured.getHeaders()));
145+
assertSame(message, captured.getHeaders().get(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER));
133146
assertEquals(sess1, SimpMessageHeaderAccessor.getSessionId(captured.getHeaders()));
134147
assertEquals("joe", SimpMessageHeaderAccessor.getUser(captured.getHeaders()).getName());
135148

spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

+20-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.context.ApplicationEventPublisherAware;
3535
import org.springframework.messaging.Message;
3636
import org.springframework.messaging.MessageChannel;
37+
import org.springframework.messaging.MessageHeaders;
3738
import org.springframework.messaging.simp.SimpAttributes;
3839
import org.springframework.messaging.simp.SimpAttributesContextHolder;
3940
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
@@ -479,8 +480,15 @@ else if (accessor instanceof SimpMessageHeaderAccessor) {
479480
stompAccessor = convertConnectAcktoStompConnected(stompAccessor);
480481
}
481482
else if (SimpMessageType.DISCONNECT_ACK.equals(messageType)) {
482-
stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
483-
stompAccessor.setMessage("Session closed.");
483+
String receipt = getDisconnectReceipt(stompAccessor);
484+
if (receipt != null) {
485+
stompAccessor = StompHeaderAccessor.create(StompCommand.RECEIPT);
486+
stompAccessor.setReceiptId(receipt);
487+
}
488+
else {
489+
stompAccessor = StompHeaderAccessor.create(StompCommand.ERROR);
490+
stompAccessor.setMessage("Session closed.");
491+
}
484492
}
485493
else if (SimpMessageType.HEARTBEAT.equals(messageType)) {
486494
stompAccessor = StompHeaderAccessor.createForHeartbeat();
@@ -533,6 +541,16 @@ else if (!acceptVersions.isEmpty()) {
533541
return connectedHeaders;
534542
}
535543

544+
private String getDisconnectReceipt(SimpMessageHeaderAccessor simpHeaders) {
545+
String name = StompHeaderAccessor.DISCONNECT_MESSAGE_HEADER;
546+
Message<?> message = (Message<?>) simpHeaders.getHeader(name);
547+
if (message != null) {
548+
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
549+
return accessor.getReceipt();
550+
}
551+
return null;
552+
}
553+
536554
protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccessor, Message<?> message) {
537555
return (headerAccessor.isMutable() ? headerAccessor : StompHeaderAccessor.wrap(message));
538556
}

spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2015 the original author or authors.
2+
* Copyright 2002-2016 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -169,6 +169,40 @@ public void handleMessageToClientWithSimpConnectAckDefaultHeartBeat() {
169169
"user-name:joe\n" + "\n" + "\u0000", actual.getPayload());
170170
}
171171

172+
@Test
173+
public void handleMessageToClientWithSimpDisconnectAck() {
174+
175+
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT);
176+
Message<?> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
177+
178+
SimpMessageHeaderAccessor ackAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
179+
ackAccessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, connectMessage);
180+
Message<byte[]> ackMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, ackAccessor.getMessageHeaders());
181+
this.protocolHandler.handleMessageToClient(this.session, ackMessage);
182+
183+
assertEquals(1, this.session.getSentMessages().size());
184+
TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
185+
assertEquals("ERROR\n" + "message:Session closed.\n" + "content-length:0\n" +
186+
"\n\u0000", actual.getPayload());
187+
}
188+
189+
@Test
190+
public void handleMessageToClientWithSimpDisconnectAckAndReceipt() {
191+
192+
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.DISCONNECT);
193+
accessor.setReceipt("message-123");
194+
Message<?> connectMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
195+
196+
SimpMessageHeaderAccessor ackAccessor = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK);
197+
ackAccessor.setHeader(SimpMessageHeaderAccessor.DISCONNECT_MESSAGE_HEADER, connectMessage);
198+
Message<byte[]> ackMessage = MessageBuilder.createMessage(EMPTY_PAYLOAD, ackAccessor.getMessageHeaders());
199+
this.protocolHandler.handleMessageToClient(this.session, ackMessage);
200+
201+
assertEquals(1, this.session.getSentMessages().size());
202+
TextMessage actual = (TextMessage) this.session.getSentMessages().get(0);
203+
assertEquals("RECEIPT\n" + "receipt-id:message-123\n" + "\n\u0000", actual.getPayload());
204+
}
205+
172206
@Test
173207
public void handleMessageToClientWithSimpHeartbeat() {
174208

0 commit comments

Comments
 (0)