Skip to content

[SignalR] Add client return results #40811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,5 +310,11 @@ public static void ErrorHandshakeTimedOut(ILogger logger, TimeSpan handshakeTime

[LoggerMessage(84, LogLevel.Trace, "Client threw an error for stream '{StreamId}'.", EventName = "ErroredStream")]
public static partial void ErroredStream(ILogger logger, string streamId, Exception exception);

[LoggerMessage(85, LogLevel.Warning, "Failed to find a value returning handler for '{Target}' method. Sending error to server.", EventName = "MissingResultHandler")]
public static partial void MissingResultHandler(ILogger logger, string target);

[LoggerMessage(86, LogLevel.Warning, "Result given for '{Target}' method but server is not expecting a result.", EventName = "ResultNotExpected")]
public static partial void ResultNotExpected(ILogger logger, string target);
}
}
126 changes: 98 additions & 28 deletions src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ private async Task StartAsyncInner(CancellationToken cancellationToken = default
throw new InvalidOperationException($"The {nameof(HubConnection)} cannot be started while {nameof(StopAsync)} is running.");
}

using (CreateLinkedToken(cancellationToken, _state.StopCts.Token, out var linkedToken))
using (CancellationTokenUtils.CreateLinkedToken(cancellationToken, _state.StopCts.Token, out var linkedToken))
{
await StartAsyncCore(linkedToken).ConfigureAwait(false);
}
Expand Down Expand Up @@ -312,6 +312,39 @@ public virtual async ValueTask DisposeAsync()
}
}

/// <summary>
/// Registers a handler that will be invoked when the hub method with the specified method name is invoked.
/// Returns value returned by handler to server if the server requests a result.
/// </summary>
/// <param name="methodName">The name of the hub method to define.</param>
/// <param name="parameterTypes">The parameters types expected by the hub method.</param>
/// <param name="handler">The handler that will be raised when the hub method is invoked.</param>
/// <param name="state">A state object that will be passed to the handler.</param>
/// <returns>A subscription that can be disposed to unsubscribe from the hub method.</returns>
/// <remarks>
/// This is a low level method for registering a handler. Using an <see cref="HubConnectionExtensions"/> <c>On</c> extension method is recommended.
/// </remarks>
public virtual IDisposable On(string methodName, Type[] parameterTypes, Func<object?[], object, Task<object?>> handler, object state)
{
Log.RegisteringHandler(_logger, methodName);

CheckDisposed();

// It's OK to be disposed while registering a callback, we'll just never call the callback anyway (as with all the callbacks registered before disposal).
var invocationHandler = new InvocationHandler(parameterTypes, handler, state);
var invocationList = _handlers.AddOrUpdate(methodName, _ => new InvocationHandlerList(invocationHandler),
(_, invocations) =>
{
lock (invocations)
{
invocations.Add(methodName, invocationHandler);
}
return invocations;
});

return new Subscription(invocationHandler, invocationList);
}

// If the registered callback blocks it can cause the client to stop receiving messages. If you need to block, get off the current thread first.
/// <summary>
/// Registers a handler that will be invoked when the hub method with the specified method name is invoked.
Expand All @@ -337,7 +370,7 @@ public virtual IDisposable On(string methodName, Type[] parameterTypes, Func<obj
{
lock (invocations)
{
invocations.Add(invocationHandler);
invocations.Add(methodName, invocationHandler);
}
return invocations;
});
Expand Down Expand Up @@ -988,27 +1021,73 @@ private async Task SendWithLock(ConnectionState expectedConnectionState, HubMess
return null;
}

private async Task DispatchInvocationAsync(InvocationMessage invocation)
private async Task DispatchInvocationAsync(InvocationMessage invocation, ConnectionState connectionState)
{
var expectsResult = !string.IsNullOrEmpty(invocation.InvocationId);
// Find the handler
if (!_handlers.TryGetValue(invocation.Target, out var invocationHandlerList))
{
Log.MissingHandler(_logger, invocation.Target);
if (expectsResult)
{
Log.MissingResultHandler(_logger, invocation.Target);
await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false);
}
else
{
Log.MissingHandler(_logger, invocation.Target);
}
return;
}

