diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java index 67669576e74d..a4d1e58c7dab 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java @@ -537,11 +537,15 @@ private void stopConnection(String errorMessage) { * @param args The arguments to be passed to the method. */ public void send(String method, Object... args) { - if (hubConnectionState != HubConnectionState.CONNECTED) { - throw new RuntimeException("The 'send' method cannot be called if the connection is not active."); + hubConnectionStateLock.lock(); + try { + if (hubConnectionState != HubConnectionState.CONNECTED) { + throw new RuntimeException("The 'send' method cannot be called if the connection is not active."); + } + sendInvocationMessage(method, args); + } finally { + hubConnectionStateLock.unlock(); } - - sendInvocationMessage(method, args); } private void sendInvocationMessage(String method, Object[] args) { @@ -605,26 +609,31 @@ Object[] checkUploadStream(Object[] args, List streamIds) { */ @SuppressWarnings("unchecked") public Completable invoke(String method, Object... args) { - if (hubConnectionState != HubConnectionState.CONNECTED) { - throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active."); - } + hubConnectionStateLock.lock(); + try { + if (hubConnectionState != HubConnectionState.CONNECTED) { + throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active."); + } - String id = connectionState.getNextInvocationId(); + String id = connectionState.getNextInvocationId(); - CompletableSubject subject = CompletableSubject.create(); - InvocationRequest irq = new InvocationRequest(null, id); - connectionState.addInvocation(irq); + CompletableSubject subject = CompletableSubject.create(); + InvocationRequest irq = new InvocationRequest(null, id); + connectionState.addInvocation(irq); - Subject pendingCall = irq.getPendingCall(); + Subject pendingCall = irq.getPendingCall(); - pendingCall.subscribe(result -> subject.onComplete(), - error -> subject.onError(error), - () -> subject.onComplete()); + pendingCall.subscribe(result -> subject.onComplete(), + error -> subject.onError(error), + () -> subject.onComplete()); - // Make sure the actual send is after setting up the callbacks otherwise there is a race - // where the map doesn't have the callbacks yet when the response is returned - sendInvocationMessage(method, args, id, false); - return subject; + // Make sure the actual send is after setting up the callbacks otherwise there is a race + // where the map doesn't have the callbacks yet when the response is returned + sendInvocationMessage(method, args, id, false); + return subject; + } finally { + hubConnectionStateLock.unlock(); + } } /** @@ -638,32 +647,37 @@ public Completable invoke(String method, Object... args) { */ @SuppressWarnings("unchecked") public Single invoke(Class returnType, String method, Object... args) { - if (hubConnectionState != HubConnectionState.CONNECTED) { - throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active."); - } + hubConnectionStateLock.lock(); + try { + if (hubConnectionState != HubConnectionState.CONNECTED) { + throw new RuntimeException("The 'invoke' method cannot be called if the connection is not active."); + } - String id = connectionState.getNextInvocationId(); + String id = connectionState.getNextInvocationId(); + InvocationRequest irq = new InvocationRequest(returnType, id); + connectionState.addInvocation(irq); - SingleSubject subject = SingleSubject.create(); - InvocationRequest irq = new InvocationRequest(returnType, id); - connectionState.addInvocation(irq); + SingleSubject subject = SingleSubject.create(); - // forward the invocation result or error to the user - // run continuations on a separate thread - Subject pendingCall = irq.getPendingCall(); - pendingCall.subscribe(result -> { - // Primitive types can't be cast with the Class cast function - if (returnType.isPrimitive()) { - subject.onSuccess((T)result); - } else { - subject.onSuccess(returnType.cast(result)); - } - }, error -> subject.onError(error)); + // forward the invocation result or error to the user + // run continuations on a separate thread + Subject pendingCall = irq.getPendingCall(); + pendingCall.subscribe(result -> { + // Primitive types can't be cast with the Class cast function + if (returnType.isPrimitive()) { + subject.onSuccess((T)result); + } else { + subject.onSuccess(returnType.cast(result)); + } + }, error -> subject.onError(error)); - // Make sure the actual send is after setting up the callbacks otherwise there is a race - // where the map doesn't have the callbacks yet when the response is returned - sendInvocationMessage(method, args, id, false); - return subject; + // Make sure the actual send is after setting up the callbacks otherwise there is a race + // where the map doesn't have the callbacks yet when the response is returned + sendInvocationMessage(method, args, id, false); + return subject; + } finally { + hubConnectionStateLock.unlock(); + } } /** @@ -677,33 +691,46 @@ public Single invoke(Class returnType, String method, Object... args) */ @SuppressWarnings("unchecked") public Observable stream(Class returnType, String method, Object ... args) { - String invocationId = connectionState.getNextInvocationId(); - AtomicInteger subscriptionCount = new AtomicInteger(); - InvocationRequest irq = new InvocationRequest(returnType, invocationId); - connectionState.addInvocation(irq); - ReplaySubject subject = ReplaySubject.create(); - - Subject pendingCall = irq.getPendingCall(); - pendingCall.subscribe(result -> { - // Primitive types can't be cast with the Class cast function - if (returnType.isPrimitive()) { - subject.onNext((T)result); - } else { - subject.onNext(returnType.cast(result)); - } - }, error -> subject.onError(error), - () -> subject.onComplete()); - - Observable observable = subject.doOnSubscribe((subscriber) -> subscriptionCount.incrementAndGet()); - sendInvocationMessage(method, args, invocationId, true); - return observable.doOnDispose(() -> { - if (subscriptionCount.decrementAndGet() == 0) { - CancelInvocationMessage cancelInvocationMessage = new CancelInvocationMessage(invocationId); - sendHubMessage(cancelInvocationMessage); - connectionState.tryRemoveInvocation(invocationId); - subject.onComplete(); + String invocationId; + InvocationRequest irq; + hubConnectionStateLock.lock(); + try { + if (hubConnectionState != HubConnectionState.CONNECTED) { + throw new RuntimeException("The 'stream' method cannot be called if the connection is not active."); } - }); + + invocationId = connectionState.getNextInvocationId(); + irq = new InvocationRequest(returnType, invocationId); + connectionState.addInvocation(irq); + + AtomicInteger subscriptionCount = new AtomicInteger(); + ReplaySubject subject = ReplaySubject.create(); + Subject pendingCall = irq.getPendingCall(); + pendingCall.subscribe(result -> { + // Primitive types can't be cast with the Class cast function + if (returnType.isPrimitive()) { + subject.onNext((T)result); + } else { + subject.onNext(returnType.cast(result)); + } + }, error -> subject.onError(error), + () -> subject.onComplete()); + + Observable observable = subject.doOnSubscribe((subscriber) -> subscriptionCount.incrementAndGet()); + sendInvocationMessage(method, args, invocationId, true); + return observable.doOnDispose(() -> { + if (subscriptionCount.decrementAndGet() == 0) { + CancelInvocationMessage cancelInvocationMessage = new CancelInvocationMessage(invocationId); + sendHubMessage(cancelInvocationMessage); + if (connectionState != null) { + connectionState.tryRemoveInvocation(invocationId); + } + subject.onComplete(); + } + }); + } finally { + hubConnectionStateLock.unlock(); + } } private void sendHubMessage(HubMessage message) { diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java index b5104919ab44..0ff3be7fd0d9 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java @@ -1608,6 +1608,15 @@ public void cannotInvokeBeforeStart() { assertEquals("The 'invoke' method cannot be called if the connection is not active.", exception.getMessage()); } + @Test + public void cannotStreamBeforeStart() { + HubConnection hubConnection = TestUtils.createHubConnection("http://example.com"); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + + Throwable exception = assertThrows(RuntimeException.class, () -> hubConnection.stream(String.class, "inc", "arg1")); + assertEquals("The 'stream' method cannot be called if the connection is not active.", exception.getMessage()); + } + @Test public void doesNotErrorWhenReceivingInvokeWithIncorrectArgumentLength() { MockTransport mockTransport = new MockTransport(); @@ -2036,7 +2045,7 @@ public void authorizationHeaderFromNegotiateGetsSetToNewValue() { TestHttpClient client = new TestHttpClient() .on("POST", "http://example.com/negotiate", (req) -> { - if(redirectCount.get() == 0){ + if (redirectCount.get() == 0) { redirectCount.incrementAndGet(); redirectToken.set(req.getHeaders().get("Authorization")); return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"firstRedirectToken\"}")); diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportUrlFormatTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportUrlFormatTest.java index 2631084a3424..1dbe653dcc1a 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportUrlFormatTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportUrlFormatTest.java @@ -8,7 +8,6 @@ import java.util.HashMap; import java.util.stream.Stream; -import io.reactivex.Single; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/sample/Chat.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/sample/Chat.java index bd617fe88e45..5201a0915887 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/sample/Chat.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/sample/Chat.java @@ -37,6 +37,6 @@ public static void main(String[] args) { hubConnection.send("Send", message); } - hubConnection.stop(); + hubConnection.stop().blockingAwait(); } }