Skip to content

Commit 8b9503e

Browse files
authored
Acquire HubConnectionStateLock before Send/Invoke/Stream (#12078)
1 parent 4c07e1e commit 8b9503e

File tree

4 files changed

+105
-70
lines changed

4 files changed

+105
-70
lines changed

src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java

Lines changed: 94 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -537,11 +537,15 @@ private void stopConnection(String errorMessage) {
537537
* @param args The arguments to be passed to the method.
538538
*/
539539
public void send(String method, Object... args) {
540-
if (hubConnectionState != HubConnectionState.CONNECTED) {
541-
throw new RuntimeException("The 'send' method cannot be called if the connection is not active.");
540+
hubConnectionStateLock.lock();
541+
try {
542+
if (hubConnectionState != HubConnectionState.CONNECTED) {
543+
throw new RuntimeException("The 'send' method cannot be called if the connection is not active.");
544+
}
545+
sendInvocationMessage(method, args);
546+
} finally {
547+
hubConnectionStateLock.unlock();
542548
}
543-
544-
sendInvocationMessage(method, args);
545549
}
546550

547551
private void sendInvocationMessage(String method, Object[] args) {
@@ -605,26 +609,31 @@ Object[] checkUploadStream(Object[] args, List<String> streamIds) {
605609
*/
606610
@SuppressWarnings("unchecked")
607611
public Completable invoke(String method, Object... args) {
608-
if (hubConnectionState != HubConnectionState.CONNECTED) {
609-
throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active.");
610-
}
612+
hubConnectionStateLock.lock();
613+
try {
614+
if (hubConnectionState != HubConnectionState.CONNECTED) {
615+
throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active.");
616+
}
611617

612-
String id = connectionState.getNextInvocationId();
618+
String id = connectionState.getNextInvocationId();
613619

614-
CompletableSubject subject = CompletableSubject.create();
615-
InvocationRequest irq = new InvocationRequest(null, id);
616-
connectionState.addInvocation(irq);
620+
CompletableSubject subject = CompletableSubject.create();
621+
InvocationRequest irq = new InvocationRequest(null, id);
622+
connectionState.addInvocation(irq);
617623

618-
Subject<Object> pendingCall = irq.getPendingCall();
624+
Subject<Object> pendingCall = irq.getPendingCall();
619625

620-
pendingCall.subscribe(result -> subject.onComplete(),
621-
error -> subject.onError(error),
622-
() -> subject.onComplete());
626+
pendingCall.subscribe(result -> subject.onComplete(),
627+
error -> subject.onError(error),
628+
() -> subject.onComplete());
623629

624-
// Make sure the actual send is after setting up the callbacks otherwise there is a race
625-
// where the map doesn't have the callbacks yet when the response is returned
626-
sendInvocationMessage(method, args, id, false);
627-
return subject;
630+
// Make sure the actual send is after setting up the callbacks otherwise there is a race
631+
// where the map doesn't have the callbacks yet when the response is returned
632+
sendInvocationMessage(method, args, id, false);
633+
return subject;
634+
} finally {
635+
hubConnectionStateLock.unlock();
636+
}
628637
}
629638

630639
/**
@@ -638,32 +647,37 @@ public Completable invoke(String method, Object... args) {
638647
*/
639648
@SuppressWarnings("unchecked")
640649
public <T> Single<T> invoke(Class<T> returnType, String method, Object... args) {
641-
if (hubConnectionState != HubConnectionState.CONNECTED) {
642-
throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active.");
643-
}
650+
hubConnectionStateLock.lock();
651+
try {
652+
if (hubConnectionState != HubConnectionState.CONNECTED) {
653+
throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active.");
654+
}
644655

645-
String id = connectionState.getNextInvocationId();
656+
String id = connectionState.getNextInvocationId();
657+
InvocationRequest irq = new InvocationRequest(returnType, id);
658+
connectionState.addInvocation(irq);
646659

647-
SingleSubject<T> subject = SingleSubject.create();
648-
InvocationRequest irq = new InvocationRequest(returnType, id);
649-
connectionState.addInvocation(irq);
660+
SingleSubject<T> subject = SingleSubject.create();
650661

651-
// forward the invocation result or error to the user
652-
// run continuations on a separate thread
653-
Subject<Object> pendingCall = irq.getPendingCall();
654-
pendingCall.subscribe(result -> {
655-
// Primitive types can't be cast with the Class cast function
656-
if (returnType.isPrimitive()) {
657-
subject.onSuccess((T)result);
658-
} else {
659-
subject.onSuccess(returnType.cast(result));
660-
}
661-
}, error -> subject.onError(error));
662+
// forward the invocation result or error to the user
663+
// run continuations on a separate thread
664+
Subject<Object> pendingCall = irq.getPendingCall();
665+
pendingCall.subscribe(result -> {
666+
// Primitive types can't be cast with the Class cast function
667+
if (returnType.isPrimitive()) {
668+
subject.onSuccess((T)result);
669+
} else {
670+
subject.onSuccess(returnType.cast(result));
671+
}
672+
}, error -> subject.onError(error));
662673

663-
// Make sure the actual send is after setting up the callbacks otherwise there is a race
664-
// where the map doesn't have the callbacks yet when the response is returned
665-
sendInvocationMessage(method, args, id, false);
666-
return subject;
674+
// Make sure the actual send is after setting up the callbacks otherwise there is a race
675+
// where the map doesn't have the callbacks yet when the response is returned
676+
sendInvocationMessage(method, args, id, false);
677+
return subject;
678+
} finally {
679+
hubConnectionStateLock.unlock();
680+
}
667681
}
668682

669683
/**
@@ -677,33 +691,46 @@ public <T> Single<T> invoke(Class<T> returnType, String method, Object... args)
677691
*/
678692
@SuppressWarnings("unchecked")
679693
public <T> Observable<T> stream(Class<T> returnType, String method, Object ... args) {
680-
String invocationId = connectionState.getNextInvocationId();
681-
AtomicInteger subscriptionCount = new AtomicInteger();
682-
InvocationRequest irq = new InvocationRequest(returnType, invocationId);
683-
connectionState.addInvocation(irq);
684-
ReplaySubject<T> subject = ReplaySubject.create();
685-
686-
Subject<Object> pendingCall = irq.getPendingCall();
687-
pendingCall.subscribe(result -> {
688-
// Primitive types can't be cast with the Class cast function
689-
if (returnType.isPrimitive()) {
690-
subject.onNext((T)result);
691-
} else {
692-
subject.onNext(returnType.cast(result));
693-
}
694-
}, error -> subject.onError(error),
695-
() -> subject.onComplete());
696-
697-
Observable<T> observable = subject.doOnSubscribe((subscriber) -> subscriptionCount.incrementAndGet());
698-
sendInvocationMessage(method, args, invocationId, true);
699-
return observable.doOnDispose(() -> {
700-
if (subscriptionCount.decrementAndGet() == 0) {
701-
CancelInvocationMessage cancelInvocationMessage = new CancelInvocationMessage(invocationId);
702-
sendHubMessage(cancelInvocationMessage);
703-
connectionState.tryRemoveInvocation(invocationId);
704-
subject.onComplete();
694+
String invocationId;
695+
InvocationRequest irq;
696+
hubConnectionStateLock.lock();
697+
try {
698+
if (hubConnectionState != HubConnectionState.CONNECTED) {
699+
throw new RuntimeException("The 'stream' method cannot be called if the connection is not active.");
705700
}
706-
});
701+
702+
invocationId = connectionState.getNextInvocationId();
703+
irq = new InvocationRequest(returnType, invocationId);
704+
connectionState.addInvocation(irq);
705+
706+
AtomicInteger subscriptionCount = new AtomicInteger();
707+
ReplaySubject<T> subject = ReplaySubject.create();
708+
Subject<Object> pendingCall = irq.getPendingCall();
709+
pendingCall.subscribe(result -> {
710+
// Primitive types can't be cast with the Class cast function
711+
if (returnType.isPrimitive()) {
712+
subject.onNext((T)result);
713+
} else {
714+
subject.onNext(returnType.cast(result));
715+
}
716+
}, error -> subject.onError(error),
717+
() -> subject.onComplete());
718+
719+
Observable<T> observable = subject.doOnSubscribe((subscriber) -> subscriptionCount.incrementAndGet());
720+
sendInvocationMessage(method, args, invocationId, true);
721+
return observable.doOnDispose(() -> {
722+
if (subscriptionCount.decrementAndGet() == 0) {
723+
CancelInvocationMessage cancelInvocationMessage = new CancelInvocationMessage(invocationId);
724+
sendHubMessage(cancelInvocationMessage);
725+
if (connectionState != null) {
726+
connectionState.tryRemoveInvocation(invocationId);
727+
}
728+
subject.onComplete();
729+
}
730+
});
731+
} finally {
732+
hubConnectionStateLock.unlock();
733+
}
707734
}
708735

709736
private void sendHubMessage(HubMessage message) {

src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1608,6 +1608,15 @@ public void cannotInvokeBeforeStart() {
16081608
assertEquals("The 'invoke' method cannot be called if the connection is not active.", exception.getMessage());
16091609
}
16101610

1611+
@Test
1612+
public void cannotStreamBeforeStart() {
1613+
HubConnection hubConnection = TestUtils.createHubConnection("http://example.com");
1614+
assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState());
1615+
1616+
Throwable exception = assertThrows(RuntimeException.class, () -> hubConnection.stream(String.class, "inc", "arg1"));
1617+
assertEquals("The 'stream' method cannot be called if the connection is not active.", exception.getMessage());
1618+
}
1619+
16111620
@Test
16121621
public void doesNotErrorWhenReceivingInvokeWithIncorrectArgumentLength() {
16131622
MockTransport mockTransport = new MockTransport();
@@ -2036,7 +2045,7 @@ public void authorizationHeaderFromNegotiateGetsSetToNewValue() {
20362045

20372046
TestHttpClient client = new TestHttpClient()
20382047
.on("POST", "http://example.com/negotiate", (req) -> {
2039-
if(redirectCount.get() == 0){
2048+
if (redirectCount.get() == 0) {
20402049
redirectCount.incrementAndGet();
20412050
redirectToken.set(req.getHeaders().get("Authorization"));
20422051
return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"firstRedirectToken\"}"));

src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportUrlFormatTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import java.util.HashMap;
99
import java.util.stream.Stream;
1010

11-
import io.reactivex.Single;
1211
import org.junit.jupiter.params.ParameterizedTest;
1312
import org.junit.jupiter.params.provider.Arguments;
1413
import org.junit.jupiter.params.provider.MethodSource;

src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/sample/Chat.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ public static void main(String[] args) {
3737
hubConnection.send("Send", message);
3838
}
3939

40-
hubConnection.stop();
40+
hubConnection.stop().blockingAwait();
4141
}
4242
}

0 commit comments

Comments
 (0)