// Grabbing the current handlers
var copiedHandlers = invocationHandlerList.GetHandlers();
object? result = null;
Exception? resultException = null;
var hasResult = false;
foreach (var handler in copiedHandlers)
{
try
{
await handler.InvokeAsync(invocation.Arguments).ConfigureAwait(false);
var task = handler.InvokeAsync(invocation.Arguments);
if (handler.HasResult && task is Task<object?> resultTask)
{
result = await resultTask.ConfigureAwait(false);
hasResult = true;
}
else
{
await task.ConfigureAwait(false);
}
}
catch (Exception ex)
{
Log.ErrorInvokingClientSideMethod(_logger, invocation.Target, ex);
if (handler.HasResult)
{
resultException = ex;
}
}
}

if (expectsResult)
{
if (resultException is not null)
{
await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, resultException.Message), cancellationToken: default).ConfigureAwait(false);
}
else if (hasResult)
{
await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false);
}
else
{
Log.MissingResultHandler(_logger, invocation.Target);
await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false);
}
}
else if (hasResult)
{
Log.ResultNotExpected(_logger, invocation.Target);
}
}

Expand Down Expand Up @@ -1073,7 +1152,7 @@ private async Task HandshakeAsync(ConnectionState startingConnectionState, Cance
try
{
// cancellationToken already contains _state.StopCts.Token, so we don't have to link it again
using (CreateLinkedToken(cancellationToken, handshakeCts.Token, out var linkedToken))
using (CancellationTokenUtils.CreateLinkedToken(cancellationToken, handshakeCts.Token, out var linkedToken))
{
while (true)
{
Expand Down Expand Up @@ -1178,7 +1257,7 @@ async Task StartProcessingInvocationMessages(ChannelReader<InvocationMessage> in
{
while (invocationMessageChannelReader.TryRead(out var invocationMessage))
{
await DispatchInvocationAsync(invocationMessage).ConfigureAwait(false);
await DispatchInvocationAsync(invocationMessage, connectionState).ConfigureAwait(false);
}
}
}
Expand Down Expand Up @@ -1562,26 +1641,6 @@ async Task RunReconnectedEventAsync()
}
}

private static IDisposable? CreateLinkedToken(CancellationToken token1, CancellationToken token2, out CancellationToken linkedToken)
{
if (!token1.CanBeCanceled)
{
linkedToken = token2;
return null;
}
else if (!token2.CanBeCanceled)
{
linkedToken = token1;
return null;
}
else
{
var cts = CancellationTokenSource.CreateLinkedTokenSource(token1, token2);
linkedToken = cts.Token;
return cts;
}
}

// Debug.Assert plays havoc with Unit Tests. But I want something that I can "assert" only in Debug builds.
[Conditional("DEBUG")]
private static void SafeAssert(bool condition, string message, [CallerMemberName] string? memberName = null, [CallerFilePath] string? fileName = null, [CallerLineNumber] int lineNumber = 0)
Expand Down Expand Up @@ -1639,10 +1698,20 @@ internal InvocationHandler[] GetHandlers()
return handlers;
}

internal void Add(InvocationHandler handler)
internal void Add(string methodName, InvocationHandler handler)
{
lock (_invocationHandlers)
{
if (handler.HasResult)
{
foreach (var m in _invocationHandlers)
{
if (m.HasResult)
{
throw new InvalidOperationException($"'{methodName}' already has a value returning handler. Multiple return values are not supported.");
}
}
}
_invocationHandlers.Add(handler);
_copiedHandlers = null;
}
Expand All @@ -1663,6 +1732,7 @@ internal void Remove(InvocationHandler handler)
private readonly struct InvocationHandler
{
public Type[] ParameterTypes { get; }
public bool HasResult => _callback.Method.ReturnType == typeof(Task<object>);
private readonly Func<object?[], object, Task> _callback;
private readonly object _state;

Expand Down
Loading