From 2f943b2fff8439badde67b00a747e38f2f317ab3 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Mon, 18 Jul 2022 14:26:39 -0700 Subject: [PATCH 1/6] [SignalR] Avoid blocking common InvokeAsync usage --- .../server/Core/src/HubConnectionContext.cs | 3 +- .../Core/src/Internal/DefaultHubDispatcher.cs | 2 +- .../Core/src/Internal/HubCallerClients.cs | 48 ++++++++++++--- .../HubConnectionHandlerTests.ClientResult.cs | 61 ++++++++++++++++++- 4 files changed, 102 insertions(+), 12 deletions(-) diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 04e211d74f86..87d33b5aa4f4 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -83,7 +83,8 @@ public HubConnectionContext(ConnectionContext connectionContext, HubConnectionCo var maxInvokeLimit = contextOptions.MaximumParallelInvocations; if (maxInvokeLimit != 1) { - ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit, maxInvokeLimit); + // Don't specify max count, this is so InvokeAsync inside hub methods will not be able to soft-lock a connection if it's run on a separate thread from the hub method, or just not awaited + ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit); } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 59b06dbf101a..3b75a73c7953 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -73,7 +73,7 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex public override async Task OnConnectedAsync(HubConnectionContext connection) { await using var scope = _serviceScopeFactory.CreateAsyncScope(); - connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit is not null); + connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit); var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index e2a65ca7d1d9..016b05bbb9d8 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -8,19 +8,19 @@ internal sealed class HubCallerClients : IHubCallerClients private readonly string _connectionId; private readonly IHubClients _hubClients; private readonly string[] _currentConnectionId; - private readonly bool _parallelEnabled; + private readonly SemaphoreSlim? _parallelInvokes; // Client results don't work in OnConnectedAsync // This property is set by the hub dispatcher when those methods are being called // so we can prevent users from making blocking client calls by returning a custom ISingleClientProxy instance internal bool InvokeAllowed { get; set; } - public HubCallerClients(IHubClients hubClients, string connectionId, bool parallelEnabled) + public HubCallerClients(IHubClients hubClients, string connectionId, SemaphoreSlim? parallelInvokes) { _connectionId = connectionId; _hubClients = hubClients; _currentConnectionId = new[] { _connectionId }; - _parallelEnabled = parallelEnabled; + _parallelInvokes = parallelInvokes; } IClientProxy IHubCallerClients.Caller => Caller; @@ -28,7 +28,7 @@ public ISingleClientProxy Caller { get { - if (!_parallelEnabled) + if (_parallelInvokes is null) { return new NotParallelSingleClientProxy(_hubClients.Client(_connectionId)); } @@ -36,7 +36,7 @@ public ISingleClientProxy Caller { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); } - return _hubClients.Client(_connectionId); + return new SingleClientProxy(_hubClients.Client(_connectionId), _parallelInvokes); } } @@ -52,7 +52,7 @@ public IClientProxy AllExcept(IReadOnlyList excludedConnectionIds) IClientProxy IHubClients.Client(string connectionId) => Client(connectionId); public ISingleClientProxy Client(string connectionId) { - if (!_parallelEnabled) + if (_parallelInvokes is null) { return new NotParallelSingleClientProxy(_hubClients.Client(connectionId)); } @@ -60,7 +60,7 @@ public ISingleClientProxy Client(string connectionId) { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); } - return _hubClients.Client(connectionId); + return new SingleClientProxy(_hubClients.Client(connectionId), _parallelInvokes); } public IClientProxy Group(string groupName) @@ -137,4 +137,38 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance return _proxy.SendCoreAsync(method, args, cancellationToken); } } + + private sealed class SingleClientProxy : ISingleClientProxy + { + private readonly ISingleClientProxy _proxy; + private readonly SemaphoreSlim _parallelInvokes; + + public SingleClientProxy(ISingleClientProxy hubClients, SemaphoreSlim parallelInvokes) + { + _proxy = hubClients; + _parallelInvokes = parallelInvokes; + } + + public async Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + // Releases the SemaphoreSlim that is blocking pending invokes, which in turn can block the receive loop. + // Because we are waiting for a result from the client we need to let the receive loop run otherwise we'll be blocked forever + _parallelInvokes.Release(); + try + { + var result = await _proxy.InvokeCoreAsync(method, args, cancellationToken); + return result; + } + finally + { + // Re-acquire the SemaphoreSlim, this is because when the hub method completes it will call release + await _parallelInvokes.WaitAsync(CancellationToken.None); + } + } + + public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) + { + return _proxy.SendCoreAsync(method, args, cancellationToken); + } + } } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs index 74b5dcec0b8f..b894f7cdfa87 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -16,7 +16,7 @@ public async Task CanReturnClientResultToHub() { var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { - // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations + // Waiting for a client result requires multiple invocations enabled builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); @@ -47,7 +47,7 @@ public async Task CanReturnClientResultErrorToHub() { var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { - // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations + // Waiting for a client result requires multiple invocations enabled builder.AddSignalR(o => { o.MaximumParallelInvocationsPerClient = 2; @@ -237,7 +237,7 @@ public async Task CanReturnClientResultToTypedHubTwoWays() { var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { - // Waiting for a client result blocks the hub dispatcher pipeline, need to allow multiple invocations + // Waiting for a client result requires multiple invocations enabled builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); @@ -266,6 +266,61 @@ public async Task CanReturnClientResultToTypedHubTwoWays() } } + [Fact] + public async Task ClientResultFromHubDoesNotBlockReceiveLoop() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + // Waiting for a client result requires multiple invocations enabled + builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + // block 1 of the 2 parallel invocations + _ = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.BlockingMethod), Array.Empty())).DefaultTimeout(); + + // make multiple invocations which would normally block the invocation processing + var invocationId = await client.SendHubMessageAsync(new InvocationMessage("2", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + var invocationId2 = await client.SendHubMessageAsync(new InvocationMessage("3", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + var invocationId3 = await client.SendHubMessageAsync(new InvocationMessage("4", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); + + // Read all 3 invocation messages from the server, shows that the hub processing continued even though parallel invokes is 2 + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + var invocationMessage2 = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + var invocationMessage3 = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + + Assert.NotNull(invocationMessage.InvocationId); + Assert.NotNull(invocationMessage2.InvocationId); + Assert.NotNull(invocationMessage3.InvocationId); + var res = 4 + ((long)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + res = 5 + ((long)invocationMessage2.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage2.InvocationId, res)).DefaultTimeout(); + res = 6 + ((long)invocationMessage3.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage3.InvocationId, res)).DefaultTimeout(); + + var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(9L, completion.Result); + Assert.Equal(invocationId, completion.InvocationId); + + completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(10L, completion.Result); + Assert.Equal(invocationId2, completion.InvocationId); + + completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); + Assert.Equal(11L, completion.Result); + Assert.Equal(invocationId3, completion.InvocationId); + } + } + } + private class TestBinder : IInvocationBinder { public IReadOnlyList GetParameterTypes(string methodName) From 314654ec79a3fb1c525b2788fb7395a4c6ffe547 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Thu, 11 Aug 2022 15:36:33 -0700 Subject: [PATCH 2/6] channel --- .../server/Core/src/HubConnectionContext.cs | 10 ++-- .../server/Core/src/HubConnectionHandler.cs | 3 ++ .../Core/src/Internal/ChannelExtensions.cs | 38 +++++++++++++++ .../Core/src/Internal/DefaultHubDispatcher.cs | 4 +- .../Core/src/Internal/HubCallerClients.cs | 46 ++++-------------- .../src/Internal/SemaphoreSlimExtensions.cs | 36 -------------- .../HubConnectionHandlerTests.ClientResult.cs | 47 +------------------ 7 files changed, 60 insertions(+), 124 deletions(-) create mode 100644 src/SignalR/server/Core/src/Internal/ChannelExtensions.cs delete mode 100644 src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 87d33b5aa4f4..b42b17149029 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -7,6 +7,7 @@ using System.Diagnostics.CodeAnalysis; using System.IO.Pipelines; using System.Security.Claims; +using System.Threading.Channels; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; @@ -79,12 +80,11 @@ public HubConnectionContext(ConnectionContext connectionContext, HubConnectionCo _systemClock = contextOptions.SystemClock ?? new SystemClock(); _lastSendTick = _systemClock.CurrentTicks; - // We'll be avoiding using the semaphore when the limit is set to 1, so no need to allocate it var maxInvokeLimit = contextOptions.MaximumParallelInvocations; - if (maxInvokeLimit != 1) + ActiveInvocationLimit = Channel.CreateBounded(maxInvokeLimit); + for (var i = 0; i < maxInvokeLimit; i++) { - // Don't specify max count, this is so InvokeAsync inside hub methods will not be able to soft-lock a connection if it's run on a separate thread from the hub method, or just not awaited - ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit); + ActiveInvocationLimit.Writer.TryWrite(1); } } @@ -107,7 +107,7 @@ internal StreamTracker StreamTracker internal Exception? CloseException { get; private set; } - internal SemaphoreSlim? ActiveInvocationLimit { get; } + internal Channel ActiveInvocationLimit { get; } /// /// Gets a that notifies when the connection is aborted. diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 3bb5566e6a7a..4b5fea1cace4 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -200,6 +200,9 @@ private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Excep // Ensure the connection is aborted before firing disconnect await connection.AbortAsync(); + // If a client result is requested in OnDisconnectedAsync we want to make sure it isn't blocked by the ActiveInvocationLimit + _ = connection.ActiveInvocationLimit.Reader.TryRead(out _); + try { await _dispatcher.OnDisconnectedAsync(connection, exception); diff --git a/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs b/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs new file mode 100644 index 000000000000..16e3797ae425 --- /dev/null +++ b/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Channels; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +internal static class ChannelExtensions +{ + public static ValueTask RunAsync(this Channel semaphoreSlim, Func callback, TState state) + { + if (semaphoreSlim.Reader.TryRead(out _)) + { + _ = RunTask(callback, semaphoreSlim, state); + return ValueTask.CompletedTask; + } + + return RunSlowAsync(semaphoreSlim, callback, state); + } + + private static async ValueTask RunSlowAsync(this Channel semaphoreSlim, Func callback, TState state) + { + _ = await semaphoreSlim.Reader.ReadAsync(); + _ = RunTask(callback, semaphoreSlim, state); + } + + static async Task RunTask(Func callback, Channel semaphoreSlim, TState state) + { + try + { + await callback(state); + } + finally + { + await semaphoreSlim.Writer.WriteAsync(1); + } + } +} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 3b75a73c7953..1ebc547a18a3 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -256,13 +256,13 @@ private Task ProcessInvocation(HubConnectionContext connection, else { bool isStreamCall = descriptor.StreamingParameters != null; - if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse) + if (!isStreamCall && !isStreamResponse) { return connection.ActiveInvocationLimit.RunAsync(static state => { var (dispatcher, descriptor, connection, invocationMessage) = state; return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false); - }, (this, descriptor, connection, hubMethodInvocationMessage)); + }, (this, descriptor, connection, hubMethodInvocationMessage)).AsTask(); } else { diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index 016b05bbb9d8..3f208843f0ff 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Threading.Channels; + namespace Microsoft.AspNetCore.SignalR.Internal; internal sealed class HubCallerClients : IHubCallerClients @@ -8,14 +10,14 @@ internal sealed class HubCallerClients : IHubCallerClients private readonly string _connectionId; private readonly IHubClients _hubClients; private readonly string[] _currentConnectionId; - private readonly SemaphoreSlim? _parallelInvokes; + private readonly Channel _parallelInvokes; // Client results don't work in OnConnectedAsync // This property is set by the hub dispatcher when those methods are being called // so we can prevent users from making blocking client calls by returning a custom ISingleClientProxy instance internal bool InvokeAllowed { get; set; } - public HubCallerClients(IHubClients hubClients, string connectionId, SemaphoreSlim? parallelInvokes) + public HubCallerClients(IHubClients hubClients, string connectionId, Channel parallelInvokes) { _connectionId = connectionId; _hubClients = hubClients; @@ -28,10 +30,6 @@ public ISingleClientProxy Caller { get { - if (_parallelInvokes is null) - { - return new NotParallelSingleClientProxy(_hubClients.Client(_connectionId)); - } if (!InvokeAllowed) { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); @@ -52,10 +50,6 @@ public IClientProxy AllExcept(IReadOnlyList excludedConnectionIds) IClientProxy IHubClients.Client(string connectionId) => Client(connectionId); public ISingleClientProxy Client(string connectionId) { - if (_parallelInvokes is null) - { - return new NotParallelSingleClientProxy(_hubClients.Client(connectionId)); - } if (!InvokeAllowed) { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); @@ -98,26 +92,6 @@ public IClientProxy Users(IReadOnlyList userIds) return _hubClients.Users(userIds); } - private sealed class NotParallelSingleClientProxy : ISingleClientProxy - { - private readonly ISingleClientProxy _proxy; - - public NotParallelSingleClientProxy(ISingleClientProxy hubClients) - { - _proxy = hubClients; - } - - public Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) - { - throw new InvalidOperationException("Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1."); - } - - public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) - { - return _proxy.SendCoreAsync(method, args, cancellationToken); - } - } - private sealed class NoInvokeSingleClientProxy : ISingleClientProxy { private readonly ISingleClientProxy _proxy; @@ -141,9 +115,9 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance private sealed class SingleClientProxy : ISingleClientProxy { private readonly ISingleClientProxy _proxy; - private readonly SemaphoreSlim _parallelInvokes; + private readonly Channel _parallelInvokes; - public SingleClientProxy(ISingleClientProxy hubClients, SemaphoreSlim parallelInvokes) + public SingleClientProxy(ISingleClientProxy hubClients, Channel parallelInvokes) { _proxy = hubClients; _parallelInvokes = parallelInvokes; @@ -151,9 +125,9 @@ public SingleClientProxy(ISingleClientProxy hubClients, SemaphoreSlim parallelIn public async Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) { - // Releases the SemaphoreSlim that is blocking pending invokes, which in turn can block the receive loop. + // Releases the Channel that is blocking pending invokes, which in turn can block the receive loop. // Because we are waiting for a result from the client we need to let the receive loop run otherwise we'll be blocked forever - _parallelInvokes.Release(); + await _parallelInvokes.Writer.WriteAsync(1, cancellationToken); try { var result = await _proxy.InvokeCoreAsync(method, args, cancellationToken); @@ -161,8 +135,8 @@ public async Task InvokeCoreAsync(string method, object?[] args, Cancellat } finally { - // Re-acquire the SemaphoreSlim, this is because when the hub method completes it will call release - await _parallelInvokes.WaitAsync(CancellationToken.None); + // Re-read from the channel, this is because when the hub method completes it will release (write an entry) which we already did above, so we need to reset the state + _ = await _parallelInvokes.Reader.ReadAsync(CancellationToken.None); } } diff --git a/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs b/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs deleted file mode 100644 index a238d09643e3..000000000000 --- a/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.AspNetCore.SignalR.Internal; - -internal static class SemaphoreSlimExtensions -{ - public static Task RunAsync(this SemaphoreSlim semaphoreSlim, Func callback, TState state) - { - if (semaphoreSlim.Wait(0)) - { - _ = RunTask(callback, semaphoreSlim, state); - return Task.CompletedTask; - } - - return RunSlowAsync(semaphoreSlim, callback, state); - } - - private static async Task RunSlowAsync(this SemaphoreSlim semaphoreSlim, Func callback, TState state) - { - await semaphoreSlim.WaitAsync(); - return RunTask(callback, semaphoreSlim, state); - } - - static async Task RunTask(Func callback, SemaphoreSlim semaphoreSlim, TState state) - { - try - { - await callback(state); - } - finally - { - semaphoreSlim.Release(); - } - } -} diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs index b894f7cdfa87..313d3624721b 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -14,11 +14,7 @@ public async Task CanReturnClientResultToHub() { using (StartVerifiableLog()) { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => - { - // Waiting for a client result requires multiple invocations enabled - builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); - }, LoggerFactory); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); using (var client = new TestClient()) @@ -47,10 +43,8 @@ public async Task CanReturnClientResultErrorToHub() { var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { - // Waiting for a client result requires multiple invocations enabled builder.AddSignalR(o => { - o.MaximumParallelInvocationsPerClient = 2; o.EnableDetailedErrors = true; }); }, LoggerFactory); @@ -74,36 +68,6 @@ public async Task CanReturnClientResultErrorToHub() } } - [Fact] - public async Task ThrowsWhenParallelHubInvokesNotEnabled() - { - using (StartVerifiableLog(write => write.EventId.Name == "FailedInvokingHubMethod")) - { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => - { - builder.AddSignalR(o => - { - o.MaximumParallelInvocationsPerClient = 1; - o.EnableDetailedErrors = true; - }); - }, LoggerFactory); - var connectionHandler = serviceProvider.GetService>(); - - using (var client = new TestClient()) - { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); - - var invocationId = await client.SendHubMessageAsync(new InvocationMessage("1", nameof(MethodHub.GetClientResult), new object[] { 5 })).DefaultTimeout(); - - // Hub asks client for a result, this is an invocation message with an ID - var completionMessage = Assert.IsType(await client.ReadAsync().DefaultTimeout()); - Assert.Equal(invocationId, completionMessage.InvocationId); - Assert.Equal("An unexpected error occurred invoking 'GetClientResult' on the server. InvalidOperationException: Client results inside a Hub method requires HubOptions.MaximumParallelInvocationsPerClient to be greater than 1.", - completionMessage.Error); - } - } - } - [Fact] public async Task ThrowsWhenUsedInOnConnectedAsync() { @@ -113,7 +77,6 @@ public async Task ThrowsWhenUsedInOnConnectedAsync() { builder.AddSignalR(o => { - o.MaximumParallelInvocationsPerClient = 2; o.EnableDetailedErrors = true; }); }, LoggerFactory); @@ -141,7 +104,6 @@ public async Task ThrowsWhenUsedInOnDisconnectedAsync() { builder.AddSignalR(o => { - o.MaximumParallelInvocationsPerClient = 2; o.EnableDetailedErrors = true; }); }, LoggerFactory); @@ -235,11 +197,7 @@ public async Task CanReturnClientResultToTypedHubTwoWays() { using (StartVerifiableLog()) { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => - { - // Waiting for a client result requires multiple invocations enabled - builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); - }, LoggerFactory); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); using var client = new TestClient(invocationBinder: new GetClientResultTwoWaysInvocationBinder()); @@ -273,7 +231,6 @@ public async Task ClientResultFromHubDoesNotBlockReceiveLoop() { var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { - // Waiting for a client result requires multiple invocations enabled builder.AddSignalR(o => o.MaximumParallelInvocationsPerClient = 2); }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); From 39d54cd723245e206edc03762dae2ace86d7ba50 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Thu, 11 Aug 2022 18:05:15 -0700 Subject: [PATCH 3/6] fixup test --- .../server/SignalR/test/HubConnectionHandlerTests.cs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 8b35d4647590..32af4cfe6246 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -2991,16 +2991,23 @@ public async Task HubMethodInvokeDoesNotCountTowardsClientTimeout() await client.SendHubMessageAsync(PingMessage.Instance); // Call long running hub method - var hubMethodTask = client.InvokeAsync(nameof(LongRunningHub.LongRunningMethod)); + var hubMethodTask1 = client.InvokeAsync(nameof(LongRunningHub.LongRunningMethod)); await tcsService.StartedMethod.Task.DefaultTimeout(); + // Wait for server to start reading again + await customDuplex.WrappedPipeReader.WaitForReadStart().DefaultTimeout(); + // Send another invocation to server, since we use Inline scheduling we know that once this call completes the server will have read and processed + // the message, it should be stuck waiting for the in-progress invoke now + _ = await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod)).DefaultTimeout(); + // Tick heartbeat while hub method is running to show that close isn't triggered client.TickHeartbeat(); // Unblock long running hub method tcsService.EndMethod.SetResult(null); - await hubMethodTask.DefaultTimeout(); + await hubMethodTask1.DefaultTimeout(); + await client.ReadAsync().DefaultTimeout(); // There is a small window when the hub method finishes and the timer starts again // So we need to delay a little before ticking the heart beat. From 8fcf442f136f8fcc14edc266049764c986350a18 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Fri, 12 Aug 2022 09:49:44 -0700 Subject: [PATCH 4/6] fb --- .../server/Core/src/HubConnectionContext.cs | 9 +---- .../server/Core/src/HubConnectionHandler.cs | 2 +- .../src/Internal/ChannelBasedSemaphore.cs | 37 +++++++++++++++++++ .../Core/src/Internal/ChannelExtensions.cs | 20 +++++----- .../Core/src/Internal/HubCallerClients.cs | 16 ++++---- .../HubConnectionHandlerTests.ClientResult.cs | 9 ++--- 6 files changed, 60 insertions(+), 33 deletions(-) create mode 100644 src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index b42b17149029..6b221be915cb 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -7,7 +7,6 @@ using System.Diagnostics.CodeAnalysis; using System.IO.Pipelines; using System.Security.Claims; -using System.Threading.Channels; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; @@ -81,11 +80,7 @@ public HubConnectionContext(ConnectionContext connectionContext, HubConnectionCo _lastSendTick = _systemClock.CurrentTicks; var maxInvokeLimit = contextOptions.MaximumParallelInvocations; - ActiveInvocationLimit = Channel.CreateBounded(maxInvokeLimit); - for (var i = 0; i < maxInvokeLimit; i++) - { - ActiveInvocationLimit.Writer.TryWrite(1); - } + ActiveInvocationLimit = new ChannelBasedSemaphore(maxInvokeLimit); } internal StreamTracker StreamTracker @@ -107,7 +102,7 @@ internal StreamTracker StreamTracker internal Exception? CloseException { get; private set; } - internal Channel ActiveInvocationLimit { get; } + internal ChannelBasedSemaphore ActiveInvocationLimit { get; } /// /// Gets a that notifies when the connection is aborted. diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 4b5fea1cace4..28d4e61e1ca7 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -201,7 +201,7 @@ private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Excep await connection.AbortAsync(); // If a client result is requested in OnDisconnectedAsync we want to make sure it isn't blocked by the ActiveInvocationLimit - _ = connection.ActiveInvocationLimit.Reader.TryRead(out _); + _ = connection.ActiveInvocationLimit.AttemptWait(); try { diff --git a/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs new file mode 100644 index 000000000000..4ec8cf899d6b --- /dev/null +++ b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Channels; + +namespace Microsoft.AspNetCore.SignalR.Internal; + +// Use a Channel instead of a SemaphoreSlim so that we can potentially save task allocations (ValueTask!) +// Additionally initial perf results show faster RPS when using Channel instead of SemaphoreSlim +internal class ChannelBasedSemaphore +{ + internal readonly Channel _channel; + + public ChannelBasedSemaphore(int maxCapacity) + { + _channel = Channel.CreateBounded(maxCapacity); + for (var i = 0; i < maxCapacity; i++) + { + _channel.Writer.TryWrite(1); + } + } + + public bool AttemptWait() + { + return _channel.Reader.TryRead(out _); + } + + public ValueTask WaitAsync(CancellationToken cancellationToken = default) + { + return _channel.Reader.ReadAsync(cancellationToken); + } + + public ValueTask ReleaseAsync(CancellationToken cancellationToken = default) + { + return _channel.Writer.WriteAsync(1, cancellationToken); + } +} diff --git a/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs b/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs index 16e3797ae425..e96a2c367388 100644 --- a/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs +++ b/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs @@ -1,30 +1,28 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Threading.Channels; - namespace Microsoft.AspNetCore.SignalR.Internal; internal static class ChannelExtensions { - public static ValueTask RunAsync(this Channel semaphoreSlim, Func callback, TState state) + public static ValueTask RunAsync(this ChannelBasedSemaphore channelSemaphore, Func callback, TState state) { - if (semaphoreSlim.Reader.TryRead(out _)) + if (channelSemaphore.AttemptWait()) { - _ = RunTask(callback, semaphoreSlim, state); + _ = RunTask(callback, channelSemaphore, state); return ValueTask.CompletedTask; } - return RunSlowAsync(semaphoreSlim, callback, state); + return RunSlowAsync(channelSemaphore, callback, state); } - private static async ValueTask RunSlowAsync(this Channel semaphoreSlim, Func callback, TState state) + private static async ValueTask RunSlowAsync(this ChannelBasedSemaphore channelSemaphore, Func callback, TState state) { - _ = await semaphoreSlim.Reader.ReadAsync(); - _ = RunTask(callback, semaphoreSlim, state); + _ = await channelSemaphore.WaitAsync(); + _ = RunTask(callback, channelSemaphore, state); } - static async Task RunTask(Func callback, Channel semaphoreSlim, TState state) + static async Task RunTask(Func callback, ChannelBasedSemaphore channelSemaphore, TState state) { try { @@ -32,7 +30,7 @@ static async Task RunTask(Func callback, Channel sema } finally { - await semaphoreSlim.Writer.WriteAsync(1); + await channelSemaphore.ReleaseAsync(); } } } diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index 3f208843f0ff..8571965b5776 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -1,8 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Threading.Channels; - namespace Microsoft.AspNetCore.SignalR.Internal; internal sealed class HubCallerClients : IHubCallerClients @@ -10,14 +8,14 @@ internal sealed class HubCallerClients : IHubCallerClients private readonly string _connectionId; private readonly IHubClients _hubClients; private readonly string[] _currentConnectionId; - private readonly Channel _parallelInvokes; + private readonly ChannelBasedSemaphore _parallelInvokes; // Client results don't work in OnConnectedAsync // This property is set by the hub dispatcher when those methods are being called // so we can prevent users from making blocking client calls by returning a custom ISingleClientProxy instance internal bool InvokeAllowed { get; set; } - public HubCallerClients(IHubClients hubClients, string connectionId, Channel parallelInvokes) + public HubCallerClients(IHubClients hubClients, string connectionId, ChannelBasedSemaphore parallelInvokes) { _connectionId = connectionId; _hubClients = hubClients; @@ -115,9 +113,9 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance private sealed class SingleClientProxy : ISingleClientProxy { private readonly ISingleClientProxy _proxy; - private readonly Channel _parallelInvokes; + private readonly ChannelBasedSemaphore _parallelInvokes; - public SingleClientProxy(ISingleClientProxy hubClients, Channel parallelInvokes) + public SingleClientProxy(ISingleClientProxy hubClients, ChannelBasedSemaphore parallelInvokes) { _proxy = hubClients; _parallelInvokes = parallelInvokes; @@ -127,7 +125,7 @@ public async Task InvokeCoreAsync(string method, object?[] args, Cancellat { // Releases the Channel that is blocking pending invokes, which in turn can block the receive loop. // Because we are waiting for a result from the client we need to let the receive loop run otherwise we'll be blocked forever - await _parallelInvokes.Writer.WriteAsync(1, cancellationToken); + await _parallelInvokes.ReleaseAsync(cancellationToken); try { var result = await _proxy.InvokeCoreAsync(method, args, cancellationToken); @@ -135,8 +133,8 @@ public async Task InvokeCoreAsync(string method, object?[] args, Cancellat } finally { - // Re-read from the channel, this is because when the hub method completes it will release (write an entry) which we already did above, so we need to reset the state - _ = await _parallelInvokes.Reader.ReadAsync(CancellationToken.None); + // Re-wait, this is because when the hub method completes it will release which we already did above, so we need to reset the state + _ = await _parallelInvokes.WaitAsync(CancellationToken.None); } } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs index 313d3624721b..c1f7f96cc358 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -258,19 +258,18 @@ public async Task ClientResultFromHubDoesNotBlockReceiveLoop() Assert.NotNull(invocationMessage3.InvocationId); var res = 4 + ((long)invocationMessage.Arguments[0]); await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); - res = 5 + ((long)invocationMessage2.Arguments[0]); - await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage2.InvocationId, res)).DefaultTimeout(); - res = 6 + ((long)invocationMessage3.Arguments[0]); - await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage3.InvocationId, res)).DefaultTimeout(); - var completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); Assert.Equal(9L, completion.Result); Assert.Equal(invocationId, completion.InvocationId); + res = 5 + ((long)invocationMessage2.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage2.InvocationId, res)).DefaultTimeout(); completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); Assert.Equal(10L, completion.Result); Assert.Equal(invocationId2, completion.InvocationId); + res = 6 + ((long)invocationMessage3.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage3.InvocationId, res)).DefaultTimeout(); completion = Assert.IsType(await client.ReadAsync().DefaultTimeout()); Assert.Equal(11L, completion.Result); Assert.Equal(invocationId3, completion.InvocationId); From 863cd4832140ed2e594cc2183fa9aac5bd6bec6d Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Fri, 12 Aug 2022 10:12:32 -0700 Subject: [PATCH 5/6] sealed --- src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs index 4ec8cf899d6b..2929bbde96c7 100644 --- a/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs +++ b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs @@ -7,7 +7,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal; // Use a Channel instead of a SemaphoreSlim so that we can potentially save task allocations (ValueTask!) // Additionally initial perf results show faster RPS when using Channel instead of SemaphoreSlim -internal class ChannelBasedSemaphore +internal sealed class ChannelBasedSemaphore { internal readonly Channel _channel; From 8c0b9b3fbcb795a64e2532817d3eb7102c3a9e5e Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Thu, 18 Aug 2022 20:13:09 -0700 Subject: [PATCH 6/6] crazy --- .../server/Core/src/HubConnectionContext.cs | 1 - .../server/Core/src/HubConnectionHandler.cs | 4 +- .../src/Internal/ChannelBasedSemaphore.cs | 48 +++++++++++++++++-- .../Core/src/Internal/ChannelExtensions.cs | 36 -------------- .../Core/src/Internal/DefaultHubDispatcher.cs | 26 +++++----- .../Core/src/Internal/HubCallerClients.cs | 35 +++++++------- .../HubConnectionHandlerTestUtils/Hubs.cs | 18 +++++++ .../HubConnectionHandlerTests.ClientResult.cs | 40 ++++++++++++++++ 8 files changed, 135 insertions(+), 73 deletions(-) delete mode 100644 src/SignalR/server/Core/src/Internal/ChannelExtensions.cs diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 6b221be915cb..5ac7769617ce 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -98,7 +98,6 @@ internal StreamTracker StreamTracker } internal HubCallerContext HubCallerContext { get; } - internal HubCallerClients HubCallerClients { get; set; } = null!; internal Exception? CloseException { get; private set; } diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 28d4e61e1ca7..3269d6f7afcd 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -200,8 +200,8 @@ private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Excep // Ensure the connection is aborted before firing disconnect await connection.AbortAsync(); - // If a client result is requested in OnDisconnectedAsync we want to make sure it isn't blocked by the ActiveInvocationLimit - _ = connection.ActiveInvocationLimit.AttemptWait(); + // If a client result is requested in OnDisconnectedAsync we want to avoid the SemaphoreFullException and get the better connection disconnected IOException + _ = connection.ActiveInvocationLimit.TryAcquire(); try { diff --git a/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs index 2929bbde96c7..8b6bbbe0ec6f 100644 --- a/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs +++ b/src/SignalR/server/Core/src/Internal/ChannelBasedSemaphore.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Threading.Channels; namespace Microsoft.AspNetCore.SignalR.Internal; @@ -9,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal; // Additionally initial perf results show faster RPS when using Channel instead of SemaphoreSlim internal sealed class ChannelBasedSemaphore { - internal readonly Channel _channel; + private readonly Channel _channel; public ChannelBasedSemaphore(int maxCapacity) { @@ -20,18 +21,57 @@ public ChannelBasedSemaphore(int maxCapacity) } } - public bool AttemptWait() + public bool TryAcquire() { return _channel.Reader.TryRead(out _); } + // The int result isn't important, only reason it's exposed is because ValueTask doesn't implement ValueTask so we can't cast like we could with Task to Task public ValueTask WaitAsync(CancellationToken cancellationToken = default) { return _channel.Reader.ReadAsync(cancellationToken); } - public ValueTask ReleaseAsync(CancellationToken cancellationToken = default) + public void Release() { - return _channel.Writer.WriteAsync(1, cancellationToken); + if (!_channel.Writer.TryWrite(1)) + { + throw new SemaphoreFullException(); + } + } + + public ValueTask RunAsync(Func> callback, TState state) + { + if (TryAcquire()) + { + _ = RunTask(callback, state); + return ValueTask.CompletedTask; + } + + return RunSlowAsync(callback, state); + } + + private async ValueTask RunSlowAsync(Func> callback, TState state) + { + _ = await WaitAsync(); + _ = RunTask(callback, state); + } + + private async Task RunTask(Func> callback, TState state) + { + try + { + var shouldRelease = await callback(state); + if (shouldRelease) + { + Release(); + } + } + catch + { + // DefaultHubDispatcher catches and handles exceptions + // It does write to the connection in exception cases which also can't throw because we catch and log in HubConnectionContext + Debug.Assert(false); + } } } diff --git a/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs b/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs deleted file mode 100644 index e96a2c367388..000000000000 --- a/src/SignalR/server/Core/src/Internal/ChannelExtensions.cs +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -namespace Microsoft.AspNetCore.SignalR.Internal; - -internal static class ChannelExtensions -{ - public static ValueTask RunAsync(this ChannelBasedSemaphore channelSemaphore, Func callback, TState state) - { - if (channelSemaphore.AttemptWait()) - { - _ = RunTask(callback, channelSemaphore, state); - return ValueTask.CompletedTask; - } - - return RunSlowAsync(channelSemaphore, callback, state); - } - - private static async ValueTask RunSlowAsync(this ChannelBasedSemaphore channelSemaphore, Func callback, TState state) - { - _ = await channelSemaphore.WaitAsync(); - _ = RunTask(callback, channelSemaphore, state); - } - - static async Task RunTask(Func callback, ChannelBasedSemaphore channelSemaphore, TState state) - { - try - { - await callback(state); - } - finally - { - await channelSemaphore.ReleaseAsync(); - } - } -} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 1ebc547a18a3..30c06e594c7c 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -73,13 +73,13 @@ public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContex public override async Task OnConnectedAsync(HubConnectionContext connection) { await using var scope = _serviceScopeFactory.CreateAsyncScope(); - connection.HubCallerClients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit); var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { - InitializeHub(hub, connection); + // OnConnectedAsync won't work with client results (ISingleClientProxy.InvokeAsync) + InitializeHub(hub, connection, invokeAllowed: false); if (_onConnectedMiddleware != null) { @@ -90,9 +90,6 @@ public override async Task OnConnectedAsync(HubConnectionContext connection) { await hub.OnConnectedAsync(); } - - // OnConnectedAsync is finished, allow hub methods to use client results (ISingleClientProxy.InvokeAsync) - connection.HubCallerClients.InvokeAllowed = true; } finally { @@ -271,11 +268,12 @@ private Task ProcessInvocation(HubConnectionContext connection, } } - private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, + private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall) { var methodExecutor = descriptor.MethodExecutor; + var wasSemaphoreReleased = false; var disposeScope = true; var scope = _serviceScopeFactory.CreateAsyncScope(); IHubActivator? hubActivator = null; @@ -290,12 +288,12 @@ private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext c Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, $"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized"); - return; + return true; } if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection)) { - return; + return true; } try @@ -308,7 +306,7 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, Log.InvalidHubParameters(_logger, hubMethodInvocationMessage.Target, ex); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); - return; + return true; } InitializeHub(hub, connection); @@ -404,9 +402,15 @@ await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, { if (disposeScope) { + if (hub?.Clients is HubCallerClients hubCallerClients) + { + wasSemaphoreReleased = Interlocked.CompareExchange(ref hubCallerClients.ShouldReleaseSemaphore, 0, 1) == 0; + } await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); } } + + return !wasSemaphoreReleased; } private static ValueTask CleanupInvocation(HubConnectionContext connection, HubMethodInvocationMessage hubMessage, IHubActivator? hubActivator, @@ -553,9 +557,9 @@ private static async Task SendInvocationError(string? invocationId, await connection.WriteAsync(CompletionMessage.WithError(invocationId, errorMessage)); } - private void InitializeHub(THub hub, HubConnectionContext connection) + private void InitializeHub(THub hub, HubConnectionContext connection, bool invokeAllowed = true) { - hub.Clients = connection.HubCallerClients; + hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId, connection.ActiveInvocationLimit) { InvokeAllowed = invokeAllowed }; hub.Context = connection.HubCallerContext; hub.Groups = _hubContext.Groups; } diff --git a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs index 8571965b5776..8e6ec0fa0dc9 100644 --- a/src/SignalR/server/Core/src/Internal/HubCallerClients.cs +++ b/src/SignalR/server/Core/src/Internal/HubCallerClients.cs @@ -7,8 +7,9 @@ internal sealed class HubCallerClients : IHubCallerClients { private readonly string _connectionId; private readonly IHubClients _hubClients; - private readonly string[] _currentConnectionId; - private readonly ChannelBasedSemaphore _parallelInvokes; + internal readonly ChannelBasedSemaphore _parallelInvokes; + + internal int ShouldReleaseSemaphore = 1; // Client results don't work in OnConnectedAsync // This property is set by the hub dispatcher when those methods are being called @@ -19,7 +20,6 @@ public HubCallerClients(IHubClients hubClients, string connectionId, ChannelBase { _connectionId = connectionId; _hubClients = hubClients; - _currentConnectionId = new[] { _connectionId }; _parallelInvokes = parallelInvokes; } @@ -32,11 +32,11 @@ public ISingleClientProxy Caller { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); } - return new SingleClientProxy(_hubClients.Client(_connectionId), _parallelInvokes); + return new SingleClientProxy(_hubClients.Client(_connectionId), this); } } - public IClientProxy Others => _hubClients.AllExcept(_currentConnectionId); + public IClientProxy Others => _hubClients.AllExcept(new[] { _connectionId }); public IClientProxy All => _hubClients.All; @@ -52,7 +52,7 @@ public ISingleClientProxy Client(string connectionId) { return new NoInvokeSingleClientProxy(_hubClients.Client(_connectionId)); } - return new SingleClientProxy(_hubClients.Client(connectionId), _parallelInvokes); + return new SingleClientProxy(_hubClients.Client(connectionId), this); } public IClientProxy Group(string groupName) @@ -67,7 +67,7 @@ public IClientProxy Groups(IReadOnlyList groupNames) public IClientProxy OthersInGroup(string groupName) { - return _hubClients.GroupExcept(groupName, _currentConnectionId); + return _hubClients.GroupExcept(groupName, new[] { _connectionId }); } public IClientProxy GroupExcept(string groupName, IReadOnlyList excludedConnectionIds) @@ -113,29 +113,26 @@ public Task SendCoreAsync(string method, object?[] args, CancellationToken cance private sealed class SingleClientProxy : ISingleClientProxy { private readonly ISingleClientProxy _proxy; - private readonly ChannelBasedSemaphore _parallelInvokes; + private readonly HubCallerClients _hubCallerClients; - public SingleClientProxy(ISingleClientProxy hubClients, ChannelBasedSemaphore parallelInvokes) + public SingleClientProxy(ISingleClientProxy hubClients, HubCallerClients hubCallerClients) { _proxy = hubClients; - _parallelInvokes = parallelInvokes; + _hubCallerClients = hubCallerClients; } public async Task InvokeCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) { // Releases the Channel that is blocking pending invokes, which in turn can block the receive loop. // Because we are waiting for a result from the client we need to let the receive loop run otherwise we'll be blocked forever - await _parallelInvokes.ReleaseAsync(cancellationToken); - try - { - var result = await _proxy.InvokeCoreAsync(method, args, cancellationToken); - return result; - } - finally + var value = Interlocked.CompareExchange(ref _hubCallerClients.ShouldReleaseSemaphore, 0, 1); + // Only release once, and we set ShouldReleaseSemaphore to 0 so the DefaultHubDispatcher knows not to call Release again + if (value == 1) { - // Re-wait, this is because when the hub method completes it will release which we already did above, so we need to reset the state - _ = await _parallelInvokes.WaitAsync(CancellationToken.None); + _hubCallerClients._parallelInvokes.Release(); } + var result = await _proxy.InvokeCoreAsync(method, args, cancellationToken); + return result; } public Task SendCoreAsync(string method, object?[] args, CancellationToken cancellationToken = default) diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index 0b659acb416e..dc4ad919292e 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -338,6 +338,24 @@ public async Task GetClientResult(int num) var sum = await Clients.Caller.InvokeAsync("Sum", num, cancellationToken: default); return sum; } + + public void BackgroundClientResult(TcsService tcsService) + { + var caller = Clients.Caller; + _ = Task.Run(async () => + { + try + { + await tcsService.StartedMethod.Task; + var result = await caller.InvokeAsync("GetResult", 1, CancellationToken.None); + tcsService.EndMethod.SetResult(result); + } + catch (Exception ex) + { + tcsService.EndMethod.SetException(ex); + } + }); + } } internal class SelfRef diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs index c1f7f96cc358..1320c674d622 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.ClientResult.cs @@ -277,6 +277,46 @@ public async Task ClientResultFromHubDoesNotBlockReceiveLoop() } } + [Fact] + public async Task ClientResultFromBackgroundThreadInHubMethodWorks() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).DefaultTimeout(); + + var completionMessage = await client.InvokeAsync(nameof(MethodHub.BackgroundClientResult)).DefaultTimeout(); + + tcsService.StartedMethod.SetResult(null); + + var task = await Task.WhenAny(tcsService.EndMethod.Task, client.ReadAsync()).DefaultTimeout(); + if (task == tcsService.EndMethod.Task) + { + await tcsService.EndMethod.Task; + } + // Hub asks client for a result, this is an invocation message with an ID + var invocationMessage = Assert.IsType(await (Task)task); + Assert.NotNull(invocationMessage.InvocationId); + var res = 4 + ((long)invocationMessage.Arguments[0]); + await client.SendHubMessageAsync(CompletionMessage.WithResult(invocationMessage.InvocationId, res)).DefaultTimeout(); + + Assert.Equal(5, await tcsService.EndMethod.Task.DefaultTimeout()); + + // Make sure we can still do a Hub invocation and that the semaphore state didn't get messed up + completionMessage = await client.InvokeAsync(nameof(MethodHub.ValueMethod)).DefaultTimeout(); + Assert.Equal(43L, completionMessage.Result); + } + } + } + private class TestBinder : IInvocationBinder { public IReadOnlyList GetParameterTypes(string methodName)