diff --git a/src/Directory.Packages.props b/src/Directory.Packages.props index cfb7f85797..768981a4f2 100644 --- a/src/Directory.Packages.props +++ b/src/Directory.Packages.props @@ -35,6 +35,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 950d7f83b1..cab7bfd88c 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -783,6 +783,9 @@ Microsoft\Data\SqlClient\SSPI\SspiAuthenticationParameters.cs + + Microsoft\Data\SqlClient\Utilities\AsyncHelper.cs + Microsoft\Data\SqlClient\Utilities\ObjectPool.cs diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 104b9261e7..e89d931229 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -977,6 +977,9 @@ Microsoft\Data\SqlClient\TransactionRequest.cs + + Microsoft\Data\SqlClient\Utilities\AsyncHelper.cs + Microsoft\Data\SqlClient\Utilities\BufferWriterExtensions.netfx.cs diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 9784117a5e..bf972048ba 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -15,6 +15,7 @@ using System.Threading.Tasks; using System.Xml; using Microsoft.Data.Common; +using Microsoft.Data.SqlClient.Utilities; namespace Microsoft.Data.SqlClient { @@ -2050,10 +2051,11 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok } else { - AsyncHelper.ContinueTaskWithState(writeTask, tcs, + AsyncHelper.ContinueTaskWithState( + taskToContinue: writeTask, + taskCompletionSource: tcs, state: tcs, - onSuccess: static (object state) => ((TaskCompletionSource)state).SetResult(null) - ); + onSuccess: static tcs2 => tcs2.SetResult(null)); } }, ctoken); // We do not need to propagate exception, etc, from reconnect task, we just need to wait for it to finish. return tcs.Task; @@ -2362,19 +2364,20 @@ private Task CopyColumnsAsync(int col, TaskCompletionSource source = nul private void CopyColumnsAsyncSetupContinuation(TaskCompletionSource source, Task task, int i) { AsyncHelper.ContinueTaskWithState( - task, - source, - state: this, - onSuccess: (object state) => + taskToContinue: task, + taskCompletionSource: source, + state1: this, + state2: Tuple.Create(source, i), + onSuccess: static (this2, parameters) => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - if (i + 1 < sqlBulkCopy._sortedColumnMappings.Count) + if (parameters.Item2 + 1 < this2._sortedColumnMappings.Count) { - sqlBulkCopy.CopyColumnsAsync(i + 1, source); //continue from the next column + // continue from the next column + this2.CopyColumnsAsync(parameters.Item2 + 1, parameters.Item1); } else { - source.SetResult(null); + parameters.Item1.SetResult(null); } }); } @@ -2505,18 +2508,17 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, Task readTask = ReadFromRowSourceAsync(cts); // Read the next row. Caution: more is only valid if the task returns null. Otherwise, we wait for Task.Result if (readTask != null) { - if (source == null) - { - source = new TaskCompletionSource(); - } + source ??= new TaskCompletionSource(); resultTask = source.Task; AsyncHelper.ContinueTaskWithState( - readTask, - source, + taskToContinue: readTask, + taskCompletionSource: source, state: this, - onSuccess: (object state) => ((SqlBulkCopy)state).CopyRowsAsync(i + 1, totalRows, cts, source)); - return resultTask; // Associated task will be completed when all rows are copied to server/exception/cancelled. + onSuccess: this2 => this2.CopyRowsAsync(i + 1, totalRows, cts, source)); + + // Associated task will be completed when all rows are copied to server/exception/cancelled. + return resultTask; } } else @@ -2524,34 +2526,35 @@ private Task CopyRowsAsync(int rowsSoFar, int totalRows, CancellationToken cts, source = source ?? new TaskCompletionSource(); resultTask = source.Task; - AsyncHelper.ContinueTaskWithState(task, source, this, - onSuccess: (object state) => + AsyncHelper.ContinueTaskWithState( + taskToContinue: task, + taskCompletionSource: source, + state: this, + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - sqlBulkCopy.CheckAndRaiseNotification(); // Check for notification now as the current row copy is done at this moment. + // Check for notification now as the current row copy is done at this moment. + this2.CheckAndRaiseNotification(); - Task readTask = sqlBulkCopy.ReadFromRowSourceAsync(cts); - if (readTask == null) + Task readTask = this2.ReadFromRowSourceAsync(cts); + if (readTask is null) { - sqlBulkCopy.CopyRowsAsync(i + 1, totalRows, cts, source); + this2.CopyRowsAsync(i + 1, totalRows, cts, source); } else { AsyncHelper.ContinueTaskWithState( - readTask, - source, - state: sqlBulkCopy, - onSuccess: (object state2) => ((SqlBulkCopy)state2).CopyRowsAsync(i + 1, totalRows, cts, source)); + taskToContinue: readTask, + taskCompletionSource: source, + state: this2, + onSuccess: this3 => this3.CopyRowsAsync(i + 1, totalRows, cts, source)); } }); return resultTask; } } - if (source != null) - { - source.TrySetResult(null); // This is set only on the last call of async copy. But may not be set if everything runs synchronously. - } + // This is set only on the last call of async copy. But may not be set if everything runs synchronously. + source?.TrySetResult(null); } catch (Exception ex) { @@ -2613,17 +2616,21 @@ private Task CopyBatchesAsync(BulkCopySimpleResultSet internalResults, string up } AsyncHelper.ContinueTaskWithState( - commandTask, - source, + taskToContinue: commandTask, + taskCompletionSource: source, state: this, - onSuccess: (object state) => + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - Task continuedTask = sqlBulkCopy.CopyBatchesAsyncContinued(internalResults, updateBulkCommandText, cts, source); - if (continuedTask == null) + Task continuedTask = this2.CopyBatchesAsyncContinued( + internalResults, + updateBulkCommandText, + cts, + source); + + if (continuedTask is null) { // Continuation finished sync, recall into CopyBatchesAsync to continue - sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + this2.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); } }); return source.Task; @@ -2683,18 +2690,19 @@ private Task CopyBatchesAsyncContinued(BulkCopySimpleResultSet internalResults, task, source, state: this, - onSuccess: (object state) => + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - Task continuedTask = sqlBulkCopy.CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); + Task continuedTask = this2.CopyBatchesAsyncContinuedOnSuccess(internalResults, updateBulkCommandText, cts, source); if (continuedTask == null) { // Continuation finished sync, recall into CopyBatchesAsync to continue - sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + this2.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); } }, - onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false), - onCancellation: static (object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: true)); + onFailure: static (this2, _) => + this2.CopyBatchesAsyncContinuedOnError(cleanupParser: false), + onCancellation: static this2 => + this2.CopyBatchesAsyncContinuedOnError(cleanupParser: true)); return source.Task; } @@ -2745,24 +2753,24 @@ private Task CopyBatchesAsyncContinuedOnSuccess(BulkCopySimpleResultSet internal writeTask, source, state: this, - onSuccess: (object state) => + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; try { - sqlBulkCopy.RunParser(); - sqlBulkCopy.CommitTransaction(); + this2.RunParser(); + this2.CommitTransaction(); } catch (Exception) { - sqlBulkCopy.CopyBatchesAsyncContinuedOnError(cleanupParser: false); + this2.CopyBatchesAsyncContinuedOnError(cleanupParser: false); throw; } // Always call back into CopyBatchesAsync - sqlBulkCopy.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); + this2.CopyBatchesAsync(internalResults, updateBulkCommandText, cts, source); }, - onFailure: static (Exception _, object state) => ((SqlBulkCopy)state).CopyBatchesAsyncContinuedOnError(cleanupParser: false)); + onFailure: static (this2, _) => + this2.CopyBatchesAsyncContinuedOnError(cleanupParser: false)); return source.Task; } } @@ -2859,24 +2867,21 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int if (task != null) { - if (source == null) - { - source = new TaskCompletionSource(); - } + source ??= new TaskCompletionSource(); AsyncHelper.ContinueTaskWithState( - task, - source, + taskToContinue: task, + taskCompletionSource: source, state: this, - onSuccess: (object state) => + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; + // @TODO: Split into oncancellation, onfailure, etc. // Bulk copy task is completed at this moment. if (task.IsCanceled) { - sqlBulkCopy._localColumnMappings = null; + this2._localColumnMappings = null; try { - sqlBulkCopy.CleanUpStateObject(); + this2.CleanUpStateObject(); } finally { @@ -2889,10 +2894,10 @@ private void WriteToServerInternalRestContinuedAsync(BulkCopySimpleResultSet int } else { - sqlBulkCopy._localColumnMappings = null; + this2._localColumnMappings = null; try { - sqlBulkCopy.CleanUpStateObject(isCancelRequested: false); + this2.CleanUpStateObject(isCancelRequested: false); } finally { @@ -3006,38 +3011,51 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio reconnectTask, cancellableReconnectTS, state: cancellableReconnectTS, - onSuccess: static state => ((TaskCompletionSource)state).SetResult(null)); + onSuccess: static state => state.SetResult(null)); // No need to cancel timer since SqlBulkCopy creates specific task source for reconnection. AsyncHelper.SetTimeoutExceptionWithState( - completion: cancellableReconnectTS, - timeout: BulkCopyTimeout, + taskCompletionSource: cancellableReconnectTS, + timeoutInSeconds: BulkCopyTimeout, state: _destinationTableName, - onFailure: static state => + onTimeout: static state => SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()), cancellationToken: CancellationToken.None ); AsyncHelper.ContinueTaskWithState( - task: cancellableReconnectTS.Task, - completion: source, + taskToContinue:cancellableReconnectTS.Task, + taskCompletionSource: source, state: regReconnectCancel, - onSuccess: (object state) => + onSuccess: regReconnectCancel2 => { - ((StrongBox)state).Value.Dispose(); - if (_parserLock != null) + regReconnectCancel2.Value.Dispose(); + + if (_parserLock is not null) { _parserLock.Release(); - _parserLock = null; + _parserLock = null; // @TODO: Can be omitted b/c we reassign it directly below } _parserLock = _connection.GetOpenTdsConnection()._parserLock; _parserLock.Wait(canReleaseFromAnyThread: true); WriteToServerInternalRestAsync(cts, source); }, - onFailure: static (_, state) => ((StrongBox)state).Value.Dispose(), - onCancellation: static state => ((StrongBox)state).Value.Dispose(), - exceptionConverter: ex => SQL.BulkLoadInvalidDestinationTable(_destinationTableName, ex) - ); + onFailure: (regReconnectCancel2, exception) => + { + regReconnectCancel2.Value.Dispose(); + + // Convert exception and set it on the source + // Note: This is safe because the helper will only try to set the + // exception and b/c it is already set will pass without setting + // to the original exception. + Exception convertedException = SQL.BulkLoadInvalidDestinationTable( + _destinationTableName, + exception); + source.TrySetException(convertedException); + }, + onCancellation: static regReconnectCancel2 => + regReconnectCancel2.Value.Dispose()); + return; } else @@ -3086,10 +3104,11 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio if (internalResultsTask != null) { AsyncHelper.ContinueTaskWithState( - internalResultsTask, - source, + taskToContinue: internalResultsTask, + taskCompletionSource: source, state: this, - onSuccess: (object state) => ((SqlBulkCopy)state).WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source)); + onSuccess: this2 => + this2.WriteToServerInternalRestContinuedAsync(internalResultsTask.Result, cts, source)); } else { @@ -3158,17 +3177,21 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken) else { Debug.Assert(_isAsyncBulkCopy, "Read must not return a Task in the Sync mode"); - AsyncHelper.ContinueTaskWithState(readTask, source, this, - onSuccess: (object state) => + AsyncHelper.ContinueTaskWithState( + taskToContinue: readTask, + taskCompletionSource: source, + state: this, + onSuccess: this2 => { - SqlBulkCopy sqlBulkCopy = (SqlBulkCopy)state; - if (!sqlBulkCopy._hasMoreRowToCopy) + if (!this2._hasMoreRowToCopy) { - source.SetResult(null); // No rows to copy! + // No rows to copy! + source.SetResult(null); } else { - sqlBulkCopy.WriteToServerInternalRestAsync(ctoken, source); // Passing the same completion which will be completed by the Callee. + // Passing the same completion which will be completed by the Callee. + this2.WriteToServerInternalRestAsync(ctoken, source); } }); return resultTask; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs index 11fd3a316d..1dadde00cd 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs @@ -330,6 +330,18 @@ private string GetFormattedMessage(string className, string memberName, string e #region Trace #region Traces without if statements + + internal void TraceEvent(string message) + { + Trace(message); + } + + [NonEvent] + internal void TraceEvent(string message, T0 args0) + { + Trace(string.Format(message, args0?.ToString() ?? NullStr)); + } + [NonEvent] internal void TraceEvent(string message, T0 args0, T1 args1) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs index cf90d29632..b93f14c1d7 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Encryption.cs @@ -12,6 +12,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.SqlClient.Utilities; namespace Microsoft.Data.SqlClient { @@ -250,24 +251,22 @@ private SqlDataReader GetParameterEncryptionDataReader( bool isRetry) { returnTask = AsyncHelper.CreateContinuationTaskWithState( - task: fetchInputParameterEncryptionInfoTask, + taskToContinue: fetchInputParameterEncryptionInfoTask, state: this, - onSuccess: state => + onSuccess: this2 => { - SqlCommand command = (SqlCommand)state; bool processFinallyBlockAsync = true; bool decrementAsyncCountInFinallyBlockAsync = true; try { // Check for any exceptions on network write, before reading. - command.CheckThrowSNIException(); + this2.CheckThrowSNIException(); // If it is async, then TryFetchInputParameterEncryptionInfo -> // RunExecuteReaderTds would have incremented the async count. Decrement it // when we are about to complete async execute reader. - SqlInternalConnectionTds internalConnectionTds = - command._activeConnection.GetOpenTdsConnection(); + SqlInternalConnectionTds internalConnectionTds = this2._activeConnection.GetOpenTdsConnection(); if (internalConnectionTds is not null) { internalConnectionTds.DecrementAsyncCount(); @@ -276,13 +275,13 @@ private SqlDataReader GetParameterEncryptionDataReader( // Complete executereader. // @TODO: If we can remove this reference, this could be a static lambda - describeParameterEncryptionDataReader = command.CompleteAsyncExecuteReader( + describeParameterEncryptionDataReader = this2.CompleteAsyncExecuteReader( isInternal: false, forDescribeParameterEncryption: true); - Debug.Assert(command._stateObj is null, "non-null state object in PrepareForTransparentEncryption."); + Debug.Assert(this2._stateObj is null, "non-null state object in PrepareForTransparentEncryption."); // Read the results of describe parameter encryption. - command.ReadDescribeEncryptionParameterResults( + this2.ReadDescribeEncryptionParameterResults( describeParameterEncryptionDataReader, describeParameterEncryptionRpcOriginalRpcMap, isRetry); @@ -302,7 +301,7 @@ private SqlDataReader GetParameterEncryptionDataReader( } finally { - command.PrepareTransparentEncryptionFinallyBlock( + this2.PrepareTransparentEncryptionFinallyBlock( closeDataReader: processFinallyBlockAsync, decrementAsyncCount: decrementAsyncCountInFinallyBlockAsync, clearDataStructures: processFinallyBlockAsync, @@ -311,10 +310,9 @@ private SqlDataReader GetParameterEncryptionDataReader( describeParameterEncryptionDataReader: describeParameterEncryptionDataReader); } }, - onFailure: static (exception, state) => + onFailure: static (this2, exception) => { - SqlCommand command = (SqlCommand)state; - command.CachedAsyncState?.ResetAsyncState(); + this2.CachedAsyncState?.ResetAsyncState(); if (exception is not null) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs index 28bd1cf2ef..eb27f488c1 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.NonQuery.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Security.Permissions; @@ -219,14 +220,12 @@ private IAsyncResult BeginExecuteNonQueryInternal( if (execNonQuery is not null) { AsyncHelper.ContinueTaskWithState( - task: execNonQuery, - completion: localCompletion, - state: Tuple.Create(this, localCompletion), - onSuccess: static state => - { - var parameters = (Tuple>)state; - parameters.Item1.BeginExecuteNonQueryInternalReadStage(parameters.Item2); - }); + taskToContinue: execNonQuery, + taskCompletionSource: localCompletion, + state1: this, + state2: localCompletion, + onSuccess: static (this2, localCompletion2) => + this2.BeginExecuteNonQueryInternalReadStage(localCompletion2)); } else { @@ -871,8 +870,8 @@ private void RunExecuteNonQueryTdsSetupReconnnectContinuation( timeoutCts.Token); AsyncHelper.ContinueTask( - reconnectTask, - completion, + taskToContinue: reconnectTask, + taskCompletionSource: completion, onSuccess: () => { if (completion.Task.IsCompleted) @@ -896,10 +895,10 @@ private void RunExecuteNonQueryTdsSetupReconnnectContinuation( else { AsyncHelper.ContinueTaskWithState( - subTask, - completion, + taskToContinue: subTask, + taskCompletionSource: completion, state: completion, - onSuccess: static state => ((TaskCompletionSource)state).SetResult(null)); + onSuccess: static state => state.SetResult(null)); } }); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs index 332195efb5..b186306272 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Reader.cs @@ -10,10 +10,10 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Security.Permissions; -using Microsoft.Data.SqlClient.Utilities; #endif namespace Microsoft.Data.SqlClient @@ -308,14 +308,12 @@ private IAsyncResult BeginExecuteReaderInternal( if (writeTask is not null) { AsyncHelper.ContinueTaskWithState( - writeTask, - localCompletion, - state: Tuple.Create(this, localCompletion), - onSuccess: static state => - { - var parameters = (Tuple>)state; - parameters.Item1.BeginExecuteReaderInternalReadStage(parameters.Item2); - }); + taskToContinue: writeTask, + taskCompletionSource: localCompletion, + state1: this, + state2: localCompletion, + onSuccess: static (this2, localCompletion2) => + this2.BeginExecuteReaderInternalReadStage(localCompletion2)); } else { @@ -1605,21 +1603,19 @@ private Task RunExecuteReaderTdsSetupContinuation( string optionSettings, Task writeTask) { - // @TODO: Why use the state version if we can't make this a static helper? return AsyncHelper.CreateContinuationTaskWithState( - task: writeTask, - state: _activeConnection, - onSuccess: state => + taskToContinue: writeTask, + state1: this, + state2: Tuple.Create(ds, runBehavior, optionSettings), + onSuccess: static (this2, parameters) => { // This will throw if the connection is closed. // @TODO: So... can we have something that specifically does that? - ((SqlConnection)state).GetOpenTdsConnection(); - CachedAsyncState.SetAsyncReaderState(ds, runBehavior, optionSettings); + this2._activeConnection.GetOpenTdsConnection(); + this2.CachedAsyncState.SetAsyncReaderState(parameters.Item1, parameters.Item2, parameters.Item3); }, - onFailure: static (exception, state) => - { - ((SqlConnection)state).GetOpenTdsConnection().DecrementAsyncCount(); - }); + onFailure: static (this2, _, _) => + this2._activeConnection.GetOpenTdsConnection().DecrementAsyncCount()); } // @TODO: This is way too many parameters being shoveled back and forth. We can do better. @@ -1640,13 +1636,12 @@ private void RunExecuteReaderTdsSetupReconnectContinuation( AsyncHelper.SetTimeoutException( completion, timeout, - onFailure: static () => SQL.CR_ReconnectTimeout(), + onTimeout: static () => SQL.CR_ReconnectTimeout(), timeoutCts.Token); - // @TODO: With an object to pass around we can use the state-based version AsyncHelper.ContinueTask( - reconnectTask, - completion, + taskToContinue: reconnectTask, + taskCompletionSource: completion, onSuccess: () => { if (completion.Task.IsCompleted) @@ -1675,10 +1670,10 @@ private void RunExecuteReaderTdsSetupReconnectContinuation( else { AsyncHelper.ContinueTaskWithState( - subTask, - completion, + taskToContinue: subTask, + taskCompletionSource: completion, state: completion, - onSuccess: static state => ((TaskCompletionSource)state).SetResult(null)); + onSuccess: static completion2 => completion2.SetResult(null)); } }); } @@ -1711,14 +1706,13 @@ private SqlDataReader RunExecuteReaderTdsWithTransparentParameterEncryption( // @TODO: This is a prime candidate for proper async-await execution TaskCompletionSource completion = new TaskCompletionSource(); AsyncHelper.ContinueTaskWithState( - task: describeParameterEncryptionTask, - completion: completion, + taskToContinue: describeParameterEncryptionTask, + taskCompletionSource: completion, state: this, - onSuccess: state => + onSuccess: this2 => { - SqlCommand command = (SqlCommand)state; - command.GenerateEnclavePackage(); - command.RunExecuteReaderTds( + this2.GenerateEnclavePackage(); + this2.RunExecuteReaderTds( cmdBehavior, runBehavior, returnStream, @@ -1737,24 +1731,22 @@ private SqlDataReader RunExecuteReaderTdsWithTransparentParameterEncryption( else { AsyncHelper.ContinueTaskWithState( - task: subTask, - completion: completion, + taskToContinue: subTask, + taskCompletionSource: completion, state: completion, - onSuccess: static state => ((TaskCompletionSource)state).SetResult(null)); + onSuccess: static state => state.SetResult(null)); } }, - onFailure: static (exception, state) => + onFailure: static (this2, exception) => { - ((SqlCommand)state).CachedAsyncState?.ResetAsyncState(); + this2.CachedAsyncState?.ResetAsyncState(); if (exception is not null) { + // @TODO: This doesn't do anything, afaik. throw exception; } }, - onCancellation: static state => - { - ((SqlCommand)state).CachedAsyncState?.ResetAsyncState(); - }); + onCancellation: static this2 => this2.CachedAsyncState?.ResetAsyncState()); task = completion.Task; return ds; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs index 6a18b6b1b6..360e8f4af2 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlCommand.Xml.cs @@ -10,6 +10,7 @@ using System.Xml; using Microsoft.Data.Common; using Microsoft.Data.SqlClient.Server; +using Microsoft.Data.SqlClient.Utilities; #if NETFRAMEWORK using System.Security.Permissions; @@ -269,14 +270,12 @@ private IAsyncResult BeginExecuteXmlReaderInternal( if (writeTask is not null) { AsyncHelper.ContinueTaskWithState( - task: writeTask, - completion: localCompletion, - state: Tuple.Create(this, localCompletion), - onSuccess: static state => - { - var parameters = (Tuple>)state; - parameters.Item1.BeginExecuteXmlReaderInternalReadStage(parameters.Item2); - }); + taskToContinue: writeTask, + taskCompletionSource: localCompletion, + state1: this, + state2: localCompletion, + onSuccess: static (this2, localCompletion2) => + this2.BeginExecuteXmlReaderInternalReadStage(localCompletion2)); } else { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs index 4f9d9ca14f..f3780d2936 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -23,7 +23,8 @@ using Microsoft.Data.SqlClient.Connection; using Microsoft.Data.SqlClient.ConnectionPool; using Microsoft.Data.SqlClient.Diagnostics; -using Microsoft.SqlServer.Server; +using Microsoft.Data.SqlClient.Utilities; + #if NETFRAMEWORK using System.Runtime.CompilerServices; using System.Security.Permissions; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 9f7858f2c9..6431a0d0d4 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -19,6 +19,7 @@ using Microsoft.Data.ProviderBase; using Microsoft.Data.SqlClient.Connection; using Microsoft.Data.SqlClient.ConnectionPool; +using Microsoft.Data.SqlClient.Utilities; using Microsoft.Identity.Client; namespace Microsoft.Data.SqlClient diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs index fd814370b5..2027cd98b3 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -49,250 +49,6 @@ internal static ArgumentOutOfRangeException InvalidMinAndMaxPair(string minParam => new ArgumentOutOfRangeException(minParamName, StringsHelper.GetString(Strings.SqlRetryLogic_InvalidMinMaxPair, minValue, maxValue, minParamName, maxParamName)); } - internal static class AsyncHelper - { - internal static Task CreateContinuationTask( - Task task, - Action onSuccess, - Action onFailure = null) - { - if (task == null) - { - onSuccess(); - return null; - } - else - { - TaskCompletionSource completion = new TaskCompletionSource(); - ContinueTaskWithState( - task, - completion, - state: Tuple.Create(onSuccess, onFailure, completion), - onSuccess: static (object state) => - { - var parameters = (Tuple, TaskCompletionSource>)state; - Action success = parameters.Item1; - TaskCompletionSource taskCompletionSource = parameters.Item3; - success(); - taskCompletionSource.SetResult(null); - }, - onFailure: static (Exception exception, object state) => - { - var parameters = (Tuple, TaskCompletionSource>)state; - Action failure = parameters.Item2; - failure?.Invoke(exception); - } - ); - return completion.Task; - } - } - - internal static Task CreateContinuationTaskWithState(Task task, object state, Action onSuccess, Action onFailure = null) - { - if (task == null) - { - onSuccess(state); - return null; - } - else - { - var completion = new TaskCompletionSource(); - ContinueTaskWithState(task, completion, state, - onSuccess: (object continueState) => - { - onSuccess(continueState); - completion.SetResult(null); - }, - onFailure: onFailure - ); - return completion.Task; - } - } - - internal static Task CreateContinuationTask( - Task task, - Action onSuccess, - T1 arg1, - T2 arg2, - Action onFailure = null) - { - return CreateContinuationTask(task, () => onSuccess(arg1, arg2), onFailure); - } - - internal static void ContinueTask(Task task, - TaskCompletionSource completion, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) - { - task.ContinueWith( - tsk => - { - if (tsk.Exception != null) - { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } - try - { - onFailure?.Invoke(exc); - } - finally - { - completion.TrySetException(exc); - } - } - else if (tsk.IsCanceled) - { - try - { - onCancellation?.Invoke(); - } - finally - { - completion.TrySetCanceled(); - } - } - else - { - try - { - onSuccess(); - } - // @TODO: CER Exception Handling was removed here (see GH#3581) - catch (Exception e) - { - completion.SetException(e); - } - } - }, TaskScheduler.Default - ); - } - - // the same logic as ContinueTask but with an added state parameter to allow the caller to avoid the use of a closure - // the parameter allocation cannot be avoided here and using closure names is clearer than Tuple numbered properties - internal static void ContinueTaskWithState(Task task, - TaskCompletionSource completion, - object state, - Action onSuccess, - Action onFailure = null, - Action onCancellation = null, - Func exceptionConverter = null) - { - task.ContinueWith( - (Task tsk, object state2) => - { - if (tsk.Exception != null) - { - Exception exc = tsk.Exception.InnerException; - if (exceptionConverter != null) - { - exc = exceptionConverter(exc); - } - - try - { - onFailure?.Invoke(exc, state2); - } - finally - { - completion.TrySetException(exc); - } - } - else if (tsk.IsCanceled) - { - try - { - onCancellation?.Invoke(state2); - } - finally - { - completion.TrySetCanceled(); - } - } - else - { - try - { - onSuccess(state2); - } - // @TODO: CER Exception Handling was removed here (see GH#3581) - catch (Exception e) - { - completion.SetException(e); - } - } - }, - state: state, - scheduler: TaskScheduler.Default - ); - } - - internal static void WaitForCompletion(Task task, int timeout, Action onTimeout = null, bool rethrowExceptions = true) - { - try - { - task.Wait(timeout > 0 ? (1000 * timeout) : Timeout.Infinite); - } - catch (AggregateException ae) - { - if (rethrowExceptions) - { - Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); - ExceptionDispatchInfo.Capture(ae.InnerException).Throw(); - } - } - if (!task.IsCompleted) - { - task.ContinueWith(static t => { var ignored = t.Exception; }); //Ensure the task does not leave an unobserved exception - onTimeout?.Invoke(); - } - } - - internal static void SetTimeoutException(TaskCompletionSource completion, int timeout, Func onFailure, CancellationToken ctoken) - { - if (timeout > 0) - { - Task.Delay(timeout * 1000, ctoken).ContinueWith( - (Task task) => - { - if (!task.IsCanceled && !completion.Task.IsCompleted) - { - completion.TrySetException(onFailure()); - } - } - ); - } - } - - internal static void SetTimeoutExceptionWithState( - TaskCompletionSource completion, - int timeout, - object state, - Func onFailure, - CancellationToken cancellationToken) - { - if (timeout <= 0) - { - return; - } - - Task.Delay(timeout * 1000, cancellationToken).ContinueWith( - (task, innerState) => - { - if (!task.IsCanceled && !completion.Task.IsCompleted) - { - completion.TrySetException(onFailure(innerState)); - } - }, - state: state, - cancellationToken: CancellationToken.None); - } - } - internal static class SQL { // The class SQL defines the exceptions that are specific to the SQL Adapter. diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index 084cea2d0a..1da54fb70a 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -1361,11 +1361,11 @@ internal void TdsLogin( int feOffset = length; // calculate and reserve the required bytes for the featureEx length = ApplyFeatureExData( - requestedFeatures, - recoverySessionData, + requestedFeatures, + recoverySessionData, fedAuthFeatureExtensionData, UserAgentInfo.UserAgentCachedJsonPayload.ToArray(), - useFeatureExt, + useFeatureExt, length ); @@ -10786,11 +10786,23 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet } // This is in its own method to avoid always allocating the lambda in TDSExecuteRPCParameter - private void TDSExecuteRPCParameterSetupWriteCompletion(SqlCommand cmd, IList<_SqlRPC> rpcArray, int timeout, bool inSchema, SqlNotificationRequest notificationRequest, TdsParserStateObject stateObj, bool isCommandProc, bool sync, TaskCompletionSource completion, int startRpc, int startParam, Task writeParamTask) + private void TDSExecuteRPCParameterSetupWriteCompletion( + SqlCommand cmd, + IList<_SqlRPC> rpcArray, + int timeout, + bool inSchema, + SqlNotificationRequest notificationRequest, + TdsParserStateObject stateObj, + bool isCommandProc, + bool sync, + TaskCompletionSource completion, + int startRpc, + int startParam, + Task writeParamTask) { AsyncHelper.ContinueTask( - writeParamTask, - completion, + taskToContinue: writeParamTask, + taskCompletionSource: completion, onSuccess: () => TdsExecuteRPC( cmd, rpcArray, @@ -10802,8 +10814,7 @@ private void TDSExecuteRPCParameterSetupWriteCompletion(SqlCommand cmd, IList<_S sync, completion, startRpc, - startParam - ), + startParam), onFailure: exc => TdsExecuteRPC_OnFailure(exc, stateObj)); } @@ -12240,11 +12251,11 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy } else { - return AsyncHelper.CreateContinuationTask( - unterminatedWriteTask, - onSuccess: WriteInt, - arg1: 0, - arg2: stateObj); + return AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: unterminatedWriteTask, + state1: this, + state2: stateObj, + onSuccess: static (this2, stateObj2) => this2.WriteInt(0, stateObj2)); } } else @@ -13207,11 +13218,11 @@ private Task WriteEncryptionMetadata(Task terminatedWriteTask, SqlColumnEncrypti else { // Otherwise, create a continuation task to write the encryption metadata after the previous write completes. - return AsyncHelper.CreateContinuationTask( - terminatedWriteTask, - onSuccess: WriteEncryptionMetadata, - arg1: columnEncryptionParameterInfo, - arg2: stateObj); + return AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: terminatedWriteTask, + state1: columnEncryptionParameterInfo, + state2: stateObj, + onSuccess: WriteEncryptionMetadata); } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index 7d29ebc23f..5f61d6dd14 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -16,17 +16,10 @@ using Microsoft.Data.Common; using Microsoft.Data.ProviderBase; using Microsoft.Data.SqlClient.ManagedSni; - -#if NETFRAMEWORK -using System.Runtime.ConstrainedExecution; -#endif +using Microsoft.Data.SqlClient.Utilities; namespace Microsoft.Data.SqlClient { -#if NETFRAMEWORK - using RuntimeHelpers = System.Runtime.CompilerServices.RuntimeHelpers; -#endif - sealed internal class LastIOTimer { internal long _value; @@ -1277,13 +1270,12 @@ internal Task ExecuteFlush() else { return AsyncHelper.CreateContinuationTaskWithState( - task: writePacketTask, + taskToContinue: writePacketTask, state: this, - onSuccess: static (object state) => + onSuccess: static this2 => { - TdsParserStateObject stateObject = (TdsParserStateObject)state; - stateObject.HasPendingData = true; - stateObject._messageStatus = 0; + this2.HasPendingData = true; + this2._messageStatus = 0; } ); } @@ -3051,7 +3043,10 @@ internal Task WritePacket(byte flushMode, bool canAccumulate = false) if (willCancel) { // If we have been canceled, then ensure that we write the ATTN packet as well - task = AsyncHelper.CreateContinuationTask(task, CancelWritePacket); + task = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: task, + state: this, + onSuccess: static this2 => this2.CancelWritePacket()); } return task; @@ -4358,9 +4353,16 @@ private Task WriteBytes(ReadOnlySpan b, int len, int offsetBuffer, bool ca // This is in its own method to avoid always allocating the lambda in WriteBytes private void WriteBytesSetupContinuation(byte[] array, int len, TaskCompletionSource completion, int offset, Task packetTask) { - AsyncHelper.ContinueTask(packetTask, completion, - onSuccess: () => WriteBytes(ReadOnlySpan.Empty, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion, array) - ); + AsyncHelper.ContinueTask( + taskToContinue: packetTask, + taskCompletionSource: completion, + onSuccess: () => WriteBytes( + ReadOnlySpan.Empty, + len: len, + offsetBuffer: offset, + canAccumulate: false, + completion: completion, + array)); } /// diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs new file mode 100644 index 0000000000..05dcb0d017 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Utilities/AsyncHelper.cs @@ -0,0 +1,796 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace Microsoft.Data.SqlClient.Utilities +{ + /// + /// Provides helpers for interacting with asynchronous tasks. + /// + /// + /// These helpers mainly provide continuation and timeout functionality. They utilize + /// at their core, and as such are fairly antiquated + /// implementations. If possible these methods should be utilized less and async/await native + /// constructs should be used. + /// + internal static class AsyncHelper + { + /// + /// Continues a task and signals failure of the continuation via the provided + /// . + /// + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * Will try to set exception on + /// * Cancelled + /// * is called (if provided) + /// * Will try to set cancelled on + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// helper will try to set an exception on the . + /// * is *not* with result on success. This + /// is to allow the task completion source to be continued even more after this current + /// continuation. + /// + /// Task to continue with provided callbacks + /// + /// Completion source used to track completion of the continuation, see remarks for details + /// + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static void ContinueTask( + Task taskToContinue, + TaskCompletionSource taskCompletionSource, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) + { + ContinuationState continuationState = new ContinuationState( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (tsk, continuationState2) => + { + ContinuationState typedState = (ContinuationState)continuationState2!; + + if (tsk.Exception != null) + { + Exception innerException = tsk.Exception.InnerException ?? tsk.Exception; + try + { + typedState.OnFailure?.Invoke(innerException); + } + finally + { + typedState.TaskCompletionSource.TrySetException(innerException); + } + } + else if (tsk.IsCanceled) + { + try + { + typedState.OnCancellation?.Invoke(); + } + finally + { + typedState.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState.OnSuccess(); + } + // @TODO: CER Exception Handling was removed here (see GH#3581) + catch (Exception e) + { + typedState.TaskCompletionSource.TrySetException(e); + } + } + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + } + + /// + /// Continues a task and signals failure of the continuation via the provided + /// . This overload provides a single state object + /// to the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * Will try to set exception on + /// * Cancelled + /// * is called (if provided) + /// * Will try to set cancelled on + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// helper will try to set an exception on the . + /// * is *not* with result on success. This + /// is to allow the task completion source to be continued even more after this current + /// continuation. + /// + /// Type of the state object to provide to the callbacks + /// Task to continue with provided callbacks + /// + /// Completion source used to track completion of the continuation, see remarks for details + /// + /// State object to provide to callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static void ContinueTaskWithState( + Task taskToContinue, + TaskCompletionSource taskCompletionSource, + TState state, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) + { + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State: state, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => + { + ContinuationState typedState2 = (ContinuationState)continuationState2!; + + if (task.Exception is not null) + { + Exception innerException = task.Exception.InnerException ?? task.Exception; + try + { + typedState2.OnFailure?.Invoke(typedState2.State, innerException); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(innerException); + } + } + else if (task.IsCanceled) + { + try + { + typedState2.OnCancellation?.Invoke(typedState2.State); + } + finally + { + typedState2.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State); + // @TODO: The one unpleasant thing with this code is that the TCS is not set completed and left to the caller to do or not do (which is more unpleasant) + } + catch (Exception e) + { + typedState2.TaskCompletionSource.TrySetException(e); + } + } + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + } + + /// + /// Continues a task and signals failure of the continuation via the provided + /// . This overload provides two state objects to + /// the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * Will try to set exception on + /// * Cancelled + /// * is called (if provided) + /// * Will try to set cancelled on + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// helper will try to set an exception on the . + /// * is *not* with result on success. This + /// is to allow the task completion source to be continued even more after this + /// current continuation. + /// + /// Task to continue with provided callbacks + /// + /// Completion source used to track completion of the continuation, see remarks for details + /// + /// Type of the first state object to provide to callbacks + /// Type of the second state object to provide to callbacks + /// First state object to provide to callbacks + /// Second state object to provide to callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static void ContinueTaskWithState( + Task taskToContinue, + TaskCompletionSource taskCompletionSource, + TState1 state1, + TState2 state2, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) + { + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State1: state1, + State2: state2, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => + { + ContinuationState typedState2 = + (ContinuationState)continuationState2!; + + if (task.Exception is not null) + { + Exception innerException = task.Exception.InnerException ?? task.Exception; + try + { + typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, innerException); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(innerException); + } + } + else if (task.IsCanceled) + { + try + { + typedState2.OnCancellation?.Invoke(typedState2.State1, typedState2.State2); + } + finally + { + typedState2.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State1, typedState2.State2); + } + catch (Exception e) + { + typedState2.TaskCompletionSource.TrySetException(e); + } + } + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + } + + /// + /// Continues a task and returns the continuation task. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * The task will be completed with an exception. + /// * Cancelled + /// * is called (if provided) + /// * The task will be completed as cancelled. + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// task will be completed with the exception. + /// * The task will be completed as successful. + /// + /// + /// Task to continue with provided callbacks, if null, null will be returned. + /// + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static Task? CreateContinuationTask( + Task? taskToContinue, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) + { + if (taskToContinue is null) + { + // This is a remnant of ye olde async/sync code that return null tasks when + // executing synchronously. It's still desirable that the onSuccess executes + // regardless of whether the preceding action was synchronous or asynchronous. + onSuccess(); + return null; + } + + // @TODO: Can totally use a non-generic TaskCompletionSource + TaskCompletionSource taskCompletionSource = new(); + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => + { + ContinuationState typedState = (ContinuationState)continuationState2!; + if (task.Exception is not null) + { + Exception innerException = task.Exception.InnerException ?? task.Exception; + try + { + typedState.OnFailure?.Invoke(innerException); + } + finally + { + typedState.TaskCompletionSource.TrySetException(innerException); + } + } + else if (task.IsCanceled) + { + try + { + typedState.OnCancellation?.Invoke(); + } + finally + { + typedState.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState.OnSuccess(); + typedState.TaskCompletionSource.TrySetResult(null); + } + catch (Exception e) + { + typedState.TaskCompletionSource.TrySetException(e); + } + } + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + + return taskCompletionSource.Task; + } + + /// + /// Continues a task and returns the continuation task. This overload allows a state object + /// to be passed into the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * The task will be completed with an exception. + /// * Cancelled + /// * is called (if provided) + /// * The task will be completed as cancelled. + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// task will be completed with the exception. + /// * The task will be completed as successful. + /// + /// Type of the state object to pass to callbacks + /// + /// Task to continue with provided callbacks, if null, null will be returned. + /// + /// State object to pass to the callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static Task? CreateContinuationTaskWithState( + Task? taskToContinue, + TState state, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) + { + if (taskToContinue is null) + { + // This is a remnant of ye olde async/sync code that return null tasks when + // executing synchronously. It's still desirable that the onSuccess executes + // regardless of whether the preceding action was synchronous or asynchronous. + onSuccess(state); + return null; + } + + // @TODO: Can totally use a non-generic TaskCompletionSource + TaskCompletionSource taskCompletionSource = new(); + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State: state, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => + { + ContinuationState typedState2 = (ContinuationState)continuationState2!; + + if (task.Exception is not null) + { + Exception innerException = task.Exception.InnerException ?? task.Exception; + try + { + typedState2.OnFailure?.Invoke(typedState2.State, innerException); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(innerException); + } + } + else if (task.IsCanceled) + { + try + { + typedState2.OnCancellation?.Invoke(typedState2.State); + } + finally + { + typedState2.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State); + typedState2.TaskCompletionSource.TrySetResult(null); + } + catch (Exception e) + { + typedState2.TaskCompletionSource.TrySetException(e); + } + } + + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + + return taskCompletionSource.Task; + } + + /// + /// Continues a task and returns the continuation task. This overload allows two state + /// objects to be passed into the callbacks. + /// + /// + /// When possible, use static lambdas for callbacks. + /// + /// If the completes in the following states, these + /// actions will be taken: + /// * With exception + /// * is called (if provided) + /// * The task will be completed with an exception. + /// * Cancelled + /// * is called (if provided) + /// * The task will be completed as cancelled. + /// * Successfully + /// * is called + /// * IF an exception is thrown during execution of , the + /// task will be completed with the exception. + /// * The task will be completed as successful. + /// + /// Type of the first state object to pass to callbacks + /// Type of the second state object to pass to callbacks + /// + /// Task to continue with provided callbacks, if null, null will be returned. + /// + /// First state object to pass to the callbacks + /// Second state object to pass to the callbacks + /// Callback to execute on successful completion of the task + /// Callback to execute on failure of the task (optional) + /// Callback to execute on cancellation of the task (optional) + internal static Task? CreateContinuationTaskWithState( + Task? taskToContinue, + TState1 state1, + TState2 state2, + Action onSuccess, + Action? onFailure = null, + Action? onCancellation = null) + { + if (taskToContinue is null) + { + // This is a remnant of ye olde async/sync code that return null tasks when + // executing synchronously. It's still desirable that the onSuccess executes + // regardless of whether the preceding action was synchronous or asynchronous. + onSuccess(state1, state2); + return null; + } + + // @TODO: Can totally use a non-generic TaskCompletionSource + TaskCompletionSource taskCompletionSource = new(); + ContinuationState continuationState = new( + OnCancellation: onCancellation, + OnFailure: onFailure, + OnSuccess: onSuccess, + State1: state1, + State2: state2, + TaskCompletionSource: taskCompletionSource); + + Task continuationTask = taskToContinue.ContinueWith( + static (task, continuationState2) => + { + ContinuationState typedState2 = + (ContinuationState)continuationState2!; + + if (task.Exception is not null) + { + Exception innerException = task.Exception.InnerException ?? task.Exception; + try + { + typedState2.OnFailure?.Invoke(typedState2.State1, typedState2.State2, innerException); + } + finally + { + typedState2.TaskCompletionSource.TrySetException(innerException); + } + } + else if (task.IsCanceled) + { + try + { + typedState2.OnCancellation?.Invoke(typedState2.State1, typedState2.State2); + } + finally + { + typedState2.TaskCompletionSource.TrySetCanceled(); + } + } + else + { + try + { + typedState2.OnSuccess(typedState2.State1, typedState2.State2); + typedState2.TaskCompletionSource.TrySetResult(null); + } + catch (Exception e) + { + typedState2.TaskCompletionSource.TrySetException(e); + } + } + + }, + state: continuationState, + scheduler: TaskScheduler.Default); + + // Explicitly follow up by observing any exception thrown during continuation + ObserveContinuationException(continuationTask); + + return taskCompletionSource.Task; + } + + /// + /// Executes a timeout task in parallel with the provided + /// . If the timeout completes before the task + /// completion source, the provided is executed and the + /// exception returned is set as the exception that completes the task completion source. + /// + /// Task to execute with a timeout + /// Number of seconds to wait until timing out the task + /// + /// Callback to execute when the task does not complete within the allotted time. The + /// exception returned by the callback is set on the . + /// + /// Cancellation token to prematurely cancel timeout + internal static void SetTimeoutException( + TaskCompletionSource taskCompletionSource, + int timeoutInSeconds, + Func onTimeout, + CancellationToken cancellationToken) + { + if (timeoutInSeconds <= 0) + { + return; + } + + Task.Delay(TimeSpan.FromSeconds(timeoutInSeconds), cancellationToken) + .ContinueWith( + task => + { + // If the timeout ran to completion AND the task to complete did not complete + // then the timeout expired first, run the timeout handler + if (!task.IsCanceled && !taskCompletionSource.Task.IsCompleted) + { + taskCompletionSource.TrySetException(onTimeout()); + } + }, + cancellationToken: CancellationToken.None); + } + + /// + /// Executes a timeout task in parallel with the provided + /// . If the timeout completes before the task + /// completion source, the provided is executed and the + /// exception returned is set as the exception that completes the task completion source. + /// This overload provides a state object to the timeout callback. + /// + /// Task to execute with a timeout + /// Number of seconds to wait until timing out the task + /// State object to pass to the callback + /// + /// Callback to execute when the task does not complete within the allotted time. The + /// exception returned by the callback is set on the . + /// + /// Cancellation token to prematurely cancel timeout + internal static void SetTimeoutExceptionWithState( + TaskCompletionSource taskCompletionSource, + int timeoutInSeconds, + TState state, + Func onTimeout, + CancellationToken cancellationToken) + { + if (timeoutInSeconds <= 0) + { + return; + } + + Task.Delay(TimeSpan.FromSeconds(timeoutInSeconds), cancellationToken) + .ContinueWith( + (task, state2) => + { + // If the timeout ran to completion AND the task to complete did not complete + // then the timeout expired first, run the timeout handler + if (!task.IsCanceled && !taskCompletionSource.Task.IsCompleted) + { + taskCompletionSource.TrySetException(onTimeout((TState)state2!)); + } + }, + state: state, + cancellationToken: CancellationToken.None); + } + + /// + /// Waits for a maximum of seconds for completion of + /// the provided . + /// + /// Task to execute with a timeout + /// Number of seconds to wait until timing out the task + /// + /// Callback to execute when the task does not complete within the allotted time. + /// + /// + /// If true, the inner exception of any raised + /// during execution, including timeout of the task, will be rethrown. + /// + internal static void WaitForCompletion( + Task task, + int timeoutInSeconds, + Action? onTimeout = null, + bool rethrowExceptions = true) + { + try + { + TimeSpan timeout = timeoutInSeconds > 0 + ? TimeSpan.FromSeconds(timeoutInSeconds) + : Timeout.InfiniteTimeSpan; + task.Wait(timeout); + } + catch (AggregateException ae) + { + if (rethrowExceptions) + { + Debug.Assert(ae.InnerException is not null, "Inner exception is null"); + Debug.Assert(ae.InnerExceptions.Count == 1, "There is more than one exception in AggregateException"); + ExceptionDispatchInfo.Capture(ae.InnerException!).Throw(); + } + } + + if (!task.IsCompleted) + { + // Ensure the task does not leave an unobserved exception + task.ContinueWith(static t => { _ = t.Exception; }); + onTimeout?.Invoke(); + } + } + + /// + /// This method is intended to be used within the above helpers to ensure that any + /// exceptions thrown during callbacks do not go unobserved. If these exceptions were + /// to go unobserved, they will trigger events to be raised by the default task scheduler. + /// Neither situation is ideal: + /// * If an application assigns a listener to this event, it will generate events that + /// should be reported to us. But, because it happens outside the stack that caused the + /// exception, most of the context of the exception is lost. Furthermore, the event is + /// triggered when the GC runs, so the event happens asynchronous to the action that + /// caused it. + /// * Adding this forced observation of the exception prevents applications from receiving + /// the event, effectively swallowing it. + /// * However, if we log the exception when we observe it, we can still log that the + /// unobserved exception happened without causing undue disruption to the application + /// or leaking resources and causing overhead by raising the event. + /// + private static void ObserveContinuationException(Task continuationTask) + { + continuationTask.ContinueWith( + static task => + { + SqlClientEventSource.Log.TryTraceEvent($"Unobserved task exception: {task.Exception}"); + return _ = task.Exception; + }, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously); + } + + private record ContinuationState( + Action? OnCancellation, + Action? OnFailure, + Action OnSuccess, + TaskCompletionSource TaskCompletionSource); + + private record ContinuationState( + Action? OnCancellation, + Action? OnFailure, + Action OnSuccess, + TState State, + TaskCompletionSource TaskCompletionSource); + + private record ContinuationState( + Action? OnCancellation, + Action? OnFailure, + Action OnSuccess, + TState1 State1, + TState2 State2, + TaskCompletionSource TaskCompletionSource); + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj index 7f6d8abd2c..28389b0335 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj @@ -63,7 +63,6 @@ - diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs deleted file mode 100644 index 44286b8c0e..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlHelperTest.cs +++ /dev/null @@ -1,62 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; -using Xunit; - -namespace Microsoft.Data.SqlClient.Tests -{ - public class SqlHelperTest - { - private void TimeOutATask() - { - var sqlClientAssembly = Assembly.GetAssembly(typeof(SqlCommand)); - //We're using reflection to avoid exposing the internals - MethodInfo waitForCompletion = sqlClientAssembly.GetType("Microsoft.Data.SqlClient.AsyncHelper") - ?.GetMethod("WaitForCompletion", BindingFlags.Static | BindingFlags.NonPublic); - - Assert.False(waitForCompletion == null, "Running a test on SqlUtil.WaitForCompletion but could not find this method"); - TaskCompletionSource tcs = new TaskCompletionSource(); - waitForCompletion.Invoke(null, new object[] { tcs.Task, 1, null, true }); //Will time out as task uncompleted - tcs.SetException(new TimeoutException("Dummy timeout exception")); //Our task now completes with an error - } - - private Exception UnwrapException(Exception e) - { - return e?.InnerException != null ? UnwrapException(e.InnerException) : e; - } - - [Fact] - public void WaitForCompletion_DoesNotCreateUnobservedException() - { - var unobservedExceptionHappenedEvent = new AutoResetEvent(false); - Exception unhandledException = null; - void handleUnobservedException(object o, UnobservedTaskExceptionEventArgs a) - { unhandledException = a.Exception; unobservedExceptionHappenedEvent.Set(); } - - TaskScheduler.UnobservedTaskException += handleUnobservedException; - - try - { - TimeOutATask(); //Create the task in another function so the task has no reference remaining - GC.Collect(); //Force collection of unobserved task - GC.WaitForPendingFinalizers(); - - bool unobservedExceptionHappend = unobservedExceptionHappenedEvent.WaitOne(1); - if (unobservedExceptionHappend) //Save doing string interpolation in the happy case - { - var e = UnwrapException(unhandledException); - Assert.Fail($"Did not expect an unobserved exception, but found a {e?.GetType()} with message \"{e?.Message}\""); - } - } - finally - { - TaskScheduler.UnobservedTaskException -= handleUnobservedException; - } - } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/EventSourceTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/EventSourceTest.cs index 4992c55974..72be245d3b 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/EventSourceTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/EventSourceTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Collections.Generic; using System.Linq; using Xunit; @@ -12,22 +13,28 @@ public class EventSourceTest [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] public void EventSourceTestAll() { - using DataTestUtility.MDSEventListener TraceListener = new(); - using (SqlConnection connection = new(DataTestUtility.TCPConnectionString)) + using DataTestUtility.MDSEventListener traceListener = new(); + + using SqlConnection connection = new(DataTestUtility.TCPConnectionString); + connection.Open(); + + using SqlCommand command = new("SELECT @@VERSION", connection); + using SqlDataReader reader = command.ExecuteReader(); + while (reader.Read()) { - connection.Open(); - using SqlCommand command = new("SELECT @@VERSION", connection); - using SqlDataReader reader = command.ExecuteReader(); - while (reader.Read()) - { - // Flush data - } + // Flush data } - // Need to investigate better way of validating traces in sequential runs, - // For now we're collecting all traces to improve code coverage. + // TODO: Need to investigate better way of validating traces in sequential runs, for now we're collecting all traces to improve code coverage. - Assert.All(TraceListener.IDs, item => { Assert.Contains(item, Enumerable.Range(1, 21)); }); + // Assert + // - Collected trace event IDs are in the range of official trace event IDs + // @TODO: This is brittle, refactor the SqlClientEventSource code so the event IDs it can throw are accessible here + HashSet acceptableEventIds = new HashSet(Enumerable.Range(0, 21)); + foreach (int id in traceListener.IDs) + { + Assert.Contains(id, acceptableEventIds); + } } } } diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj index a105ccdf29..7f25bfe919 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj @@ -11,7 +11,6 @@ - runtime; build; native; contentfiles; analyzers; buildtransitive @@ -28,6 +27,8 @@ + + diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs new file mode 100644 index 0000000000..39783ed0b3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/Utilities/AsyncHelperTest.cs @@ -0,0 +1,1353 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient.UnitTests.Utilities; +using Microsoft.Data.SqlClient.Utilities; +using Moq; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.Microsoft.Data.SqlClient.Utilities +{ + public class AsyncHelperTest + { + // This timeout is set fairly high. The tests are expected to complete quickly, but are + // dependent on congestion of the thread pool. If the thread pool is congested, like on a + // full CI run, short timeouts may elapse even if the code under test would behave as + // expected. As such, we set a long timeout to ride out reasonable congestion on the + // thread pool, but still trigger a failure if the code under test hangs. + // @TODO: If suite-level timeouts are added, these timeouts can likely be removed. + private static readonly TimeSpan RunTimeout = TimeSpan.FromSeconds(30); + + #region ContinueTask + + [Fact] + public async Task ContinueTask_TaskCompletes() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // Note: We have to set up mockOnSuccess to set a result on the task completion source, + // since the AsyncHelper will not do it, and without that, we cannot reliably + // know when the continuation completed. We will use SetResult b/c it will throw + // if it has already been set. + Mock mockOnSuccess = new(); + mockOnSuccess.Setup(action => action()) + .Callback(() => taskCompletionSource.SetResult(0)); + + // Act + AsyncHelper.ContinueTask( + taskToContinue: taskToContinue, + taskCompletionSource: taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTask_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + + // - mockOnSuccess handler throws + Mock mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); + + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTask_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that is cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + + Mock mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(), Times.Once); + } + + [Fact] + public async Task ContinueTask_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue that is cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTask_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that is faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTask_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue that is cancelled + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTask( + taskToContinue, + taskCompletionSource, + mockOnSuccess.Object, + onFailure: null, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region ContinueTaskWithState + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskCompletes() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Note: We have to set up mockOnSuccess to set a result on the task completion source, + // since the AsyncHelper will not do it, and without that, we cannot reliably + // know when the continuation completed. We will use SetResult b/c it will throw + // if it has already been set. + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(action => action(state1)) + .Callback(_ => taskCompletionSource.SetResult(0)); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue: taskToContinue, + taskCompletionSource: taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(action => action(It.IsAny())).Throws(); + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have faulted + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_1Generic_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that was cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue that was cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_1Generic_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTaskWithState_1Generic_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue that faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + mockOnSuccess.Object, + onFailure: null, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region ContinueTaskWithState + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskCompletes() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Note: We have to set up mockOnSuccess to set a result on the task completion source, + // since the AsyncHelper will not do it, and without that, we cannot reliably + // know when the continuation completed. We will use SetResult b/c it will throw + // if it has already been set. + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(action => action(state1, state2)) + .Callback((_, _) => taskCompletionSource.SetResult(0)); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue: taskToContinue, + taskCompletionSource: taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue that completed successfully + Task taskToContinue = Task.CompletedTask; + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234 + + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.Setup(o => o(It.IsAny(), It.IsAny())).Throws(); + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have faulted + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + // - mockOnSuccess was called with state obj + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_2Generics_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that was cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnCancellation throwing + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1, state2), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue that was cancelled + Task taskToContinue = GetCancelledTask(); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Canceled, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task ContinueTaskWithState_2Generics_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue that faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled, regardless of mockOnSuccess throwing + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); + } + + [Fact] + public async Task ContinueTaskWithState_2Generics_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue that faulted + Task taskToContinue = Task.FromException(new Exception()); + TaskCompletionSource taskCompletionSource = GetTaskCompletionSource(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + AsyncHelper.ContinueTaskWithState( + taskToContinue, + taskCompletionSource, + state1, + state2, + mockOnSuccess.Object, + onFailure: null, + mockOnCancellation.Object); + await RunWithTimeout(taskCompletionSource.Task, RunTimeout); + + // Assert + // - taskCompletionSource should have been cancelled + Assert.Equal(TaskStatus.Faulted, taskCompletionSource.Task.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region CreateContinuationTask + + [Fact] + public void CreateContinuationTask_NullTask() + { + // Arrange + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue: null, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + + // Assert + Assert.Null(continuationTask); + + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTask_TaskCompletes() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + Mock mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTask_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + Mock> mockOnFailure = new(); + Mock mockOnCancellation = new(); + + // - mockOnSuccess handler throws + Mock mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.Verify(action => action(), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTask_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + Mock> mockOnFailure = new(); + Mock mockOnSuccess = new(); + + Mock mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(), Times.Once); + } + + [Fact] + public async Task CreateContinuationTask_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = GetCancelledTask(); + Mock> mockOnFailure = new(); + Mock mockOnSuccess = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTask_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTask_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.FromException(new Exception()); + Mock mockOnSuccess = new(); + Mock mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTask( + taskToContinue, + mockOnSuccess.Object, + onFailure: null, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region CreateContinuationTaskWithState + + [Fact] + public void CreateContinuationTaskWithState_1Generic_NullTask() + { + // Arrange + const int state1 = 123; + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: null, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + + // Assert + Assert.Null(continuationTask); + + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskCompletes() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.Verify(action => action(state1), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_1Generic_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); + + Mock> mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1), Times.Once); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_1Generic_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_1Generic_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + mockOnSuccess.Object, + onFailure: null, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region CreateContinuationTaskWithState + + [Fact] + public void CreateContinuationTaskWithState_2Generics_NullTask() + { + // Arrange + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue: null, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + + // Assert + Assert.Null(continuationTask); + + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskCompletes() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.RanToCompletion, continuationTask.Status); + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskCompletesHandlerThrows() + { + // Arrange + // - Task to continue completed successfully + Task taskToContinue = Task.CompletedTask; + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + Mock> mockOnCancellation = new(); + + // - mockOnSuccess handler throws + Mock> mockOnSuccess = new(); + mockOnSuccess.SetupThrows(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.Verify(action => action(state1, state2), Times.Once); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_2Generics_TaskCancels(bool handlerShouldThrow) + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); + + Mock> mockOnCancellation = new(); + if (handlerShouldThrow) + { + mockOnCancellation.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + mockOnCancellation.Verify(action => action(state1, state2), Times.Once); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskCancelsNoHandler() + { + // Arrange + // - Task to continue was cancelled + Task taskToContinue = GetCancelledTask(); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnFailure = new(); + Mock> mockOnSuccess = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Canceled, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.VerifyNeverCalled(); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task CreateContinuationTaskWithState_2Generics_TaskFaults(bool handlerShouldThrow) + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + Mock> mockOnFailure = new(); + if (handlerShouldThrow) + { + mockOnFailure.SetupThrows(); + } + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + mockOnFailure.Object, + mockOnCancellation.Object); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnFailure.Verify(action => action(state1, state2, It.IsAny()), Times.Once); + mockOnCancellation.VerifyNeverCalled(); + } + + [Fact] + public async Task CreateContinuationTaskWithState_2Generics_TaskFaultsNoHandler() + { + // Arrange + // - Task to continue faulted + Task taskToContinue = Task.FromException(new Exception()); + const int state1 = 123; + const int state2 = 234; + + Mock> mockOnSuccess = new(); + Mock> mockOnCancellation = new(); + + // Act + Task? continuationTask = AsyncHelper.CreateContinuationTaskWithState( + taskToContinue, + state1, + state2, + mockOnSuccess.Object, + onFailure: null, + onCancellation: null); + await RunWithTimeout(continuationTask, RunTimeout); + + // Assert + Assert.Equal(TaskStatus.Faulted, continuationTask.Status); + mockOnSuccess.VerifyNeverCalled(); + mockOnCancellation.VerifyNeverCalled(); + } + + #endregion + + #region WaitForCompletion + + [Fact] + public void WaitForCompletion_DoesNotCreateUnobservedException() + { + // Arrange + // - Create a handler to capture any unhandled exception + Exception? unhandledException = null; + EventHandler handleUnobservedException = + (_, args) => unhandledException = args.Exception; + + // @TODO: Can we do this with a custom scheduler to avoid changing global state? + TaskScheduler.UnobservedTaskException += handleUnobservedException; + + try + { + // Act + // - Run task that will always time out + TaskCompletionSource tcs = new(); + AsyncHelper.WaitForCompletion( + tcs.Task, + timeoutInSeconds: 1, + onTimeout: null, + rethrowExceptions: true); + + // - Force collection of unobserved task + GC.Collect(); + GC.WaitForPendingFinalizers(); + + // Assert + // - Make sure no unobserved tasks happened + Assert.Null(unhandledException); + } + finally + { + // Cleanup + // - Remove the unobserved task handler + TaskScheduler.UnobservedTaskException -= handleUnobservedException; + } + } + + #endregion + + private static Task GetCancelledTask() + { + using CancellationTokenSource cts = new(); + cts.Cancel(); + + return Task.FromCanceled(cts.Token); + } + + private static TaskCompletionSource GetTaskCompletionSource() + => new(TaskCreationOptions.RunContinuationsAsynchronously); + + private static async Task RunWithTimeout([NotNull] Task? taskToRun, TimeSpan timeout) + { + if (taskToRun is null) + { + Assert.Fail("Expected non-null task for timeout"); + } + + Task winner = await Task.WhenAny(taskToRun, Task.Delay(timeout)); + if (winner != taskToRun) + { + Assert.Fail("Timeout elapsed."); + } + + // Force observation of any exception + _ = taskToRun.Exception; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs new file mode 100644 index 0000000000..d1cfe89e2e --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Utilities/MockExtensions.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Moq; + +namespace Microsoft.Data.SqlClient.UnitTests.Utilities +{ + public static class MockExtensions + { + public static void SetupThrows(this Mock mock) + where TException : Exception, new() + { + mock.Setup(action => action()) + .Throws(); + } + + public static void SetupThrows(this Mock> mock) + where TException : Exception, new() + { + mock.Setup(action => action(It.IsAny())) + .Throws(); + } + + public static void SetupThrows(this Mock> mock) + where TException : Exception, new() + { + mock.Setup(action => action(It.IsAny(), It.IsAny())) + .Throws(); + } + + public static void SetupThrows(this Mock> mock) + where TException : Exception, new() + { + mock.Setup(action => action(It.IsAny(), It.IsAny(), It.IsAny())) + .Throws(); + } + + public static void VerifyNeverCalled(this Mock mock) => + mock.Verify(action => action(), Times.Never); + + public static void VerifyNeverCalled(this Mock> mock) + { + mock.Verify( + action => action(It.IsAny()), + Times.Never); + } + + public static void VerifyNeverCalled(this Mock> mock) + { + mock.Verify( + action => action(It.IsAny(), It.IsAny()), + Times.Never); + } + + public static void VerifyNeverCalled(this Mock> mock) + { + mock.Verify( + action => action(It.IsAny(), It.IsAny(), It.IsAny()), + Times.Never); + } + } +}