Skip to content

[release/7.0] [SignalR] Unblock user callbacks when waiting for client result #44014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,14 @@ async Task StartProcessingInvocationMessages(ChannelReader<InvocationMessage> in
{
while (invocationMessageChannelReader.TryRead(out var invocationMessage))
{
await DispatchInvocationAsync(invocationMessage, connectionState).ConfigureAwait(false);
var invokeTask = DispatchInvocationAsync(invocationMessage, connectionState);
// If a client result is expected we shouldn't block on user code as that could potentially permanently block the application
// Even if it doesn't permanently block, it would be better if non-client result handlers could still be called while waiting for a result
// e.g. chat while waiting for user input for a turn in a game
if (string.IsNullOrEmpty(invocationMessage.InvocationId))
{
await invokeTask.ConfigureAwait(false);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -876,5 +876,36 @@ public async Task ClientResultCanReturnNullResult()
await connection.DisposeAsync().DefaultTimeout();
}
}

[Fact]
public async Task ClientResultHandlerDoesNotBlockOtherHandlers()
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection);
try
{
await hubConnection.StartAsync().DefaultTimeout();

var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
hubConnection.On("Result", async () =>
{
await tcs.Task.DefaultTimeout();
return 1;
});
hubConnection.On("Other", () => tcs.SetResult());

await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout();
await connection.ReceiveTextAsync("{\"type\":1,\"target\":\"Other\",\"arguments\":[]}\u001e").DefaultTimeout();

var invokeMessage = await connection.ReadSentTextMessageAsync().DefaultTimeout();

Assert.Equal("{\"type\":3,\"invocationId\":\"1\",\"result\":1}", invokeMessage);
}
finally
{
await hubConnection.DisposeAsync().DefaultTimeout();
await connection.DisposeAsync().DefaultTimeout();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,7 @@ private final class ConnectionState implements InvocationBinder {
private Boolean handshakeReceived = false;
private ScheduledExecutorService handshakeTimeout = null;
private BehaviorSubject<InvocationMessage> messages = BehaviorSubject.create();
private ExecutorService resultInvocationPool = null;

public final Lock lock = new ReentrantLock();
public final CompletableSubject handshakeResponseSubject = CompletableSubject.create();
Expand Down Expand Up @@ -1506,7 +1507,7 @@ public void handleHandshake(ByteBuffer payload) {
}
handshakeReceived = true;
handshakeResponseSubject.onComplete();
handleInvocations();
startInvocationProcessing();
}
}

Expand All @@ -1528,66 +1529,81 @@ public void close() {
if (this.handshakeTimeout != null) {
this.handshakeTimeout.shutdownNow();
}

if (this.resultInvocationPool != null) {
this.resultInvocationPool.shutdownNow();
}
}

public void dispatchInvocation(InvocationMessage message) {
messages.onNext(message);
}

