From f8643634e0ce9b7639110df60bc0d6081de66e5b Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Thu, 26 Mar 2020 13:57:36 -0700 Subject: [PATCH 1/2] Pass CancellationToken to WaitAsync in client --- .../csharp/Client.Core/src/HubConnection.cs | 26 ++--- ...HttpConnectionTests.ConnectionLifecycle.cs | 25 ++++- .../HubConnectionTests.ConnectionLifecycle.cs | 2 +- .../test/UnitTests/HubConnectionTests.cs | 94 ++++++++++++++++++- .../src/HttpConnection.cs | 2 +- 5 files changed, 133 insertions(+), 16 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 8738545ffbde..22311e43df10 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -235,7 +235,7 @@ public async Task StartAsync(CancellationToken cancellationToken = default) private async Task StartAsyncInner(CancellationToken cancellationToken = default) { - await _state.WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(token: cancellationToken); try { if (!_state.TryChangeState(HubConnectionState.Disconnected, HubConnectionState.Connecting)) @@ -601,7 +601,7 @@ async Task OnStreamCanceled(InvocationRequest irq) var readers = default(Dictionary); CheckDisposed(); - var connectionState = await _state.WaitForActiveConnectionAsync(nameof(StreamAsChannelCoreAsync)); + var connectionState = await _state.WaitForActiveConnectionAsync(nameof(StreamAsChannelCoreAsync), token: cancellationToken); ChannelReader channel; try @@ -704,7 +704,7 @@ async Task ReadChannelStream(CancellationTokenSource tokenSource) { while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item)) { - await SendWithLock(connectionState, new StreamItemMessage(streamId, item)); + await SendWithLock(connectionState, new StreamItemMessage(streamId, item), tokenSource.Token); Log.SendingStreamItem(_logger, streamId); } } @@ -722,7 +722,7 @@ async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource) await foreach (var streamValue in streamValues) { - await SendWithLock(connectionState, new StreamItemMessage(streamId, streamValue)); + await SendWithLock(connectionState, new StreamItemMessage(streamId, streamValue), tokenSource.Token); Log.SendingStreamItem(_logger, streamId); } } @@ -750,7 +750,9 @@ private async Task CommonStreaming(ConnectionState connectionState, string strea Log.CompletingStream(_logger, streamId); - await SendWithLock(connectionState, CompletionMessage.WithError(streamId, responseError), cts.Token); + // Don't use cancellation token here + // this is triggered by a cancellation token to tell the server that the client is done streaming + await SendWithLock(connectionState, CompletionMessage.WithError(streamId, responseError)); } private async Task InvokeCoreAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) @@ -758,7 +760,7 @@ private async Task InvokeCoreAsyncCore(string methodName, Type returnTyp var readers = default(Dictionary); CheckDisposed(); - var connectionState = await _state.WaitForActiveConnectionAsync(nameof(InvokeCoreAsync)); + var connectionState = await _state.WaitForActiveConnectionAsync(nameof(InvokeCoreAsync), token: cancellationToken); Task invocationTask; try @@ -853,7 +855,7 @@ private async Task SendCoreAsyncCore(string methodName, object[] args, Cancellat var readers = default(Dictionary); CheckDisposed(); - var connectionState = await _state.WaitForActiveConnectionAsync(nameof(SendCoreAsync)); + var connectionState = await _state.WaitForActiveConnectionAsync(nameof(SendCoreAsync), token: cancellationToken); try { CheckDisposed(); @@ -875,7 +877,7 @@ private async Task SendCoreAsyncCore(string methodName, object[] args, Cancellat private async Task SendWithLock(ConnectionState expectedConnectionState, HubMessage message, CancellationToken cancellationToken = default, [CallerMemberName] string callerName = "") { CheckDisposed(); - var connectionState = await _state.WaitForActiveConnectionAsync(callerName); + var connectionState = await _state.WaitForActiveConnectionAsync(callerName, token: cancellationToken); try { CheckDisposed(); @@ -1954,10 +1956,10 @@ public void AssertConnectionValid([CallerMemberName] string memberName = null, [ SafeAssert(CurrentConnectionStateUnsynchronized != null, "We don't have a connection!", memberName, fileName, lineNumber); } - public Task WaitConnectionLockAsync([CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) + public Task WaitConnectionLockAsync([CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0, CancellationToken token = default) { Log.WaitingOnConnectionLock(_logger, memberName, filePath, lineNumber); - return _connectionLock.WaitAsync(); + return _connectionLock.WaitAsync(token); } public bool TryAcquireConnectionLock() @@ -1966,9 +1968,9 @@ public bool TryAcquireConnectionLock() } // Don't call this method in a try/finally that releases the lock since we're also potentially releasing the connection lock here. - public async Task WaitForActiveConnectionAsync(string methodName, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) + public async Task WaitForActiveConnectionAsync(string methodName, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0, CancellationToken token = default) { - await WaitConnectionLockAsync(methodName); + await WaitConnectionLockAsync(methodName, token: token); if (CurrentConnectionStateUnsynchronized == null || CurrentConnectionStateUnsynchronized.Stopping) { diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs index fa95fbc83b51..3b1617178780 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs @@ -443,7 +443,7 @@ await WithConnectionAsync( })), async (connection) => { - // We aggregate failures that happen when we start the transport. The operation cancelled exception will + // We aggregate failures that happen when we start the transport. The operation canceled exception will // be an inner exception. var ex = await Assert.ThrowsAsync(async () => await connection.StartAsync(cts.Token)).OrTimeout(); Assert.Equal(3, ex.InnerExceptions.Count); @@ -454,6 +454,29 @@ await WithConnectionAsync( } } + [Fact] + public async Task CanceledCancellationTokenPassedToStartThrows() + { + using (StartVerifiableLog()) + { + bool transportStartCalled = false; + var httpHandler = new TestHttpMessageHandler(); + + await WithConnectionAsync( + CreateConnection(httpHandler, + transport: new TestTransport(onTransportStart: () => { + transportStartCalled = true; + return Task.CompletedTask; + })), + async (connection) => + { + await Assert.ThrowsAsync(async () => await connection.StartAsync(new CancellationToken(canceled: true))).OrTimeout(); + }); + + Assert.False(transportStartCalled); + } + } + [Fact] public async Task SSECanBeCanceled() { diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs index f1d191ee8ca8..77f380602e6d 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs @@ -541,7 +541,7 @@ public async Task StartAsyncWithTriggeredCancellationTokenIsCanceled() var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); try { - await Assert.ThrowsAsync(() => hubConnection.StartAsync(new CancellationToken(canceled: true))).OrTimeout(); + await Assert.ThrowsAsync(() => hubConnection.StartAsync(new CancellationToken(canceled: true))).OrTimeout(); Assert.False(onStartCalled); } finally diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs index cbaec6f40d1c..32df3ab06230 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs @@ -118,6 +118,98 @@ bool ExpectedErrors(WriteContext writeContext) } } + [Fact] + public async Task PendingInvocationsAreCanceledWhenTokenTriggered() + { + using (StartVerifiableLog()) + { + var hubConnection = CreateHubConnection(new TestConnection(), loggerFactory: LoggerFactory); + + await hubConnection.StartAsync().OrTimeout(); + var cts = new CancellationTokenSource(); + var invokeTask = hubConnection.InvokeAsync("testMethod", cancellationToken: cts.Token).OrTimeout(); + cts.Cancel(); + + await Assert.ThrowsAsync(async () => await invokeTask); + } + } + + [Fact] + public async Task InvokeAsyncCanceledWhenPassedCanceledToken() + { + using (StartVerifiableLog()) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + + await hubConnection.StartAsync().OrTimeout(); + await Assert.ThrowsAsync(() => + hubConnection.InvokeAsync("testMethod", cancellationToken: new CancellationToken(canceled: true)).OrTimeout()); + + await hubConnection.StopAsync().OrTimeout(); + + // Assert that InvokeAsync didn't send a message + Assert.Null(await connection.ReadSentTextMessageAsync().OrTimeout()); + } + } + + [Fact] + public async Task SendAsyncCanceledWhenPassedCanceledToken() + { + using (StartVerifiableLog()) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + + await hubConnection.StartAsync().OrTimeout(); + await Assert.ThrowsAsync(() => + hubConnection.SendAsync("testMethod", cancellationToken: new CancellationToken(canceled: true)).OrTimeout()); + + await hubConnection.StopAsync().OrTimeout(); + + // Assert that SendAsync didn't send a message + Assert.Null(await connection.ReadSentTextMessageAsync().OrTimeout()); + } + } + + [Fact] + public async Task StreamAsChannelAsyncCanceledWhenPassedCanceledToken() + { + using (StartVerifiableLog()) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + + await hubConnection.StartAsync().OrTimeout(); + await Assert.ThrowsAsync(() => + hubConnection.StreamAsChannelAsync("testMethod", cancellationToken: new CancellationToken(canceled: true)).OrTimeout()); + + await hubConnection.StopAsync().OrTimeout(); + + // Assert that StreamAsChannelAsync didn't send a message + Assert.Null(await connection.ReadSentTextMessageAsync().OrTimeout()); + } + } + + [Fact] + public async Task StreamAsyncCanceledWhenPassedCanceledToken() + { + using (StartVerifiableLog()) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + + await hubConnection.StartAsync().OrTimeout(); + var result = hubConnection.StreamAsync("testMethod", cancellationToken: new CancellationToken(canceled: true)); + await Assert.ThrowsAsync(() => result.GetAsyncEnumerator().MoveNextAsync().OrTimeout()); + + await hubConnection.StopAsync().OrTimeout(); + + // Assert that StreamAsync didn't send a message + Assert.Null(await connection.ReadSentTextMessageAsync().OrTimeout()); + } + } + [Fact] public async Task ConnectionTerminatedIfServerTimeoutIntervalElapsesWithNoMessages() { @@ -318,7 +410,7 @@ await connection.ReceiveJsonMessage( [Fact] [LogLevel(LogLevel.Trace)] - public async Task UploadStreamCancelationSendsStreamComplete() + public async Task UploadStreamCancellationSendsStreamComplete() { using (StartVerifiableLog()) { diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs index acfdce101045..1c2daa9cdfab 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/HttpConnection.cs @@ -205,7 +205,7 @@ private async Task StartAsyncCore(TransferFormat transferFormat, CancellationTok return; } - await _connectionLock.WaitAsync(); + await _connectionLock.WaitAsync(cancellationToken); try { CheckDisposed(); From a882539008eff2770af51297392cc9b7200fc98b Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Fri, 27 Mar 2020 16:14:47 -0700 Subject: [PATCH 2/2] fb --- .../csharp/Client.Core/src/HubConnection.cs | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 22311e43df10..70be20641e51 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -465,7 +465,7 @@ private async Task StopAsyncCore(bool disposing) // Potentially wait for StartAsync to finish, and block a new StartAsync from // starting until we've finished stopping. - await _state.WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(token: default); // Ensure that ReconnectingState.ReconnectTask is not accessed outside of the lock. var reconnectTask = _state.ReconnectTask; @@ -478,7 +478,7 @@ private async Task StopAsyncCore(bool disposing) // The StopCts should prevent the HubConnection from restarting until it is reset. _state.ReleaseConnectionLock(); await reconnectTask; - await _state.WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(token: default); } ConnectionState connectionState; @@ -574,7 +574,7 @@ private async Task> StreamAsChannelCoreAsyncCore(string me async Task OnStreamCanceled(InvocationRequest irq) { // We need to take the connection lock in order to ensure we a) have a connection and b) are the only one accessing the write end of the pipe. - await _state.WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(token: default); try { if (_state.CurrentConnectionStateUnsynchronized != null) @@ -752,7 +752,7 @@ private async Task CommonStreaming(ConnectionState connectionState, string strea // Don't use cancellation token here // this is triggered by a cancellation token to tell the server that the client is done streaming - await SendWithLock(connectionState, CompletionMessage.WithError(streamId, responseError)); + await SendWithLock(connectionState, CompletionMessage.WithError(streamId, responseError), cancellationToken: default); } private async Task InvokeCoreAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) @@ -874,7 +874,7 @@ private async Task SendCoreAsyncCore(string methodName, object[] args, Cancellat } } - private async Task SendWithLock(ConnectionState expectedConnectionState, HubMessage message, CancellationToken cancellationToken = default, [CallerMemberName] string callerName = "") + private async Task SendWithLock(ConnectionState expectedConnectionState, HubMessage message, CancellationToken cancellationToken, [CallerMemberName] string callerName = "") { CheckDisposed(); var connectionState = await _state.WaitForActiveConnectionAsync(callerName, token: cancellationToken); @@ -1248,7 +1248,7 @@ internal void OnServerTimeout() private async Task HandleConnectionClose(ConnectionState connectionState) { // Clear the connectionState field - await _state.WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(token: default); try { SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, connectionState), @@ -1366,7 +1366,7 @@ private async Task ReconnectAsync(Exception closeException) { Log.ReconnectingStoppedDuringRetryDelay(_logger); - await _state.WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(token: default); try { _state.ChangeState(HubConnectionState.Reconnecting, HubConnectionState.Disconnected); @@ -1381,7 +1381,7 @@ private async Task ReconnectAsync(Exception closeException) return; } - await _state.WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(token: default); try { SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null), @@ -1420,7 +1420,7 @@ private async Task ReconnectAsync(Exception closeException) nextRetryDelay = GetNextRetryDelay(previousReconnectAttempts++, DateTime.UtcNow - reconnectStartTime, retryReason); } - await _state.WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(token: default); try { SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null), @@ -1956,7 +1956,7 @@ public void AssertConnectionValid([CallerMemberName] string memberName = null, [ SafeAssert(CurrentConnectionStateUnsynchronized != null, "We don't have a connection!", memberName, fileName, lineNumber); } - public Task WaitConnectionLockAsync([CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0, CancellationToken token = default) + public Task WaitConnectionLockAsync(CancellationToken token, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) { Log.WaitingOnConnectionLock(_logger, memberName, filePath, lineNumber); return _connectionLock.WaitAsync(token); @@ -1968,9 +1968,9 @@ public bool TryAcquireConnectionLock() } // Don't call this method in a try/finally that releases the lock since we're also potentially releasing the connection lock here. - public async Task WaitForActiveConnectionAsync(string methodName, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0, CancellationToken token = default) + public async Task WaitForActiveConnectionAsync(string methodName, CancellationToken token, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) { - await WaitConnectionLockAsync(methodName, token: token); + await WaitConnectionLockAsync(token, methodName); if (CurrentConnectionStateUnsynchronized == null || CurrentConnectionStateUnsynchronized.Stopping) {