private void handleInvocations() {
messages.observeOn(Schedulers.io()).subscribe(invocationMessage -> {
List<InvocationHandler> handlers = this.connection.handlers.get(invocationMessage.getTarget());
boolean expectsResult = invocationMessage.getInvocationId() != null;
if (handlers == null) {
if (expectsResult) {
logger.warn("Failed to find a value returning handler for '{}' method. Sending error to server.", invocationMessage.getTarget());
sendHubMessageWithLock(new CompletionMessage(null, invocationMessage.getInvocationId(),
null, "Client did not provide a result."));
} else {
logger.warn("Failed to find handler for '{}' method.", invocationMessage.getTarget());
}
return;
}
Object result = null;
Exception resultException = null;
Boolean hasResult = false;
for (InvocationHandler handler : handlers) {
try {
Object action = handler.getAction();
if (handler.getHasResult()) {
FunctionBase function = (FunctionBase)action;
result = function.invoke(invocationMessage.getArguments()).blockingGet();
hasResult = true;
} else {
((ActionBase)action).invoke(invocationMessage.getArguments()).blockingAwait();
}
} catch (Exception e) {
logger.error("Invoking client side method '{}' failed:", invocationMessage.getTarget(), e);
if (handler.getHasResult()) {
resultException = e;
}
}
private void startInvocationProcessing() {
this.resultInvocationPool = Executors.newCachedThreadPool();
this.messages.observeOn(Schedulers.io()).subscribe(invocationMessage -> {
// if client result expected, unblock the invocation processing thread
if (invocationMessage.getInvocationId() != null) {
this.resultInvocationPool.submit(() -> handleInvocation(invocationMessage));
} else {
handleInvocation(invocationMessage);
}
}, (e) -> {
stop(e.getMessage());
}, () -> {
});
}

private void handleInvocation(InvocationMessage invocationMessage)
{
List<InvocationHandler> handlers = this.connection.handlers.get(invocationMessage.getTarget());
boolean expectsResult = invocationMessage.getInvocationId() != null;
if (handlers == null) {
if (expectsResult) {
if (resultException != null) {
sendHubMessageWithLock(new CompletionMessage(null, invocationMessage.getInvocationId(),
null, resultException.getMessage()));
} else if (hasResult) {
sendHubMessageWithLock(new CompletionMessage(null, invocationMessage.getInvocationId(),
result, null));
logger.warn("Failed to find a value returning handler for '{}' method. Sending error to server.", invocationMessage.getTarget());
sendHubMessageWithLock(new CompletionMessage(null, invocationMessage.getInvocationId(),
null, "Client did not provide a result."));
} else {
logger.warn("Failed to find handler for '{}' method.", invocationMessage.getTarget());
}
return;
}
Object result = null;
Exception resultException = null;
Boolean hasResult = false;
for (InvocationHandler handler : handlers) {
try {
Object action = handler.getAction();
if (handler.getHasResult()) {
FunctionBase function = (FunctionBase)action;
result = function.invoke(invocationMessage.getArguments()).blockingGet();
hasResult = true;
} else {
logger.warn("Failed to find a value returning handler for '{}' method. Sending error to server.", invocationMessage.getTarget());
sendHubMessageWithLock(new CompletionMessage(null, invocationMessage.getInvocationId(),
null, "Client did not provide a result."));
((ActionBase)action).invoke(invocationMessage.getArguments()).blockingAwait();
}
} catch (Exception e) {
logger.error("Invoking client side method '{}' failed:", invocationMessage.getTarget(), e);
if (handler.getHasResult()) {
resultException = e;
}
}
}

if (expectsResult) {
if (resultException != null) {
sendHubMessageWithLock(new CompletionMessage(null, invocationMessage.getInvocationId(),
null, resultException.getMessage()));
} else if (hasResult) {
logger.warn("Result given for '{}' method but server is not expecting a result.", invocationMessage.getTarget());
sendHubMessageWithLock(new CompletionMessage(null, invocationMessage.getInvocationId(),
result, null));
} else {
logger.warn("Failed to find a value returning handler for '{}' method. Sending error to server.", invocationMessage.getTarget());
sendHubMessageWithLock(new CompletionMessage(null, invocationMessage.getInvocationId(),
null, "Client did not provide a result."));
}
}, (e) -> {
stop(e.getMessage());
}, () -> {
});
} else if (hasResult) {
logger.warn("Result given for '{}' method but server is not expecting a result.", invocationMessage.getTarget());
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,4 +395,38 @@ public void returnFromOnHandlerEightParams() {
String expected = "{\"type\":3,\"invocationId\":\"1\",\"result\":\"bob\"}" + RECORD_SEPARATOR;
assertEquals(expected, TestUtils.byteBufferToString(message));
}

@Test
public void clientResultHandlerDoesNotBlockOtherHandlers() {
MockTransport mockTransport = new MockTransport();
HubConnection hubConnection = TestUtils.createHubConnection("http://example.com", mockTransport);
CompletableSubject resultCalled = CompletableSubject.create();
CompletableSubject completeResult = CompletableSubject.create();
CompletableSubject nonResultCalled = CompletableSubject.create();

hubConnection.onWithResult("inc", (i) -> {
resultCalled.onComplete();
completeResult.timeout(30, TimeUnit.SECONDS).blockingAwait();
return Single.just("bob");
}, String.class);

hubConnection.on("inc2", (i) -> {
nonResultCalled.onComplete();
}, String.class);

hubConnection.start().timeout(30, TimeUnit.SECONDS).blockingAwait();
SingleSubject<ByteBuffer> sendTask = mockTransport.getNextSentMessage();
mockTransport.receiveMessage("{\"type\":1,\"invocationId\":\"1\",\"target\":\"inc\",\"arguments\":[\"1\"]}" + RECORD_SEPARATOR);
resultCalled.timeout(30, TimeUnit.SECONDS).blockingAwait();

// Send an non-result invocation and make sure it's processed even with a blocking result invocation
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc2\",\"arguments\":[\"1\"]}" + RECORD_SEPARATOR);
nonResultCalled.timeout(30, TimeUnit.SECONDS).blockingAwait();

completeResult.onComplete();

ByteBuffer message = sendTask.timeout(30, TimeUnit.SECONDS).blockingGet();
String expected = "{\"type\":3,\"invocationId\":\"1\",\"result\":\"bob\"}" + RECORD_SEPARATOR;
assertEquals(expected, TestUtils.byteBufferToString(message));
}
}