diff --git a/build.proj b/build.proj index 0443bbcf4e..79635218ff 100644 --- a/build.proj +++ b/build.proj @@ -58,9 +58,14 @@ + + + + - + + @@ -220,6 +225,7 @@ -p:TestTargetOS=Windows$(TargetGroup) --collect "Code coverage" --results-directory $(ResultsDirectory) + --filter "category!=failing%26category!=flaky" --logger:"trx;LogFilePrefix=Unit-Windows$(TargetGroup)-$(TestSet)" $(TestCommand.Replace($([System.Environment]::NewLine), " ")) @@ -240,8 +246,9 @@ -p:TestTargetOS=Unixnetcoreapp --collect "Code coverage" --results-directory $(ResultsDirectory) + --filter "category!=failing%26category!=flaky" --logger:"trx;LogFilePrefix=Unit-Unixnetcoreapp-$(TestSet)" - + $(TestCommand.Replace($([System.Environment]::NewLine), " ")) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 2434582205..66dd0e7340 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -1712,8 +1712,11 @@ private void LoginNoFailover(ServerInfo serverInfo, continue; } + // If state != closed, indicates that the parser encountered an error while processing the + // login response (e.g. an explicit error token). Transient network errors that impact + // connectivity will result in parser state being closed. if (_parser == null - || TdsParserState.Closed != _parser.State + || _parser.State != TdsParserState.Closed || IsDoNotRetryConnectError(sqlex) || timeout.IsExpired) { @@ -1993,6 +1996,9 @@ TimeoutTimer timeout throw; // Caller will call LoginFailure() } + // TODO: It doesn't make sense to connect to an azure sql server instance with a failover partner + // specified. Azure SQL Server does not support failover partners. Other availability technologies + // like Failover Groups should be used instead. if (!ADP.IsAzureSqlServerEndpoint(connectionOptions.DataSource) && IsConnectionDoomed) { throw; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs index 248fe0f985..9920417fb0 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/AdapterUtil.cs @@ -24,6 +24,7 @@ using Microsoft.Identity.Client; using Microsoft.SqlServer.Server; using System.Security.Authentication; +using System.Collections.Generic; #if NETFRAMEWORK using System.Reflection; @@ -787,7 +788,7 @@ internal static Version GetAssemblyVersion() /// This array includes endpoint URLs for Azure SQL in global, Germany, US Government, /// China, and Fabric environments. These endpoints are used to identify and interact with Azure SQL services /// in their respective regions or environments. - internal static readonly string[] s_azureSqlServerEndpoints = { AZURE_SQL, + internal static readonly List s_azureSqlServerEndpoints = new() { AZURE_SQL, AZURE_SQL_GERMANY, AZURE_SQL_USGOV, AZURE_SQL_CHINA, @@ -827,7 +828,7 @@ internal static bool IsAzureSqlServerEndpoint(string dataSource) } // This method assumes dataSource parameter is in TCP connection string format. - private static bool IsEndpoint(string dataSource, string[] endpoints) + private static bool IsEndpoint(string dataSource, ICollection endpoints) { int length = dataSource.Length; // remove server port diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/ConnectionString/DbConnectionStringDefaults.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/ConnectionString/DbConnectionStringDefaults.cs index 20831c4912..757b3522fc 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/ConnectionString/DbConnectionStringDefaults.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/ConnectionString/DbConnectionStringDefaults.cs @@ -57,7 +57,7 @@ internal static class DbConnectionStringDefaults #if NETFRAMEWORK internal const bool ConnectionReset = true; - internal static readonly bool TransparentNetworkIpResolution = !LocalAppContextSwitches.DisableTnirByDefault; + internal static bool TransparentNetworkIpResolution => !LocalAppContextSwitches.DisableTnirByDefault; internal const string NetworkLibrary = ""; #endif } diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalizationTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalizationTest.cs index 77e7eee950..3676abb237 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalizationTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalizationTest.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Globalization; using System.Threading; +using Microsoft.SqlServer.TDS.Servers; using Xunit; namespace Microsoft.Data.SqlClient.Tests @@ -55,9 +56,11 @@ private string GetLocalizedErrorMessage(string culture) Thread.CurrentThread.CurrentCulture = new CultureInfo(culture); Thread.CurrentThread.CurrentUICulture = new CultureInfo(culture); - using TestTdsServer server = TestTdsServer.StartTestServer(); - var connStr = server.ConnectionString; - connStr = connStr.Replace("localhost", "dummy"); + using TdsServer server = new TdsServer(new TdsServerArguments()); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { + DataSource = $"dummy,{server.EndPoint.Port}" + }.ConnectionString; using SqlConnection connection = new SqlConnection(connStr); try diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj index 91a5a505b9..730d96ee19 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj @@ -36,8 +36,6 @@ - - @@ -64,8 +62,6 @@ - - @@ -91,6 +87,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionReadOnlyRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionReadOnlyRoutingTests.cs deleted file mode 100644 index c3574dbc13..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionReadOnlyRoutingTests.cs +++ /dev/null @@ -1,140 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Net; -using System.Threading.Tasks; -using Microsoft.SqlServer.TDS.Servers; -using Xunit; - -namespace Microsoft.Data.SqlClient.Tests -{ - public class SqlConnectionReadOnlyRoutingTests - { - [Fact] - public void NonRoutedConnection() - { - using TestTdsServer server = TestTdsServer.StartTestServer(); - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(server.ConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; - using SqlConnection connection = new SqlConnection(builder.ConnectionString); - connection.Open(); - } - - [Fact] - public async Task NonRoutedAsyncConnection() - { - using TestTdsServer server = TestTdsServer.StartTestServer(); - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(server.ConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; - using SqlConnection connection = new SqlConnection(builder.ConnectionString); - await connection.OpenAsync(); - } - - [Fact] - public void RoutedConnection() - => RecursivelyRoutedConnection(1); - - [Fact] - public async Task RoutedAsyncConnection() - => await RecursivelyRoutedAsyncConnection(1); - - [Theory] - [InlineData(2)] - [InlineData(9)] - [InlineData(11)] // The driver rejects more than 10 redirects (11 layers of redirecting servers) - public void RecursivelyRoutedConnection(int layers) - { - TestTdsServer innerServer = TestTdsServer.StartTestServer(); - IPEndPoint lastEndpoint = innerServer.Endpoint; - Stack routingLayers = new(layers + 1); - string lastConnectionString = innerServer.ConnectionString; - - try - { - routingLayers.Push(innerServer); - for (int i = 0; i < layers; i++) - { - TestRoutingTdsServer router = TestRoutingTdsServer.StartTestServer(lastEndpoint); - - routingLayers.Push(router); - lastEndpoint = router.Endpoint; - lastConnectionString = router.ConnectionString; - } - - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(lastConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; - using SqlConnection connection = new SqlConnection(builder.ConnectionString); - connection.Open(); - } - finally - { - while (routingLayers.Count > 0) - { - GenericTDSServer layer = routingLayers.Pop(); - - if (layer is IDisposable disp) - { - disp.Dispose(); - } - } - } - } - - [Theory] - [InlineData(2)] - [InlineData(9)] - [InlineData(11)] // The driver rejects more than 10 redirects (11 layers of redirecting servers) - public async Task RecursivelyRoutedAsyncConnection(int layers) - { - TestTdsServer innerServer = TestTdsServer.StartTestServer(); - IPEndPoint lastEndpoint = innerServer.Endpoint; - Stack routingLayers = new(layers + 1); - string lastConnectionString = innerServer.ConnectionString; - - try - { - routingLayers.Push(innerServer); - for (int i = 0; i < layers; i++) - { - TestRoutingTdsServer router = TestRoutingTdsServer.StartTestServer(lastEndpoint); - - routingLayers.Push(router); - lastEndpoint = router.Endpoint; - lastConnectionString = router.ConnectionString; - } - - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(lastConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; - using SqlConnection connection = new SqlConnection(builder.ConnectionString); - await connection.OpenAsync(); - } - finally - { - while (routingLayers.Count > 0) - { - GenericTDSServer layer = routingLayers.Pop(); - - if (layer is IDisposable disp) - { - disp.Dispose(); - } - } - } - } - - [Fact] - public void ConnectionRoutingLimit() - { - SqlException sqlEx = Assert.Throws(() => RecursivelyRoutedConnection(12)); // This will fail on the 11th redirect - - Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Fact] - public async Task AsyncConnectionRoutingLimit() - { - SqlException sqlEx = await Assert.ThrowsAsync(() => RecursivelyRoutedAsyncConnection(12)); // This will fail on the 11th redirect - - Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase); - } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestRoutingTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestRoutingTdsServer.cs deleted file mode 100644 index 130b50cad9..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestRoutingTdsServer.cs +++ /dev/null @@ -1,64 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Net; -using System.Runtime.CompilerServices; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.Servers; - -namespace Microsoft.Data.SqlClient.Tests -{ - internal class TestRoutingTdsServer : RoutingTDSServer, IDisposable - { - private const int DefaultConnectionTimeout = 5; - - private TDSServerEndPoint _endpoint = null; - - private SqlConnectionStringBuilder _connectionStringBuilder; - - public TestRoutingTdsServer(RoutingTDSServerArguments args) : base(args) { } - - public static TestRoutingTdsServer StartTestServer(IPEndPoint destinationEndpoint, bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, bool excludeEncryption = false, [CallerMemberName] string methodName = "") - { - RoutingTDSServerArguments args = new RoutingTDSServerArguments() - { - Log = enableLog ? Console.Out : null, - RoutingTCPHost = destinationEndpoint.Address.ToString() == IPAddress.Any.ToString() ? IPAddress.Loopback.ToString() : destinationEndpoint.Address.ToString(), - RoutingTCPPort = (ushort)destinationEndpoint.Port, - }; - - if (enableFedAuth) - { - args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired; - } - if (excludeEncryption) - { - args.Encryption = SqlServer.TDS.PreLogin.TDSPreLoginTokenEncryptionType.None; - } - - TestRoutingTdsServer server = new TestRoutingTdsServer(args); - server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; - server._endpoint.EndpointName = methodName; - // The server EventLog should be enabled as it logs the exceptions. - server._endpoint.EventLog = enableLog ? Console.Out : null; - server._endpoint.Start(); - - int port = server._endpoint.ServerEndPoint.Port; - server._connectionStringBuilder = excludeEncryption - // Allow encryption to be set when encryption is to be excluded from pre-login response. - ? new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Mandatory } - : new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Optional }; - server.ConnectionString = server._connectionStringBuilder.ConnectionString; - server.Endpoint = server._endpoint.ServerEndPoint; - return server; - } - - public void Dispose() => _endpoint?.Stop(); - - public string ConnectionString { get; private set; } - - public IPEndPoint Endpoint { get; private set; } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs deleted file mode 100644 index a5976fd6d5..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs +++ /dev/null @@ -1,76 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Net; -using System.Runtime.CompilerServices; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.Servers; - -namespace Microsoft.Data.SqlClient.Tests -{ - internal class TestTdsServer : GenericTDSServer, IDisposable - { - private const int DefaultConnectionTimeout = 5; - - private TDSServerEndPoint _endpoint = null; - - private SqlConnectionStringBuilder _connectionStringBuilder; - - public TestTdsServer(TDSServerArguments args) : base(args) { } - - public TestTdsServer(QueryEngine engine, TDSServerArguments args) : base(args) - { - Engine = engine; - } - - public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, bool excludeEncryption = false, Version serverVersion = null, [CallerMemberName] string methodName = "") - { - TDSServerArguments args = new TDSServerArguments() - { - Log = enableLog ? Console.Out : null, - }; - - if (enableFedAuth) - { - args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired; - } - if (excludeEncryption) - { - args.Encryption = SqlServer.TDS.PreLogin.TDSPreLoginTokenEncryptionType.None; - } - if (serverVersion != null) - { - args.ServerVersion = serverVersion; - } - - TestTdsServer server = engine == null ? new TestTdsServer(args) : new TestTdsServer(engine, args); - server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; - server._endpoint.EndpointName = methodName; - // The server EventLog should be enabled as it logs the exceptions. - server._endpoint.EventLog = enableLog ? Console.Out : null; - server._endpoint.Start(); - - int port = server._endpoint.ServerEndPoint.Port; - server._connectionStringBuilder = excludeEncryption - // Allow encryption to be set when encryption is to be excluded from pre-login response. - ? new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Mandatory } - : new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Optional }; - server.ConnectionString = server._connectionStringBuilder.ConnectionString; - server.Endpoint = server._endpoint.ServerEndPoint; - return server; - } - - public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, bool excludeEncryption = false, Version serverVersion = null, [CallerMemberName] string methodName = "") - { - return StartServerWithQueryEngine(null, enableFedAuth, enableLog, connectionTimeout, excludeEncryption, serverVersion, methodName); - } - - public void Dispose() => _endpoint?.Stop(); - - public string ConnectionString { get; private set; } - - public IPEndPoint Endpoint { get; private set; } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 94819c9bb4..cba929bdf6 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -298,7 +298,6 @@ - diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs index 48b0c9273e..169f71704a 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs @@ -12,8 +12,10 @@ using System.ServiceProcess; using System.Text; using Microsoft.Data.SqlClient.ManualTesting.Tests.DataCommon; +using Microsoft.SqlServer.TDS.Servers; using Microsoft.Win32; using Xunit; +#nullable enable namespace Microsoft.Data.SqlClient.ManualTesting.Tests { @@ -129,18 +131,19 @@ private void ConnectionTest(ConnectionTestParameters connectionTestParameters) string userId = string.IsNullOrWhiteSpace(builder.UserID) ? "user" : builder.UserID; string password = string.IsNullOrWhiteSpace(builder.Password) ? "password" : builder.Password; - using TestTdsServer server = TestTdsServer.StartTestServer(enableFedAuth: false, enableLog: false, connectionTimeout: 15, - methodName: "", -#if NET9_0_OR_GREATER - X509CertificateLoader.LoadPkcs12FromFile(s_fullPathToPfx, "nopassword", X509KeyStorageFlags.UserKeySet), -#else - new X509Certificate2(s_fullPathToPfx, "nopassword", X509KeyStorageFlags.UserKeySet), -#endif - encryptionProtocols: connectionTestParameters.EncryptionProtocols, - encryptionType: connectionTestParameters.TdsEncryptionType); + using TdsServer server = new TdsServer(new TdsServerArguments + { + EncryptionCertificate = GetEncryptionCertificate(s_fullPathToPfx, "nopassword", X509KeyStorageFlags.UserKeySet), + EncryptionProtocols = connectionTestParameters.EncryptionProtocols, + Encryption = connectionTestParameters.TdsEncryptionType, + }); + + server.Start(); - builder = new(server.ConnectionString) + builder = new() { + DataSource = $"localhost,{server.EndPoint.Port}", + ConnectTimeout = 15, UserID = userId, Password = password, TrustServerCertificate = connectionTestParameters.TrustServerCertificate, @@ -231,6 +234,22 @@ private static void RunPowershellScript(string script) } } + /// + /// Loads the specified certificate. + /// + /// The full path of the certificate. + /// The certificate's password. + /// Key storage flags to apply when loading the certificate + /// An instance. + private X509Certificate2 GetEncryptionCertificate(string fileName, string? password, X509KeyStorageFlags keyStorageFlags) + { +#if NET9_0_OR_GREATER + return X509CertificateLoader.LoadPkcs12FromFile(fileName, password, keyStorageFlags); +#else + return new X509Certificate2(fileName, password, keyStorageFlags); +#endif + } + private void RemoveCertificate() { string thumbprint = File.ReadAllText(s_fullPathTothumbprint); @@ -249,7 +268,7 @@ private void RemoveCertificate() private static void RemoveForceEncryptionFromRegistryPath(string registryPath) { - RegistryKey key = Registry.LocalMachine.OpenSubKey(registryPath, true); + RegistryKey? key = Registry.LocalMachine.OpenSubKey(registryPath, true); key?.SetValue("ForceEncryption", 0, RegistryValueKind.DWord); key?.SetValue("Certificate", "", RegistryValueKind.String); ServiceController sc = new($"{s_instanceNamePrefix}{s_instanceName}"); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ExceptionTest/ConnectionExceptionTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ExceptionTest/ConnectionExceptionTest.cs index 6ee0681a0d..c44ed97ed0 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ExceptionTest/ConnectionExceptionTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ExceptionTest/ConnectionExceptionTest.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.SqlServer.TDS.Servers; using Xunit; namespace Microsoft.Data.SqlClient.ManualTesting.Tests @@ -23,8 +24,14 @@ public class ConnectionExceptionTest [ConditionalFact(nameof(IsNotKerberos))] public void TestConnectionStateWithErrorClass20() { - using TestTdsServer server = TestTdsServer.StartTestServer(); - using SqlConnection conn = new(server.ConnectionString); + using TdsServer server = new TdsServer(); + server.Start(); + using SqlConnection conn = new( + new SqlConnectionStringBuilder + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString); conn.Open(); SqlCommand cmd = conn.CreateCommand(); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/DiagnosticTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/DiagnosticTest.cs index 8ee8e28058..47e60a8db4 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/DiagnosticTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/DiagnosticTest.cs @@ -530,7 +530,7 @@ public void ConnectionOpenAsyncErrorTest() }).Dispose(); } - private static void CollectStatisticsDiagnostics(Action sqlOperation, bool enableServerLogging = false, [CallerMemberName] string methodName = "") + private static void CollectStatisticsDiagnostics(Action sqlOperation, [CallerMemberName] string methodName = "") { bool statsLogged = false; bool operationHasError = false; @@ -717,10 +717,19 @@ private static void CollectStatisticsDiagnostics(Action sqlOperation, bo { Console.WriteLine(string.Format("Test: {0} Enabled Listeners", methodName)); - using (var server = TestTdsServer.StartServerWithQueryEngine(new DiagnosticsQueryEngine(), enableLog: enableServerLogging, methodName: methodName)) + + using (var server = new TdsServer(new DiagnosticsQueryEngine(), new TdsServerArguments())) { + server.Start(methodName); Console.WriteLine(string.Format("Test: {0} Started Server", methodName)); - sqlOperation(server.ConnectionString); + + var connectionString = new SqlConnectionStringBuilder + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + + sqlOperation(connectionString); Console.WriteLine(string.Format("Test: {0} SqlOperation Successful", methodName)); @@ -906,11 +915,17 @@ private static async Task CollectStatisticsDiagnosticsAsync(Func s using (DiagnosticListener.AllListeners.Subscribe(diagnosticListenerObserver)) { Console.WriteLine(string.Format("Test: {0} Enabled Listeners", methodName)); - using (var server = TestTdsServer.StartServerWithQueryEngine(new DiagnosticsQueryEngine(), methodName: methodName)) + using (var server = new TdsServer(new DiagnosticsQueryEngine(), new TdsServerArguments())) { + server.Start(methodName); Console.WriteLine(string.Format("Test: {0} Started Server", methodName)); - await sqlOperation(server.ConnectionString); + var connectionString = new SqlConnectionStringBuilder + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + await sqlOperation(connectionString); Console.WriteLine(string.Format("Test: {0} SqlOperation Successful", methodName)); @@ -937,7 +952,7 @@ private static T GetPropertyValueFromType(object obj, string propName) public class DiagnosticsQueryEngine : QueryEngine { - public DiagnosticsQueryEngine() : base(new TDSServerArguments()) + public DiagnosticsQueryEngine() : base(new TdsServerArguments()) { } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs deleted file mode 100644 index 45a817c46e..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs +++ /dev/null @@ -1,93 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Linq; -using System.Net; -using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Security.Authentication; -using System.Security.Cryptography.X509Certificates; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.PreLogin; -using Microsoft.SqlServer.TDS.Servers; - -namespace Microsoft.Data.SqlClient.ManualTesting.Tests -{ - internal class TestTdsServer : GenericTDSServer, IDisposable - { - private const int DefaultConnectionTimeout = 5; - - private TDSServerEndPoint _endpoint = null; - - private SqlConnectionStringBuilder _connectionStringBuilder; - - public TestTdsServer(TDSServerArguments args) : base(args) { } - - public TestTdsServer(QueryEngine engine, TDSServerArguments args) : base(args) - { - Engine = engine; - } - - public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, - int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "", - X509Certificate2 encryptionCertificate = null, SslProtocols encryptionProtocols = SslProtocols.Tls12, TDSPreLoginTokenEncryptionType encryptionType = TDSPreLoginTokenEncryptionType.NotSupported) - { - TDSServerArguments args = new TDSServerArguments() - { - Log = enableLog ? Console.Out : null, - }; - - if (enableFedAuth) - { - args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired; - } - - args.EncryptionCertificate = encryptionCertificate; - args.EncryptionProtocols = encryptionProtocols; - args.Encryption = encryptionType; - - TestTdsServer server = engine == null ? new TestTdsServer(args) : new TestTdsServer(engine, args); - - server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; - server._endpoint.EndpointName = methodName; - // The server EventLog should be enabled as it logs the exceptions. - server._endpoint.EventLog = enableLog ? Console.Out : null; - server._endpoint.Start(); - - int port = server._endpoint.ServerEndPoint.Port; - - server._connectionStringBuilder = new SqlConnectionStringBuilder() - { - DataSource = "localhost," + port, - ConnectTimeout = connectionTimeout, - }; - - if (encryptionType == TDSPreLoginTokenEncryptionType.Off || - encryptionType == TDSPreLoginTokenEncryptionType.None || - encryptionType == TDSPreLoginTokenEncryptionType.NotSupported) - { - server._connectionStringBuilder.Encrypt = SqlConnectionEncryptOption.Optional; - } - else - { - server._connectionStringBuilder.Encrypt = SqlConnectionEncryptOption.Mandatory; - } - - server.ConnectionString = server._connectionStringBuilder.ConnectionString; - return server; - } - - public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, - int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "", - X509Certificate2 encryptionCertificate = null, SslProtocols encryptionProtocols = SslProtocols.Tls12, TDSPreLoginTokenEncryptionType encryptionType = TDSPreLoginTokenEncryptionType.NotSupported) - { - return StartServerWithQueryEngine(null, enableFedAuth, enableLog, connectionTimeout, methodName, encryptionCertificate, encryptionProtocols, encryptionType); - } - - public void Dispose() => _endpoint?.Stop(); - - public string ConnectionString { get; private set; } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/ADPHelper.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/ADPHelper.cs new file mode 100644 index 0000000000..d78c36f785 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/ADPHelper.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.Data.Common +{ + internal class ADPHelper : IDisposable + { + List _originalAzureSqlServerEndpoints; + + internal ADPHelper() + { + _originalAzureSqlServerEndpoints = [.. ADP.s_azureSqlServerEndpoints]; + } + + internal void AddAzureSqlServerEndpoint(string endpoint) + { + ADP.s_azureSqlServerEndpoints.Add(endpoint); + } + + public void Dispose() + { + ADP.s_azureSqlServerEndpoints.Clear(); + ADP.s_azureSqlServerEndpoints.AddRange(_originalAzureSqlServerEndpoints); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj index fcef431b36..16f09d82cc 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj @@ -11,6 +11,7 @@ + runtime; build; native; contentfiles; analyzers; buildtransitive @@ -27,6 +28,10 @@ + + + + diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/SqlConnectionStringTest.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/SqlConnectionStringTest.cs new file mode 100644 index 0000000000..14bc51d520 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft/Data/SqlClient/SqlConnectionStringTest.cs @@ -0,0 +1,68 @@ +using System; +using Microsoft.Data.SqlClient.Tests.Common; +using Xunit; +using static Microsoft.Data.SqlClient.Tests.Common.LocalAppContextSwitchesHelper; + +namespace Microsoft.Data.SqlClient.UnitTests.Microsoft.Data.SqlClient +{ + public class SqlConnectionStringTest : IDisposable + { + private LocalAppContextSwitchesHelper _appContextSwitchHelper; + public SqlConnectionStringTest() + { + // Ensure that the app context switch is set to the default value + _appContextSwitchHelper = new LocalAppContextSwitchesHelper(); + } + +#if NETFRAMEWORK + [Theory] + [InlineData("test.database.windows.net", true, Tristate.True, true)] + [InlineData("test.database.windows.net", false, Tristate.True, false)] + [InlineData("test.database.windows.net", null, Tristate.True, false)] + [InlineData("test.database.windows.net", true, Tristate.False, true)] + [InlineData("test.database.windows.net", false, Tristate.False, false)] + [InlineData("test.database.windows.net", null, Tristate.False, true)] + [InlineData("test.database.windows.net", true, Tristate.NotInitialized, true)] + [InlineData("test.database.windows.net", false, Tristate.NotInitialized, false)] + [InlineData("test.database.windows.net", null, Tristate.NotInitialized, true)] + [InlineData("my.test.server", true, Tristate.True, true)] + [InlineData("my.test.server", false, Tristate.True, false)] + [InlineData("my.test.server", null, Tristate.True, false)] + [InlineData("my.test.server", true, Tristate.False, true)] + [InlineData("my.test.server", false, Tristate.False, false)] + [InlineData("my.test.server", null, Tristate.False, true)] + [InlineData("my.test.server", true, Tristate.NotInitialized, true)] + [InlineData("my.test.server", false, Tristate.NotInitialized, false)] + [InlineData("my.test.server", null, Tristate.NotInitialized, true)] + public void TestDefaultTnir(string dataSource, bool? tnirEnabledInConnString, Tristate tnirDisabledAppContext, bool expectedValue) + { + // Note: TNIR is only supported on .NET Framework. + // Note: TNIR is disabled by default for Azure SQL Database servers (i.e. *.database.windows.net) + // and when using federated auth unless explicitly set in the connection string. + // However, this evaluation only happens at login time so TNIR behavior may not match + // the value of TransparentNetworkIPResolution property in SqlConnectionString. + + // Arrange + _appContextSwitchHelper.DisableTnirByDefaultField = tnirDisabledAppContext; + + // Act + SqlConnectionStringBuilder builder = new(); + builder.DataSource = dataSource; + if (tnirEnabledInConnString.HasValue) + { + builder.TransparentNetworkIPResolution = tnirEnabledInConnString.Value; + } + SqlConnectionString connectionString = new(builder.ConnectionString); + + // Assert + Assert.Equal(expectedValue, connectionString.TransparentNetworkIPResolution); + } +#endif + + public void Dispose() + { + // Clean up any resources if necessary + _appContextSwitchHelper.Dispose(); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs new file mode 100644 index 0000000000..3fa98d1e18 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs @@ -0,0 +1,525 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using Microsoft.SqlServer.TDS.Servers; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests +{ + [Trait("Category", "flaky")] + [Collection("SimulatedServerTests")] + public class ConnectionFailoverTests + { + //TODO parameterize for transient errors + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_NoFailover_DoesNotClearPool(uint errorCode) + { + // When connecting to a server with a configured failover partner, + // transient errors returned during the login ack should not clear the connection pool. + + // Arrange + using TdsServer failoverServer = new(new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234" + }); + failoverServer.Start(); + var failoverDataSource = $"localhost,{failoverServer.EndPoint.Port}"; + + // Errors are off to start to allow the pool to warm up + using TransientTdsErrorTdsServer initialServer = new(new TransientTdsErrorTdsServerArguments + { + FailoverPartner = failoverDataSource + }); + initialServer.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + initialServer.EndPoint.Port, + ConnectRetryInterval = 1, + ConnectTimeout = 30, + Encrypt = SqlConnectionEncryptOption.Optional, + InitialCatalog = "test" + }; + + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + + // Act + initialServer.SetErrorBehavior(true, errorCode); + using SqlConnection secondConnection = new(builder.ConnectionString); + // Should not trigger a failover, will retry against the same server + secondConnection.Open(); + + // Request a new connection, should initiate a fresh connection attempt if the pool was cleared. + connection.Close(); + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal(ConnectionState.Open, secondConnection.State); + Assert.Equal($"localhost,{initialServer.EndPoint.Port}", connection.DataSource); + Assert.Equal($"localhost,{initialServer.EndPoint.Port}", secondConnection.DataSource); + + // 1 for the initial connection, 2 for the second connection + Assert.Equal(3, initialServer.PreLoginCount); + // A failover should not be triggered, so prelogin count to the failover server should be 0 + Assert.Equal(0, failoverServer.PreLoginCount); + } + + [Fact] + public void NetworkError_TriggersFailover_ClearsPool() + { + // When connecting to a server with a configured failover partner, + // network errors returned during prelogin should clear the connection pool. + + // Arrange + using TdsServer failoverServer = new(new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234" + }); + failoverServer.Start(); + var failoverDataSource = $"localhost,{failoverServer.EndPoint.Port}"; + + // Errors are off to start to allow the pool to warm up + using TdsServer initialServer = new(new TdsServerArguments + { + FailoverPartner = failoverDataSource + }); + initialServer.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + initialServer.EndPoint.Port, + ConnectRetryInterval = 1, + ConnectTimeout = 30, + Encrypt = SqlConnectionEncryptOption.Optional, + InitialCatalog = "test", + MultiSubnetFailover = false, +#if NETFRAMEWORK + TransparentNetworkIPResolution = false, +#endif + }; + + // Open the initial connection to warm up the pool and populate failover partner information + // for the pool group. + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{initialServer.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, initialServer.PreLoginCount); + Assert.Equal(0, failoverServer.PreLoginCount); + + // Act + // Should trigger a failover because the initial server is unavailable + initialServer.Dispose(); + using SqlConnection secondConnection = new(builder.ConnectionString); + secondConnection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, secondConnection.State); + Assert.Equal($"localhost,{failoverServer.EndPoint.Port}", secondConnection.DataSource); + Assert.Equal(1, initialServer.PreLoginCount); + Assert.Equal(1, failoverServer.PreLoginCount); + + + // Act + // Request a new connection, should initiate a fresh connection attempt if the pool was cleared. + connection.Close(); + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{failoverServer.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, initialServer.PreLoginCount); + Assert.Equal(2, failoverServer.PreLoginCount); + } + + [Fact] + public void NetworkTimeout_ShouldFail() + { + using TdsServer failoverServer = new( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234", + }); + failoverServer.Start(); + + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(2000), + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + InitialCatalog = "master",// Required for failover partner to work + ConnectTimeout = 1, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + Encrypt = false, + MultiSubnetFailover = false, +#if NETFRAMEWORK + TransparentNetworkIPResolution = false, +#endif + }; + using SqlConnection connection = new(builder.ConnectionString); + + // Act + var e = Assert.Throws(() => connection.Open()); + + // Assert + Assert.Contains("Connection Timeout Expired", e.Message); + Assert.Equal(ConnectionState.Closed, connection.State); + Assert.Equal(1, server.PreLoginCount); + Assert.Equal(0, failoverServer.PreLoginCount); + } + + [Fact] + public void NetworkDelay_ShouldConnectToPrimary() + { + using TdsServer failoverServer = new( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234", + }); + failoverServer.Start(); + + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(1000), + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + InitialCatalog = "master",// Required for failover partner to work + ConnectTimeout = 5, + Encrypt = false, + MultiSubnetFailover = false, +#if NETFRAMEWORK + TransparentNetworkIPResolution = false, +#endif + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + // On the first connection attempt, no failover partner information is available, + // so the connection will retry on the same server. + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, server.PreLoginCount); + Assert.Equal(0, failoverServer.PreLoginCount); + } + + [Fact] + public void NetworkError_WithUserProvidedPartner_RetryDisabled_ShouldConnectToFailoverPartner() + { + using TdsServer failoverServer = new( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234", + }); + failoverServer.Start(); + + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(10000), + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + InitialCatalog = "master", // Required for failover partner to work + ConnectTimeout = 5, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", // User provided failover partner + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + // On the first connection attempt, failover partner information is available in the connection string, + // so the connection will retry on the failover server. + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{failoverServer.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, failoverServer.PreLoginCount); + Assert.Equal(1, server.PreLoginCount); + } + + [Fact] + public void NetworkError_WithUserProvidedPartner_RetryEnabled_ShouldConnectToFailoverPartner() + { + using TdsServer failoverServer = new( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234", + }); + failoverServer.Start(); + + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(10000), + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + InitialCatalog = "master", // Required for failover partner to work + ConnectTimeout = 5, + ConnectRetryInterval = 1, + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", // User provided failover partner + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + // Act + connection.Open(); + + // Assert + // On the first connection attempt, failover partner information is available in the connection string, + // so the connection will retry on the failover server. + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{failoverServer.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, server.PreLoginCount); + Assert.Equal(1, failoverServer.PreLoginCount); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_ShouldConnectToPrimary(uint errorCode) + { + // Arrange + using TdsServer failoverServer = new( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost:1234", + }); + failoverServer.Start(); + + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = $"localhost,{server.EndPoint.Port}", + InitialCatalog = "master", + ConnectTimeout = 30, + ConnectRetryInterval = 1, + Encrypt = false + }; + using SqlConnection connection = new(builder.ConnectionString); + + // Act + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + + // Failures should prompt the client to return to the original server, resulting in a login count of 2 + Assert.Equal(2, server.PreLoginCount); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_RetryDisabled_ShouldFail(uint errorCode) + { + // Arrange + using TdsServer failoverServer = new( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost:1234", + }); + failoverServer.Start(); + + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = $"localhost,{server.EndPoint.Port}", + InitialCatalog = "master", + ConnectTimeout = 30, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + Encrypt = false + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (SqlException e) + { + Assert.Equal((int)errorCode, e.Number); + return; + } + + Assert.Fail(); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_WithUserProvidedPartner_ShouldConnectToPrimary(uint errorCode) + { + // Arrange + using TdsServer failoverServer = new( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost:1234", + }); + failoverServer.Start(); + + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = $"localhost,{server.EndPoint.Port}", + InitialCatalog = "master", + ConnectTimeout = 30, + ConnectRetryInterval = 1, + Encrypt = false, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", // User provided failover partner + }; + using SqlConnection connection = new(builder.ConnectionString); + + // Act + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + + // Failures should prompt the client to return to the original server, resulting in a login count of 2 + Assert.Equal(2, server.PreLoginCount); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_WithUserProvidedPartner_RetryDisabled_ShouldFail(uint errorCode) + { + // Arrange + using TdsServer failoverServer = new( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost:1234", + }); + failoverServer.Start(); + + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = $"localhost,{server.EndPoint.Port}", + InitialCatalog = "master", + ConnectTimeout = 30, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + Encrypt = false, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", // User provided failover partner + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (SqlException e) + { + Assert.Equal((int)errorCode, e.Number); + return; + } + + Assert.Fail(); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs new file mode 100644 index 0000000000..f0618ac269 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs @@ -0,0 +1,156 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Threading.Tasks; +using Microsoft.SqlServer.TDS.Servers; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests +{ + [Collection("SimulatedServerTests")] + public class ConnectionReadOnlyRoutingTests + { + [Fact] + public void NonRoutedConnection() + { + using TdsServer server = new(); + server.Start(); + SqlConnectionStringBuilder builder = new() { + DataSource = $"localhost,{server.EndPoint.Port}", + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = SqlConnectionEncryptOption.Optional + }; + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + } + + [Fact] + public async Task NonRoutedAsyncConnection() + { + using TdsServer server = new(); + server.Start(); + SqlConnectionStringBuilder builder = new() { + DataSource = $"localhost,{server.EndPoint.Port}", + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = SqlConnectionEncryptOption.Optional + }; + using SqlConnection connection = new(builder.ConnectionString); + await connection.OpenAsync(); + } + + [Fact] + public void RoutedConnection() => RecursivelyRoutedConnection(1); + + [Fact] + public async Task RoutedAsyncConnection() => await RecursivelyRoutedAsyncConnection(1); + + [Theory] + [InlineData(11)] // 11 layers of routing should succeed, 12 should fail + public void RecursivelyRoutedConnection(int layers) + { + using TdsServer innerServer = new(); + innerServer.Start(); + IPEndPoint lastEndpoint = innerServer.EndPoint; + Stack routingLayers = new(layers + 1); + string lastConnectionString = (new SqlConnectionStringBuilder() { DataSource = $"localhost,{lastEndpoint.Port}" }).ConnectionString; + + try + { + for (int i = 0; i < layers; i++) + { + RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)lastEndpoint.Port, + }); + router.Start(); + routingLayers.Push(router); + lastEndpoint = router.EndPoint; + lastConnectionString = (new SqlConnectionStringBuilder() { + DataSource = $"localhost,{lastEndpoint.Port}", + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = false + }).ConnectionString; + } + + SqlConnectionStringBuilder builder = new(lastConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + } + finally + { + while (routingLayers.Count > 0) + { + routingLayers.Pop().Dispose(); + } + } + } + + [Theory] + [InlineData(11)] // 11 layers of routing should succeed, 12 should fail + public async Task RecursivelyRoutedAsyncConnection(int layers) + { + using TdsServer innerServer = new(); + innerServer.Start(); + IPEndPoint lastEndpoint = innerServer.EndPoint; + Stack routingLayers = new(layers + 1); + string lastConnectionString = (new SqlConnectionStringBuilder() { DataSource = $"localhost,{lastEndpoint.Port}" }).ConnectionString; + + try + { + for (int i = 0; i < layers; i++) + { + RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)lastEndpoint.Port, + }); + router.Start(); + routingLayers.Push(router); + lastEndpoint = router.EndPoint; + lastConnectionString = (new SqlConnectionStringBuilder() { + DataSource = $"localhost,{lastEndpoint.Port}", + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = false + }).ConnectionString; + } + + SqlConnectionStringBuilder builder = new(lastConnectionString) { + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = false + }; + using SqlConnection connection = new(builder.ConnectionString); + await connection.OpenAsync(); + } + finally + { + while (routingLayers.Count > 0) + { + routingLayers.Pop().Dispose(); + } + } + } + + [Fact] + public void ConnectionRoutingLimit() + { + SqlException sqlEx = Assert.Throws(() => RecursivelyRoutedConnection(12)); // This will fail on the 11th redirect + + Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task AsyncConnectionRoutingLimit() + { + SqlException sqlEx = await Assert.ThrowsAsync(() => RecursivelyRoutedAsyncConnection(12)); // This will fail on the 11th redirect + + Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs new file mode 100644 index 0000000000..108118dda7 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs @@ -0,0 +1,202 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using Microsoft.Data.Common; +using Microsoft.SqlServer.TDS.Servers; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests +{ + [Trait("Category", "flaky")] + [Collection("SimulatedServerTests")] + public class ConnectionRoutingTests + { + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFaultAtRoutedLocation_ShouldReturnToGateway(uint errorCode) + { + // Arrange + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + + server.Start(); + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + //RoutingTCPHost = server.EndPoint.Address.ToString() == IPAddress.Any.ToString() ? IPAddress.Loopback.ToString() : server.EndPoint.Address.ToString(), + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 30, + ConnectRetryInterval = 1, + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + + // Act + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + // Routing does not update the connection's data source + Assert.Equal($"localhost,{router.EndPoint.Port}", connection.DataSource); + + // Failures should prompt the client to return to the original server, resulting in a login count of 2 + Assert.Equal(2, router.PreLoginCount); + Assert.Equal(2, server.PreLoginCount); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFaultAtRoutedLocation_RetryDisabled_ShouldFail(uint errorCode) + { + // Arrange + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + + server.Start(); + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + //RoutingTCPHost = server.EndPoint.Address.ToString() == IPAddress.Any.ToString() ? IPAddress.Loopback.ToString() : server.EndPoint.Address.ToString(), + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 30, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + + //Act and Assert + Assert.Throws(() => connection.Open()); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void NetworkDelayAtRoutedLocation_RetryDisabled_ShouldSucceed(bool multiSubnetFailoverEnabled) + { + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(1000), + }); + + server.Start(); + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 5, + ConnectRetryCount = 0, // disable retry + Encrypt = false, + MultiSubnetFailover = multiSubnetFailoverEnabled, +#if NETFRAMEWORK + TransparentNetworkIPResolution = multiSubnetFailoverEnabled, +#endif + }; + using SqlConnection connection = new(builder.ConnectionString); + + // Act + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{router.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, router.PreLoginCount); + if (multiSubnetFailoverEnabled) + { + Assert.True(server.PreLoginCount > 1); + } + else + { + Assert.Equal(1, server.PreLoginCount); + } + } + + [Fact] + public void NetworkTimeoutAtRoutedLocation_RetryDisabled_ShouldFail() + { + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(2000), + }); + + server.Start(); + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 1, + ConnectRetryCount = 0, // disable retry + Encrypt = false, + MultiSubnetFailover = false, +#if NETFRAMEWORK + TransparentNetworkIPResolution = false +#endif + }; + using SqlConnection connection = new(builder.ConnectionString); + + // Act + var e = Assert.Throws(connection.Open); + + // Assert + Assert.Equal(ConnectionState.Closed, connection.State); + Assert.Contains("Connection Timeout Expired", e.Message); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTestsAzure.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTestsAzure.cs new file mode 100644 index 0000000000..dd945e37f3 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTestsAzure.cs @@ -0,0 +1,200 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using Microsoft.Data.Common; +using Microsoft.SqlServer.TDS.Servers; +using Xunit; + +namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests +{ + [Trait("Category", "flaky")] + [Collection("SimulatedServerTests")] + public class ConnectionRoutingTestsAzure : IDisposable + { + private ADPHelper adpHelper; + + public ConnectionRoutingTestsAzure() + { + adpHelper = new ADPHelper(); + adpHelper.AddAzureSqlServerEndpoint("localhost"); + } + + public void Dispose() + { + adpHelper.Dispose(); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFaultAtRoutedLocation_ShouldReturnToGateway(uint errorCode) + { + // Arrange + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + + server.Start(); + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 30, + ConnectRetryInterval = 1, + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + // Routing does not update the connection's data source + Assert.Equal($"localhost,{router.EndPoint.Port}", connection.DataSource); + + // Failures should prompt the client to return to the original server, resulting in a login count of 2 + Assert.Equal(2, router.PreLoginCount); + Assert.Equal(2, server.PreLoginCount); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFaultAtRoutedLocation_RetryDisabled_ShouldFail(uint errorCode) + { + // Arrange + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + + server.Start(); + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 30, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + Assert.Throws(() => connection.Open()); + } + + [Fact] + public void NetworkDelayAtRoutedLocation_RetryDisabled_ShouldSucceed() + { + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(1000), + }); + + server.Start(); + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 5, + ConnectRetryCount = 0, // disable retry + Encrypt = false + }; + using SqlConnection connection = new(builder.ConnectionString); + + // Act + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{router.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, router.PreLoginCount); + Assert.Equal(1, server.PreLoginCount); + } + + [Fact] + public void NetworkTimeoutAtRoutedLocation_RetryDisabled_ShouldFail() + { + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(2000), + }); + + server.Start(); + + using RoutingTdsServer router = new( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 1, + ConnectRetryCount = 0, // disable retry + Encrypt = false + }; + using SqlConnection connection = new(builder.ConnectionString); + + // Act + var e = Assert.Throws(connection.Open); + + // Assert + Assert.Equal(ConnectionState.Closed, connection.State); + Assert.Contains("Connection Timeout Expired", e.Message); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionTests.cs similarity index 63% rename from src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs rename to src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionTests.cs index 616a8fec6f..7fd18395e0 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionTests.cs @@ -9,7 +9,6 @@ using System.Globalization; using System.Linq; using System.Reflection; -using System.Runtime.InteropServices; using System.Security; using System.Threading; using System.Threading.Tasks; @@ -20,26 +19,36 @@ using Microsoft.SqlServer.TDS.Servers; using Xunit; -namespace Microsoft.Data.SqlClient.Tests +namespace Microsoft.Data.SqlClient.UnitTests.SimulatedServerTests { - public class SqlConnectionBasicTests + public class ConnectionTests { [Fact] public void ConnectionTest() { - using TestTdsServer server = TestTdsServer.StartTestServer(); - using SqlConnection connection = new SqlConnection(server.ConnectionString); + using TdsServer server = new(new TdsServerArguments() { }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + using SqlConnection connection = new(connStr); connection.Open(); } - [ConditionalFact(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Fact] [PlatformSpecific(TestPlatforms.Windows)] public void IntegratedAuthConnectionTest() { - using TestTdsServer server = TestTdsServer.StartTestServer(); - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(server.ConnectionString); + using TdsServer server = new(new TdsServerArguments() { }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + SqlConnectionStringBuilder builder = new(connStr); builder.IntegratedSecurity = true; - using SqlConnection connection = new SqlConnection(builder.ConnectionString); + using SqlConnection connection = new(builder.ConnectionString); connection.Open(); } @@ -49,117 +58,264 @@ public void IntegratedAuthConnectionTest() /// when client enables encryption using Encrypt=true or uses default encryption setting. /// [Fact] - public async Task PreLoginEncryptionExcludedTest() + public async Task RequestEncryption_ServerDoesNotSupportEncryption_ShouldFail() { - using TestTdsServer server = TestTdsServer.StartTestServer(false, false, 5, excludeEncryption: true); - SqlConnectionStringBuilder builder = new(server.ConnectionString) - { - IntegratedSecurity = true - }; + using TdsServer server = new(new TdsServerArguments() {Encryption = TDSPreLoginTokenEncryptionType.None }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}" + }.ConnectionString; - using SqlConnection connection = new(builder.ConnectionString); + using SqlConnection connection = new(connStr); Exception ex = await Assert.ThrowsAsync(async () => await connection.OpenAsync()); Assert.Contains("The instance of SQL Server you attempted to connect to does not support encryption.", ex.Message, StringComparison.OrdinalIgnoreCase); } - [ConditionalTheory(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Trait("Category", "flaky")] + [Theory] [InlineData(40613)] [InlineData(42108)] [InlineData(42109)] - [PlatformSpecific(TestPlatforms.Windows)] - public async Task TransientFaultTestAsync(uint errorCode) + public async Task TransientFault_RetryEnabled_ShouldSucceed_Async(uint errorCode) { - using TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, false, errorCode); + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + server.Start(); SqlConnectionStringBuilder builder = new() { - DataSource = "localhost," + server.Port, - IntegratedSecurity = true, + DataSource = "localhost," + server.EndPoint.Port, Encrypt = SqlConnectionEncryptOption.Optional }; using SqlConnection connection = new(builder.ConnectionString); await connection.OpenAsync(); Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + Assert.Equal(2, server.PreLoginCount); } - [ConditionalTheory(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Trait("Category", "flaky")] + [Theory] [InlineData(40613)] [InlineData(42108)] [InlineData(42109)] - [PlatformSpecific(TestPlatforms.Windows)] - public void TransientFaultTest(uint errorCode) + public void TransientFault_RetryEnabled_ShouldSucceed(uint errorCode) { - using TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, false, errorCode); + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + server.Start(); SqlConnectionStringBuilder builder = new() { - DataSource = "localhost," + server.Port, - IntegratedSecurity = true, + DataSource = "localhost," + server.EndPoint.Port, Encrypt = SqlConnectionEncryptOption.Optional }; using SqlConnection connection = new(builder.ConnectionString); - try - { - connection.Open(); - Assert.Equal(ConnectionState.Open, connection.State); - } - catch (Exception e) - { - Assert.Fail(e.Message); - } + connection.Open(); + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + Assert.Equal(2, server.PreLoginCount); } - [ConditionalTheory(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Trait("Category", "flaky")] + [Theory] [InlineData(40613)] [InlineData(42108)] [InlineData(42109)] - [PlatformSpecific(TestPlatforms.Windows)] - public void TransientFaultDisabledTestAsync(uint errorCode) + public async Task TransientFault_RetryDisabled_ShouldFail_Async(uint errorCode) { - using TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, false, errorCode); + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + server.Start(); SqlConnectionStringBuilder builder = new() { - DataSource = "localhost," + server.Port, - IntegratedSecurity = true, + DataSource = "localhost," + server.EndPoint.Port, ConnectRetryCount = 0, Encrypt = SqlConnectionEncryptOption.Optional }; using SqlConnection connection = new(builder.ConnectionString); - Task e = Assert.ThrowsAsync(async () => await connection.OpenAsync()); - Assert.Equal(20, e.Result.Class); + SqlException e = await Assert.ThrowsAsync(async () => await connection.OpenAsync()); + Assert.Equal((int)errorCode, e.Number); Assert.Equal(ConnectionState.Closed, connection.State); + Assert.Equal(1, server.PreLoginCount); } - [ConditionalTheory(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Trait("Category", "flaky")] + [Theory] [InlineData(40613)] [InlineData(42108)] [InlineData(42109)] - [PlatformSpecific(TestPlatforms.Windows)] - public void TransientFaultDisabledTest(uint errorCode) + public void TransientFault_RetryDisabled_ShouldFail(uint errorCode) { - using TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, false, errorCode); + using TransientTdsErrorTdsServer server = new( + new TransientTdsErrorTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + server.Start(); SqlConnectionStringBuilder builder = new() { - DataSource = "localhost," + server.Port, - IntegratedSecurity = true, + DataSource = "localhost," + server.EndPoint.Port, ConnectRetryCount = 0, Encrypt = SqlConnectionEncryptOption.Optional }; using SqlConnection connection = new(builder.ConnectionString); SqlException e = Assert.Throws(() => connection.Open()); - Assert.Equal(20, e.Class); + Assert.Equal((int)errorCode, e.Number); Assert.Equal(ConnectionState.Closed, connection.State); + Assert.Equal(1, server.PreLoginCount); + } + + [Trait("Category", "flaky")] + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task NetworkError_RetryEnabled_ShouldSucceed_Async(bool multiSubnetFailoverEnabled) + { + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(1000), + }); + server.Start(); + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + Encrypt = SqlConnectionEncryptOption.Optional, + ConnectTimeout = 5, + MultiSubnetFailover = multiSubnetFailoverEnabled, +#if NETFRAMEWORK + TransparentNetworkIPResolution = multiSubnetFailoverEnabled +#endif + }; + + using SqlConnection connection = new(builder.ConnectionString); + await connection.OpenAsync(); + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + if (multiSubnetFailoverEnabled) + { + Assert.True(server.PreLoginCount > 1, "Expected multiple pre-login attempts due to retry."); + } + else + { + Assert.Equal(1, server.PreLoginCount); + } + } + + [Trait("Category", "flaky")] + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task NetworkDelay_RetryDisabled_Async(bool multiSubnetFailoverEnabled) + { + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(1000), + }); + server.Start(); + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + ConnectTimeout = 5, + ConnectRetryCount = 0, + Encrypt = SqlConnectionEncryptOption.Optional, + MultiSubnetFailover = multiSubnetFailoverEnabled, +#if NETFRAMEWORK + TransparentNetworkIPResolution = multiSubnetFailoverEnabled, +#endif + }; + + using SqlConnection connection = new(builder.ConnectionString); + + // Act + await connection.OpenAsync(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + + if (multiSubnetFailoverEnabled) + { + Assert.True(server.PreLoginCount > 1, "Expected multiple pre-login attempts due to retry."); + } + else + { + Assert.Equal(1, server.PreLoginCount); + } + } + + [Trait("Category", "flaky")] + [Theory] + [InlineData(true)] + [InlineData(false)] + public void NetworkDelay_RetryDisabled(bool multiSubnetFailoverEnabled) + { + // Arrange + using TransientDelayTdsServer server = new( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientDelay = true, + DelayDuration = TimeSpan.FromMilliseconds(1000), + }); + server.Start(); + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + ConnectRetryCount = 0, + Encrypt = SqlConnectionEncryptOption.Optional, + ConnectTimeout = 5, + MultiSubnetFailover = multiSubnetFailoverEnabled, +#if NETFRAMEWORK + TransparentNetworkIPResolution = multiSubnetFailoverEnabled, +#endif + }; + + using SqlConnection connection = new(builder.ConnectionString); + + // Act + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + + if (multiSubnetFailoverEnabled) + { + Assert.True(server.PreLoginCount > 1, "Expected multiple pre-login attempts due to retry."); + } + else + { + Assert.Equal(1, server.PreLoginCount); + } } [Fact] public void SqlConnectionDbProviderFactoryTest() { SqlConnection con = new(); - PropertyInfo dbProviderFactoryProperty = con.GetType().GetProperty("DbProviderFactory", BindingFlags.NonPublic | BindingFlags.Instance); + PropertyInfo? dbProviderFactoryProperty = con.GetType().GetProperty("DbProviderFactory", BindingFlags.NonPublic | BindingFlags.Instance); Assert.NotNull(dbProviderFactoryProperty); - DbProviderFactory factory = dbProviderFactoryProperty.GetValue(con) as DbProviderFactory; + DbProviderFactory? factory = dbProviderFactoryProperty.GetValue(con) as DbProviderFactory; Assert.NotNull(factory); Assert.Same(typeof(SqlClientFactory), factory.GetType()); Assert.Same(SqlClientFactory.Instance, factory); @@ -206,7 +362,7 @@ public void ClosedConnectionSchemaRetrieval() [InlineData("RandomStringForTargetServer", true, false)] [InlineData(null, false, false)] [InlineData("", false, false)] - public void RetrieveWorkstationId(string workstation, bool withDispose, bool shouldMatchSetWorkstationId) + public void RetrieveWorkstationId(string? workstation, bool withDispose, bool shouldMatchSetWorkstationId) { string connectionString = $"Workstation Id={workstation}"; SqlConnection conn = new(connectionString); @@ -214,7 +370,7 @@ public void RetrieveWorkstationId(string workstation, bool withDispose, bool sho { conn.Dispose(); } - string expected = shouldMatchSetWorkstationId ? workstation : Environment.MachineName; + string? expected = shouldMatchSetWorkstationId ? workstation : Environment.MachineName; Assert.Equal(expected, conn.WorkstationId); } @@ -302,23 +458,27 @@ public void ConnectionTestValidCredentialCombination() [Theory] [InlineData(60)] - [InlineData(30)] - [InlineData(15)] [InlineData(10)] - [InlineData(5)] [InlineData(1)] public void ConnectionTimeoutTest(int timeout) { // Start a server with connection timeout from the inline data. - using TestTdsServer server = TestTdsServer.StartTestServer(false, false, timeout); - using SqlConnection connection = new SqlConnection(server.ConnectionString); + //TODO: do we even need a server for this test? + using TdsServer server = new(); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}", + ConnectTimeout = timeout, + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + using SqlConnection connection = new(connStr); // Dispose the server to force connection timeout server.Dispose(); // Measure the actual time it took to timeout and compare it with configured timeout Stopwatch timer = new(); - Exception ex = null; + Exception? ex = null; // Open a connection with the server disposed. try @@ -341,29 +501,34 @@ public void ConnectionTimeoutTest(int timeout) [Theory] [InlineData(60)] - [InlineData(30)] - [InlineData(15)] [InlineData(10)] - [InlineData(5)] [InlineData(1)] public async Task ConnectionTimeoutTestAsync(int timeout) { // Start a server with connection timeout from the inline data. - using TestTdsServer server = TestTdsServer.StartTestServer(false, false, timeout); - using SqlConnection connection = new SqlConnection(server.ConnectionString); + //TODO: do we even need a server for this test? + using TdsServer server = new(); + server.Start(); + var connStr = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + ConnectTimeout = timeout, + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + using SqlConnection connection = new(connStr); // Dispose the server to force connection timeout server.Dispose(); // Measure the actual time it took to timeout and compare it with configured timeout Stopwatch timer = new(); - Exception ex = null; + Exception? ex = null; // Open a connection with the server disposed. try { //an asyn call with a timeout token to cancel the operation after the specific time - using CancellationTokenSource cts = new CancellationTokenSource(timeout * 1000); + using CancellationTokenSource cts = new(timeout * 1000); timer.Start(); await connection.OpenAsync(cts.Token).ConfigureAwait(false); } @@ -385,7 +550,11 @@ public void ConnectionInvalidTimeoutTest() { Assert.Throws(() => { - using TestTdsServer server = TestTdsServer.StartTestServer(false, false, -5); + var connectionString = new SqlConnectionStringBuilder() + { + DataSource = "localhost", + ConnectTimeout = -5 // Invalid timeout + }.ConnectionString; }); } @@ -401,8 +570,15 @@ public void ConnectionTestWithCultureTH() Thread.CurrentThread.CurrentCulture = new CultureInfo("th-TH"); Thread.CurrentThread.CurrentUICulture = new CultureInfo("th-TH"); - using TestTdsServer server = TestTdsServer.StartTestServer(); - using SqlConnection connection = new SqlConnection(server.ConnectionString); + //TODO: do we even need a server for this test? + using TdsServer server = new(); + server.Start(); + var connStr = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + using SqlConnection connection = new(connStr); connection.Open(); Assert.Equal(ConnectionState.Open, connection.State); } @@ -504,9 +680,20 @@ public void ConnectionTestAccessTokenCallbackCombinations() [InlineData(11, 0, 3000)] // SQL Server 2012-2022 public void ConnectionTestPermittedVersion(int major, int minor, int build) { - Version simulatedServerVersion = new Version(major, minor, build); - using TestTdsServer server = TestTdsServer.StartTestServer(serverVersion: simulatedServerVersion); - using SqlConnection conn = new SqlConnection(server.ConnectionString); + Version simulatedServerVersion = new(major, minor, build); + + using TdsServer server = new( + new TdsServerArguments + { + ServerVersion = simulatedServerVersion, + }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + using SqlConnection conn = new(connStr); conn.Open(); Assert.Equal(ConnectionState.Open, conn.State); @@ -522,9 +709,19 @@ public void ConnectionTestPermittedVersion(int major, int minor, int build) [InlineData(8, 0, 384)] // SQL Server 2000 SP1 public void ConnectionTestDeniedVersion(int major, int minor, int build) { - Version simulatedServerVersion = new Version(major, minor, build); - using TestTdsServer server = TestTdsServer.StartTestServer(serverVersion: simulatedServerVersion); - using SqlConnection conn = new SqlConnection(server.ConnectionString); + Version simulatedServerVersion = new(major, minor, build); + using TdsServer server = new( + new TdsServerArguments + { + ServerVersion = simulatedServerVersion, + }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + using SqlConnection conn = new(connStr); Assert.Throws(() => conn.Open()); } @@ -542,7 +739,8 @@ public void ConnectionTestDeniedVersion(int major, int minor, int build) public void TestConnWithVectorFeatExtVersionNegotiation(bool expectedConnectionResult, byte serverVersion, byte expectedNegotiatedVersion) { // Start the test TDS server. - using var server = TestTdsServer.StartTestServer(); + using var server = new TdsServer(); + server.Start(); server.ServerSupportedVectorFeatureExtVersion = serverVersion; server.EnableVectorFeatureExt = serverVersion == 0xFF ? false : true; @@ -594,7 +792,12 @@ public void TestConnWithVectorFeatExtVersionNegotiation(bool expectedConnectionR }; // Connect to the test TDS server. - using var connection = new SqlConnection(server.ConnectionString); + var connStr = new SqlConnectionStringBuilder + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + using var connection = new SqlConnection(connStr); if (expectedConnectionResult) { connection.Open(); diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPoint.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPoint.cs index e81139c63a..349b23ceee 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPoint.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPoint.cs @@ -30,7 +30,8 @@ public override TDSServerEndPointConnection CreateConnection(TcpClient newConnec /// /// General server handler /// - public abstract class ServerEndPointHandler where T : ServerEndPointConnection + public abstract class ServerEndPointHandler : IDisposable + where T : ServerEndPointConnection { /// /// Gets/Sets the event log for the proxy server @@ -100,21 +101,6 @@ public void Start() // Update ServerEndpoint with the actual address/port, e.g. if port=0 was given ServerEndPoint = (IPEndPoint)ListenerSocket.LocalEndpoint; - Log($"{GetType().Name} {EndpointName} Is Server Socket Bound: {ListenerSocket.Server.IsBound} Testing connectivity to the endpoint created for the server."); - using (TcpClient client = new TcpClient()) - { - try - { - client.Connect("localhost", ServerEndPoint.Port); - } - catch (Exception e) - { - Log($"{GetType().Name} {EndpointName} Error occurred while testing server endpoint {e.Message}"); - throw; - } - } - Log($"{GetType().Name} {EndpointName} Endpoint test successful."); - // Initialize the listener ListenerThread = new Thread(new ThreadStart(_RequestListener)) { IsBackground = true }; ListenerThread.Name = "TDS Server EndPoint Listener"; @@ -148,7 +134,7 @@ public void Stop() foreach (T connection in unlockedConnections) { // Request to stop - connection.Stop(); + connection.Dispose(); } // If server failed to start there is no thread to join @@ -167,6 +153,12 @@ public void Stop() } } + public void Dispose() + { + // Stop the listener + Stop(); + } + /// /// Processes all incoming requests /// diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPointConnection.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPointConnection.cs index 6327189691..3d24f9c397 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPointConnection.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPointConnection.cs @@ -8,6 +8,7 @@ using System.Net; using System.Net.Sockets; using System.Threading; +using System.Threading.Tasks; namespace Microsoft.SqlServer.TDS.EndPoint { @@ -44,12 +45,12 @@ public override void ProcessData(Stream rawStream) /// /// Connection to a single client /// - public abstract class ServerEndPointConnection + public abstract class ServerEndPointConnection : IDisposable { /// /// Worker thread /// - protected Thread ProcessorThread { get; set; } + protected Task ProcessorTask { get; set; } /// /// Gets/Sets the event log for the proxy server @@ -76,11 +77,6 @@ public abstract class ServerEndPointConnection /// protected TcpClient Connection { get; set; } - /// - /// The flag indicates whether server is being stopped - /// - protected bool StopRequested { get; set; } - /// /// Initialization constructor /// @@ -124,29 +120,8 @@ public ServerEndPointConnection(ITDSServer server, TcpClient connection) /// internal void Start() { - // Start with active connection - StopRequested = false; - // Prepare and start a thread - ProcessorThread = new Thread(new ThreadStart(_ConnectionHandler)) { IsBackground = true }; - ProcessorThread.Name = string.Format("TDS Server Connection {0} Thread", Connection.Client.RemoteEndPoint); - ProcessorThread.Start(); - } - - /// - /// Stop the connection - /// - internal void Stop() - { - // Request the listener thread to stop - StopRequested = true; - - // If connection failed to start there's no processor thread - if (ProcessorThread != null) - { - // Wait for termination - ProcessorThread.Join(); - } + ProcessorTask = RunConnectionHandler(); } /// @@ -159,10 +134,27 @@ internal void Stop() /// public abstract void ProcessData(Stream rawStream); + public void Dispose() + { + if (Connection != null) + { + Connection.Close(); + Connection.Dispose(); + Connection = null; + } + + // TODO: there's a deadlock condition when awaiting the processor task + // only dispose of it if it's already completed + if (ProcessorTask.Status == TaskStatus.RanToCompletion) + { + ProcessorTask.Dispose(); + } + } + /// /// Worker thread /// - private void _ConnectionHandler() + private async Task RunConnectionHandler() { try { @@ -171,7 +163,7 @@ private void _ConnectionHandler() PrepareForProcessingData(rawStream); // Process the packet sequence - while (Connection.Connected && !StopRequested) + while (Connection.Connected) { // Check incoming buffer if (rawStream.DataAvailable) @@ -187,7 +179,7 @@ private void _ConnectionHandler() } // Sleep a bit to reduce load on CPU - Thread.Sleep(10); + await Task.Delay(10); } } } @@ -212,6 +204,8 @@ private void _ConnectionHandler() { OnConnectionClosed(this, null); } + + return; } /// diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServer.cs deleted file mode 100644 index 06261e2c8f..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServer.cs +++ /dev/null @@ -1,220 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using Microsoft.SqlServer.TDS.Done; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.Error; -using Microsoft.SqlServer.TDS.Login7; - -namespace Microsoft.SqlServer.TDS.Servers -{ - /// - /// TDS Server that authenticates clients according to the requested parameters - /// - public class AuthenticatingTDSServer : GenericTDSServer - { - /// - /// Initialization constructor - /// - public AuthenticatingTDSServer() : - this(new AuthenticatingTDSServerArguments()) - { - } - - /// - /// Initialization constructor - /// - public AuthenticatingTDSServer(AuthenticatingTDSServerArguments arguments) : - base(arguments) - { - } - - /// - /// Handler for login request - /// - public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) - { - // Inflate login7 request from the message - TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; - - // Check if arguments are of the authenticating TDS server - if (Arguments is AuthenticatingTDSServerArguments) - { - // Cast to authenticating TDS server arguments - AuthenticatingTDSServerArguments ServerArguments = Arguments as AuthenticatingTDSServerArguments; - - // Check if we're still processing normal login - if (ServerArguments.ApplicationIntentFilter != ApplicationIntentFilterType.All) - { - // Check filter - if ((ServerArguments.ApplicationIntentFilter == ApplicationIntentFilterType.ReadOnly && loginRequest.TypeFlags.ReadOnlyIntent != TDSLogin7TypeFlagsReadOnlyIntent.ReadOnly) - || (ServerArguments.ApplicationIntentFilter == ApplicationIntentFilterType.None)) - { - // Log request to which we're about to send a failure - TDSUtilities.Log(Arguments.Log, "Request", loginRequest); - - // Prepare ERROR token with the denial details - TDSErrorToken errorToken = new TDSErrorToken(18456, 1, 14, "Received application intent: " + loginRequest.TypeFlags.ReadOnlyIntent.ToString(), Arguments.ServerName); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the error token into the response packet - TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); - - // Prepare ERROR token for the final decision - errorToken = new TDSErrorToken(18456, 1, 14, "Connection is denied by application intent filter", Arguments.ServerName); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the error token into the response packet - responseMessage.Add(errorToken); - - // Create DONE token - TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", doneToken); - - // Serialize DONE token into the response packet - responseMessage.Add(doneToken); - - // Put a single message into the collection and return it - return new TDSMessageCollection(responseMessage); - } - } - - // Check if we're still processing normal login and there's a filter to check - if (ServerArguments.ServerNameFilterType != ServerNameFilterType.None) - { - // Check each algorithm - if ((ServerArguments.ServerNameFilterType == ServerNameFilterType.Equals && string.Compare(ServerArguments.ServerNameFilter, loginRequest.ServerName, true) != 0) - || (ServerArguments.ServerNameFilterType == ServerNameFilterType.StartsWith && !loginRequest.ServerName.StartsWith(ServerArguments.ServerNameFilter)) - || (ServerArguments.ServerNameFilterType == ServerNameFilterType.EndsWith && !loginRequest.ServerName.EndsWith(ServerArguments.ServerNameFilter)) - || (ServerArguments.ServerNameFilterType == ServerNameFilterType.Contains && !loginRequest.ServerName.Contains(ServerArguments.ServerNameFilter))) - { - // Log request to which we're about to send a failure - TDSUtilities.Log(Arguments.Log, "Request", loginRequest); - - // Prepare ERROR token with the reason - TDSErrorToken errorToken = new TDSErrorToken(18456, 1, 14, string.Format("Received server name \"{0}\", expected \"{1}\" using \"{2}\" algorithm", loginRequest.ServerName, ServerArguments.ServerNameFilter, ServerArguments.ServerNameFilterType), Arguments.ServerName); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the errorToken token into the response packet - TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); - - // Prepare ERROR token with the final errorToken - errorToken = new TDSErrorToken(18456, 1, 14, "Connection is denied by server name filter", Arguments.ServerName); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the errorToken token into the response packet - responseMessage.Add(errorToken); - - // Create DONE token - TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", doneToken); - - // Serialize DONE token into the response packet - responseMessage.Add(doneToken); - - // Return only a single message with the collection - return new TDSMessageCollection(responseMessage); - } - } - - // Check if packet size filter is applied - if (ServerArguments.PacketSizeFilter != null) - { - // Check if requested packet size is the same as the filter specified - if (loginRequest.PacketSize != ServerArguments.PacketSizeFilter.Value) - { - // Log request to which we're about to send a failure - TDSUtilities.Log(Arguments.Log, "Request", loginRequest); - - // Prepare ERROR token with the reason - TDSErrorToken errorToken = new TDSErrorToken(1919, 1, 14, string.Format("Received packet size \"{0}\", expected \"{1}\"", loginRequest.PacketSize, ServerArguments.PacketSizeFilter.Value), Arguments.ServerName); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the errorToken token into the response packet - TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); - - // Prepare ERROR token with the final errorToken - errorToken = new TDSErrorToken(1919, 1, 14, "Connection is denied by packet size filter", Arguments.ServerName); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the errorToken token into the response packet - responseMessage.Add(errorToken); - - // Create DONE token - TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", doneToken); - - // Serialize DONE token into the response packet - responseMessage.Add(doneToken); - - // Return only a single message with the collection - return new TDSMessageCollection(responseMessage); - } - } - - // If we have an application name filter - if (ServerArguments.ApplicationNameFilter != null) - { - // If we are supposed to block this connection attempt - if (loginRequest.ApplicationName.Equals(ServerArguments.ApplicationNameFilter, System.StringComparison.OrdinalIgnoreCase)) - { - // Log request to which we're about to send a failure - TDSUtilities.Log(Arguments.Log, "Request", loginRequest); - - // Prepare ERROR token with the denial details - TDSErrorToken errorToken = new TDSErrorToken(18456, 1, 14, "Received application name: " + loginRequest.ApplicationName, Arguments.ServerName); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the error token into the response packet - TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); - - // Prepare ERROR token for the final decision - errorToken = new TDSErrorToken(18456, 1, 14, "Connection is denied by application name filter", Arguments.ServerName); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the error token into the response packet - responseMessage.Add(errorToken); - - // Create DONE token - TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", doneToken); - - // Serialize DONE token into the response packet - responseMessage.Add(doneToken); - - // Put a single message into the collection and return it - return new TDSMessageCollection(responseMessage); - } - } - } - - // Return login response from the base class - return base.OnLogin7Request(session, request); - } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServer.cs new file mode 100644 index 0000000000..4db5439845 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServer.cs @@ -0,0 +1,213 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.SqlServer.TDS.Done; +using Microsoft.SqlServer.TDS.EndPoint; +using Microsoft.SqlServer.TDS.Error; +using Microsoft.SqlServer.TDS.Login7; + +namespace Microsoft.SqlServer.TDS.Servers +{ + /// + /// TDS Server that authenticates clients according to the requested parameters + /// + public class AuthenticatingTdsServer : GenericTdsServer + { + /// + /// Initialization constructor + /// + public AuthenticatingTdsServer() + : this(new AuthenticatingTdsServerArguments()) + { + } + + /// + /// Initialization constructor + /// + public AuthenticatingTdsServer(AuthenticatingTdsServerArguments arguments) : + base(arguments) + { + } + + /// + /// Handler for login request + /// + public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) + { + // Inflate login7 request from the message + TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; + + // Check if we're still processing normal login + if (Arguments.ApplicationIntentFilter != ApplicationIntentFilterType.All) + { + // Check filter + if ((Arguments.ApplicationIntentFilter == ApplicationIntentFilterType.ReadOnly && loginRequest.TypeFlags.ReadOnlyIntent != TDSLogin7TypeFlagsReadOnlyIntent.ReadOnly) + || (Arguments.ApplicationIntentFilter == ApplicationIntentFilterType.None)) + { + // Log request to which we're about to send a failure + TDSUtilities.Log(Arguments.Log, "Request", loginRequest); + + // Prepare ERROR token with the denial details + TDSErrorToken errorToken = new TDSErrorToken(18456, 1, 14, "Received application intent: " + loginRequest.TypeFlags.ReadOnlyIntent.ToString(), Arguments.ServerName); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the error token into the response packet + TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); + + // Prepare ERROR token for the final decision + errorToken = new TDSErrorToken(18456, 1, 14, "Connection is denied by application intent filter", Arguments.ServerName); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the error token into the response packet + responseMessage.Add(errorToken); + + // Create DONE token + TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", doneToken); + + // Serialize DONE token into the response packet + responseMessage.Add(doneToken); + + // Put a single message into the collection and return it + return new TDSMessageCollection(responseMessage); + } + } + + // Check if we're still processing normal login and there's a filter to check + if (Arguments.ServerNameFilterType != ServerNameFilterType.None) + { + // Check each algorithm + if ((Arguments.ServerNameFilterType == ServerNameFilterType.Equals && string.Compare(Arguments.ServerNameFilter, loginRequest.ServerName, true) != 0) + || (Arguments.ServerNameFilterType == ServerNameFilterType.StartsWith && !loginRequest.ServerName.StartsWith(Arguments.ServerNameFilter)) + || (Arguments.ServerNameFilterType == ServerNameFilterType.EndsWith && !loginRequest.ServerName.EndsWith(Arguments.ServerNameFilter)) + || (Arguments.ServerNameFilterType == ServerNameFilterType.Contains && !loginRequest.ServerName.Contains(Arguments.ServerNameFilter))) + { + // Log request to which we're about to send a failure + TDSUtilities.Log(Arguments.Log, "Request", loginRequest); + + // Prepare ERROR token with the reason + TDSErrorToken errorToken = new TDSErrorToken(18456, 1, 14, string.Format("Received server name \"{0}\", expected \"{1}\" using \"{2}\" algorithm", loginRequest.ServerName, Arguments.ServerNameFilter, Arguments.ServerNameFilterType), Arguments.ServerName); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the errorToken token into the response packet + TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); + + // Prepare ERROR token with the final errorToken + errorToken = new TDSErrorToken(18456, 1, 14, "Connection is denied by server name filter", Arguments.ServerName); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the errorToken token into the response packet + responseMessage.Add(errorToken); + + // Create DONE token + TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", doneToken); + + // Serialize DONE token into the response packet + responseMessage.Add(doneToken); + + // Return only a single message with the collection + return new TDSMessageCollection(responseMessage); + } + } + + // Check if packet size filter is applied + if (Arguments.PacketSizeFilter != null) + { + // Check if requested packet size is the same as the filter specified + if (loginRequest.PacketSize != Arguments.PacketSizeFilter.Value) + { + // Log request to which we're about to send a failure + TDSUtilities.Log(Arguments.Log, "Request", loginRequest); + + // Prepare ERROR token with the reason + TDSErrorToken errorToken = new TDSErrorToken(1919, 1, 14, string.Format("Received packet size \"{0}\", expected \"{1}\"", loginRequest.PacketSize, Arguments.PacketSizeFilter.Value), Arguments.ServerName); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the errorToken token into the response packet + TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); + + // Prepare ERROR token with the final errorToken + errorToken = new TDSErrorToken(1919, 1, 14, "Connection is denied by packet size filter", Arguments.ServerName); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the errorToken token into the response packet + responseMessage.Add(errorToken); + + // Create DONE token + TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", doneToken); + + // Serialize DONE token into the response packet + responseMessage.Add(doneToken); + + // Return only a single message with the collection + return new TDSMessageCollection(responseMessage); + } + } + + // If we have an application name filter + if (Arguments.ApplicationNameFilter != null) + { + // If we are supposed to block this connection attempt + if (loginRequest.ApplicationName.Equals(Arguments.ApplicationNameFilter, System.StringComparison.OrdinalIgnoreCase)) + { + // Log request to which we're about to send a failure + TDSUtilities.Log(Arguments.Log, "Request", loginRequest); + + // Prepare ERROR token with the denial details + TDSErrorToken errorToken = new TDSErrorToken(18456, 1, 14, "Received application name: " + loginRequest.ApplicationName, Arguments.ServerName); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the error token into the response packet + TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); + + // Prepare ERROR token for the final decision + errorToken = new TDSErrorToken(18456, 1, 14, "Connection is denied by application name filter", Arguments.ServerName); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the error token into the response packet + responseMessage.Add(errorToken); + + // Create DONE token + TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", doneToken); + + // Serialize DONE token into the response packet + responseMessage.Add(doneToken); + + // Put a single message into the collection and return it + return new TDSMessageCollection(responseMessage); + } + } + + // Return login response from the base class + return base.OnLogin7Request(session, request); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServerArguments.cs similarity index 58% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServerArguments.cs index dcb812a648..d9f36a5a49 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServerArguments.cs @@ -7,43 +7,31 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Arguments for authenticating TDS Server /// - public class AuthenticatingTDSServerArguments : TDSServerArguments + public class AuthenticatingTdsServerArguments : TdsServerArguments { /// /// Type of the application intent filter /// - public ApplicationIntentFilterType ApplicationIntentFilter { get; set; } + public ApplicationIntentFilterType ApplicationIntentFilter { get; set; } = ApplicationIntentFilterType.All; /// /// Filter for server name /// - public string ServerNameFilter { get; set; } + public string ServerNameFilter { get; set; } = string.Empty; /// /// Type of the filtering algorithm to use /// - public ServerNameFilterType ServerNameFilterType { get; set; } + public ServerNameFilterType ServerNameFilterType { get; set; } = ServerNameFilterType.None; /// /// TDS packet size filtering /// - public ushort? PacketSizeFilter { get; set; } + public ushort? PacketSizeFilter { get; set; } = null; /// /// Filter for application name /// - public string ApplicationNameFilter { get; set; } - - /// - /// Initialization constructor - /// - public AuthenticatingTDSServerArguments() - { - // Allow everyone to connect - ApplicationIntentFilter = ApplicationIntentFilterType.All; - - // By default we don't turn on server name filter - ServerNameFilterType = Servers.ServerNameFilterType.None; - } + public string ApplicationNameFilter { get; set; } = string.Empty; } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServer.cs deleted file mode 100644 index 40d4791f13..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServer.cs +++ /dev/null @@ -1,148 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System.Linq; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.FeatureExtAck; -using Microsoft.SqlServer.TDS.PreLogin; - -namespace Microsoft.SqlServer.TDS.Servers -{ - /// - /// TDS Server that generates invalid TDS scenarios according to the requested parameters - /// - public class FederatedAuthenticationNegativeTDSServer : GenericTDSServer - { - /// - /// Initialization constructor - /// - public FederatedAuthenticationNegativeTDSServer() : - this(new FederatedAuthenticationNegativeTDSServerArguments()) - { - } - - /// - /// Initialization constructor - /// - public FederatedAuthenticationNegativeTDSServer(FederatedAuthenticationNegativeTDSServerArguments arguments) : - base(arguments) - { - } - - /// - /// Handler for login request - /// - public override TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, TDSMessage request) - { - // Get the collection from a valid On PreLogin Request - TDSMessageCollection preLoginCollection = base.OnPreLoginRequest(session, request); - - // Check if arguments are of the Federated Authentication server - if (Arguments is FederatedAuthenticationNegativeTDSServerArguments) - { - // Cast to federated authentication server arguments - FederatedAuthenticationNegativeTDSServerArguments ServerArguments = Arguments as FederatedAuthenticationNegativeTDSServerArguments; - - // Find the is token carrying on TDSPreLoginToken - TDSPreLoginToken preLoginToken = preLoginCollection.Find(message => message.Exists(packetToken => packetToken is TDSPreLoginToken)). - Find(packetToken => packetToken is TDSPreLoginToken) as TDSPreLoginToken; - - switch (ServerArguments.Scenario) - { - case FederatedAuthenticationNegativeTDSScenarioType.NonceMissingInFedAuthPreLogin: - { - // If we have the prelogin token - if (preLoginToken != null && preLoginToken.Nonce != null) - { - // Nullify the nonce from the Token - preLoginToken.Nonce = null; - } - - break; - } - - case FederatedAuthenticationNegativeTDSScenarioType.InvalidB_FEDAUTHREQUIREDResponse: - { - // If we have the prelogin token - if (preLoginToken != null) - { - // Set an illegal value for B_FEDAUTHREQURED - preLoginToken.FedAuthRequired = TdsPreLoginFedAuthRequiredOption.Illegal; - } - - break; - } - } - } - - // Return the collection - return preLoginCollection; - } - - /// - /// Handler for login request - /// - public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) - { - // Get the collection from the normal behavior On Login7 Request - TDSMessageCollection login7Collection = base.OnLogin7Request(session, request); - - // Check if arguments are of the Federated Authentication server - if (Arguments is FederatedAuthenticationNegativeTDSServerArguments) - { - // Cast to federated authentication server arguments - FederatedAuthenticationNegativeTDSServerArguments ServerArguments = Arguments as FederatedAuthenticationNegativeTDSServerArguments; - - // Get the Federated Authentication ExtAck from Login 7 - TDSFeatureExtAckFederatedAuthenticationOption fedAutExtAct = GetFeatureExtAckFederatedAuthenticationOptionFromLogin7(login7Collection); - - // If not found, return the base collection intact - if (fedAutExtAct != null) - { - switch (ServerArguments.Scenario) - { - case FederatedAuthenticationNegativeTDSScenarioType.NonceMissingInFedAuthFEATUREXTACK: - { - // Delete the nonce from the Token - fedAutExtAct.ClientNonce = null; - - break; - } - case FederatedAuthenticationNegativeTDSScenarioType.FedAuthMissingInFEATUREEXTACK: - { - // Remove the Fed Auth Ext Ack from the options list in the FeatureExtAckToken - GetFeatureExtAckTokenFromLogin7(login7Collection).Options.Remove(fedAutExtAct); - - break; - } - case FederatedAuthenticationNegativeTDSScenarioType.SignatureMissingInFedAuthFEATUREXTACK: - { - // Delete the signature from the Token - fedAutExtAct.Signature = null; - - break; - } - } - } - } - - // Return the collection - return login7Collection; - } - - private TDSFeatureExtAckToken GetFeatureExtAckTokenFromLogin7(TDSMessageCollection login7Collection) - { - // Find the is token carrying on TDSFeatureExtAckToken - return login7Collection.Find(m => m.Exists(p => p is TDSFeatureExtAckToken)). - Find(t => t is TDSFeatureExtAckToken) as TDSFeatureExtAckToken; - } - - private TDSFeatureExtAckFederatedAuthenticationOption GetFeatureExtAckFederatedAuthenticationOptionFromLogin7(TDSMessageCollection login7Collection) - { - // Get the Fed Auth Ext Ack from the list of options in the feature ExtAck - return GetFeatureExtAckTokenFromLogin7(login7Collection).Options. - Where(o => o is TDSFeatureExtAckFederatedAuthenticationOption).FirstOrDefault() as TDSFeatureExtAckFederatedAuthenticationOption; - } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSScenarioType.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsScenarioType.cs similarity index 94% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSScenarioType.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsScenarioType.cs index 11baa170d6..f35f69c22d 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSScenarioType.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsScenarioType.cs @@ -4,7 +4,7 @@ namespace Microsoft.SqlServer.TDS.Servers { - public enum FederatedAuthenticationNegativeTDSScenarioType : int + public enum FederatedAuthenticationNegativeTdsScenarioType : int { /// /// Valid Scenario. Do not perform negative activity. diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServer.cs new file mode 100644 index 0000000000..53a68c70ea --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServer.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Linq; +using Microsoft.SqlServer.TDS.EndPoint; +using Microsoft.SqlServer.TDS.FeatureExtAck; +using Microsoft.SqlServer.TDS.PreLogin; + +namespace Microsoft.SqlServer.TDS.Servers +{ + /// + /// TDS Server that generates invalid TDS scenarios according to the requested parameters + /// + public class FederatedAuthenticationNegativeTdsServer : GenericTdsServer + { + /// + /// Initialization constructor + /// + public FederatedAuthenticationNegativeTdsServer() : + this(new FederatedAuthenticationNegativeTdsServerArguments()) + { + } + + /// + /// Initialization constructor + /// + public FederatedAuthenticationNegativeTdsServer(FederatedAuthenticationNegativeTdsServerArguments arguments) : + base(arguments) + { + } + + /// + /// Handler for login request + /// + public override TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, TDSMessage request) + { + // Get the collection from a valid On PreLogin Request + TDSMessageCollection preLoginCollection = base.OnPreLoginRequest(session, request); + + // Find the is token carrying on TDSPreLoginToken + TDSPreLoginToken preLoginToken = preLoginCollection.Find(message => message.Exists(packetToken => packetToken is TDSPreLoginToken)). + Find(packetToken => packetToken is TDSPreLoginToken) as TDSPreLoginToken; + + switch (Arguments.Scenario) + { + case FederatedAuthenticationNegativeTdsScenarioType.NonceMissingInFedAuthPreLogin: + { + // If we have the prelogin token + if (preLoginToken != null && preLoginToken.Nonce != null) + { + // Nullify the nonce from the Token + preLoginToken.Nonce = null; + } + + break; + } + + case FederatedAuthenticationNegativeTdsScenarioType.InvalidB_FEDAUTHREQUIREDResponse: + { + // If we have the prelogin token + if (preLoginToken != null) + { + // Set an illegal value for B_FEDAUTHREQURED + preLoginToken.FedAuthRequired = TdsPreLoginFedAuthRequiredOption.Illegal; + } + + break; + } + } + + // Return the collection + return preLoginCollection; + } + + /// + /// Handler for login request + /// + public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) + { + // Get the collection from the normal behavior On Login7 Request + TDSMessageCollection login7Collection = base.OnLogin7Request(session, request); + + // Get the Federated Authentication ExtAck from Login 7 + TDSFeatureExtAckFederatedAuthenticationOption fedAutExtAct = GetFeatureExtAckFederatedAuthenticationOptionFromLogin7(login7Collection); + + // If not found, return the base collection intact + if (fedAutExtAct != null) + { + switch (Arguments.Scenario) + { + case FederatedAuthenticationNegativeTdsScenarioType.NonceMissingInFedAuthFEATUREXTACK: + { + // Delete the nonce from the Token + fedAutExtAct.ClientNonce = null; + + break; + } + case FederatedAuthenticationNegativeTdsScenarioType.FedAuthMissingInFEATUREEXTACK: + { + // Remove the Fed Auth Ext Ack from the options list in the FeatureExtAckToken + GetFeatureExtAckTokenFromLogin7(login7Collection).Options.Remove(fedAutExtAct); + + break; + } + case FederatedAuthenticationNegativeTdsScenarioType.SignatureMissingInFedAuthFEATUREXTACK: + { + // Delete the signature from the Token + fedAutExtAct.Signature = null; + + break; + } + } + } + + // Return the collection + return login7Collection; + } + + private TDSFeatureExtAckToken GetFeatureExtAckTokenFromLogin7(TDSMessageCollection login7Collection) + { + // Find the is token carrying on TDSFeatureExtAckToken + return login7Collection.Find(m => m.Exists(p => p is TDSFeatureExtAckToken)). + Find(t => t is TDSFeatureExtAckToken) as TDSFeatureExtAckToken; + } + + private TDSFeatureExtAckFederatedAuthenticationOption GetFeatureExtAckFederatedAuthenticationOptionFromLogin7(TDSMessageCollection login7Collection) + { + // Get the Fed Auth Ext Ack from the list of options in the feature ExtAck + return GetFeatureExtAckTokenFromLogin7(login7Collection).Options. + Where(o => o is TDSFeatureExtAckFederatedAuthenticationOption).FirstOrDefault() as TDSFeatureExtAckFederatedAuthenticationOption; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServerArguments.cs similarity index 56% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServerArguments.cs index 67143d645b..19fd43aab5 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServerArguments.cs @@ -7,18 +7,11 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Arguments for Fed Auth Negative TDS Server /// - public class FederatedAuthenticationNegativeTDSServerArguments : TDSServerArguments + public class FederatedAuthenticationNegativeTdsServerArguments : TdsServerArguments { /// /// Type of the Fed Auth Negative TDS Server /// - public FederatedAuthenticationNegativeTDSScenarioType Scenario { get; set; } - - /// - /// Initialization constructor - /// - public FederatedAuthenticationNegativeTDSServerArguments() - { - } + public FederatedAuthenticationNegativeTdsScenarioType Scenario { get; set; } = FederatedAuthenticationNegativeTdsScenarioType.ValidScenario; } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs similarity index 92% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs index ac04fd2f57..c17726fb8b 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs @@ -4,6 +4,8 @@ using System; using System.Linq; +using System.Net; +using System.Runtime.CompilerServices; using System.Security.Cryptography; using System.Threading; using Microsoft.SqlServer.TDS.Authentication; @@ -25,7 +27,8 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Generic TDS server without specialization /// - public class GenericTDSServer : ITDSServer + public abstract class GenericTdsServer : ITDSServer, IDisposable + where T : TdsServerArguments { /// /// Delegate to be called when a LOGIN7 request has been received and is @@ -34,7 +37,6 @@ public class GenericTDSServer : ITDSServer /// public delegate void OnLogin7ValidatedDelegate( TDSLogin7Token login7Token); - public OnLogin7ValidatedDelegate OnLogin7Validated { private get; set; } /// /// Delegate to be called when authentication is completed and TDSResponse @@ -42,23 +44,12 @@ public delegate void OnLogin7ValidatedDelegate( /// public delegate void OnAuthenticationCompletedDelegate( TDSMessage response); - public OnAuthenticationCompletedDelegate OnAuthenticationResponseCompleted { private get; set; } /// /// Default feature extension version supported on the server for vector support. /// public const byte DefaultSupportedVectorFeatureExtVersion = 0x01; - /// - /// Property for setting server version for vector feature extension. - /// - public bool EnableVectorFeatureExt { get; set; } = false; - - /// - /// Property for setting server version for vector feature extension. - /// - public byte ServerSupportedVectorFeatureExtVersion { get; set; } = DefaultSupportedVectorFeatureExtVersion; - /// /// Client version for vector FeatureExtension. /// @@ -70,27 +61,16 @@ public delegate void OnAuthenticationCompletedDelegate( private int _sessionCount = 0; /// - /// Server configuration + /// Counts pre-login requests to the server. /// - protected TDSServerArguments Arguments { get; set; } + private int _preLoginCount = 0; - /// - /// Query engine instance - /// - protected QueryEngine Engine { get; set; } - - /// - /// Default constructor - /// - public GenericTDSServer() : - this(new TDSServerArguments()) - { - } + private TDSServerEndPoint _endpoint; /// /// Initialization constructor /// - public GenericTDSServer(TDSServerArguments arguments) : + public GenericTdsServer(T arguments) : this(arguments, new QueryEngine(arguments)) { } @@ -98,7 +78,7 @@ public GenericTDSServer(TDSServerArguments arguments) : /// /// Initialization constructor /// - public GenericTDSServer(TDSServerArguments arguments, QueryEngine queryEngine) + public GenericTdsServer(T arguments, QueryEngine queryEngine) { // Save arguments Arguments = arguments; @@ -110,6 +90,50 @@ public GenericTDSServer(TDSServerArguments arguments, QueryEngine queryEngine) Engine.Log = Arguments.Log; } + public IPEndPoint EndPoint => _endpoint.ServerEndPoint; + + /// + /// Server configuration + /// + protected T Arguments { get; set; } + + /// + /// Query engine instance + /// + protected QueryEngine Engine { get; set; } + + /// + /// Counts pre-login requests to the server. + /// + public int PreLoginCount => _preLoginCount; + + /// + /// Property for setting server version for vector feature extension. + /// + public bool EnableVectorFeatureExt { get; set; } = false; + + /// + /// Property for setting server version for vector feature extension. + /// + public byte ServerSupportedVectorFeatureExtVersion { get; set; } = DefaultSupportedVectorFeatureExtVersion; + + public OnAuthenticationCompletedDelegate OnAuthenticationResponseCompleted { private get; set; } + + public OnLogin7ValidatedDelegate OnLogin7Validated { private get; set; } + + + public void Start([CallerMemberName] string methodName = "") + { + if (_endpoint != null) + { + throw new InvalidOperationException("Server is already started"); + } + _endpoint = new TDSServerEndPoint(this) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; + _endpoint.EndpointName = methodName; + _endpoint.EventLog = Arguments.Log; + _endpoint.Start(); + } + /// /// Create a new session on the server /// @@ -120,7 +144,7 @@ public virtual ITDSServerSession OpenSession() Interlocked.Increment(ref _sessionCount); // Create a new session - GenericTDSServerSession session = new GenericTDSServerSession(this, (uint)_sessionCount); + GenericTdsServerSession session = new GenericTdsServerSession(this, (uint)_sessionCount); // Use configured encryption certificate and protocols session.EncryptionCertificate = Arguments.EncryptionCertificate; @@ -142,8 +166,11 @@ public virtual void CloseSession(ITDSServerSession session) /// public virtual TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, TDSMessage request) { + Interlocked.Increment(ref _preLoginCount); + // Inflate pre-login request from the message TDSPreLoginToken preLoginRequest = request[0] as TDSPreLoginToken; + GenericTdsServerSession genericTdsServerSession = session as GenericTdsServerSession; // Log request TDSUtilities.Log(Arguments.Log, "Request", preLoginRequest); @@ -158,7 +185,7 @@ public virtual TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, TDSPreLoginToken preLoginToken = new TDSPreLoginToken(Arguments.ServerVersion, serverResponse, false); // TDS server doesn't support MARS // Cache the received Nonce into the session - (session as GenericTDSServerSession).ClientNonce = preLoginRequest.Nonce; + genericTdsServerSession.ClientNonce = preLoginRequest.Nonce; // Check if the server has been started up as requiring FedAuth when choosing between SSPI and FedAuth if (Arguments.FedAuthRequiredPreLoginOption == TdsPreLoginFedAuthRequiredOption.FedAuthRequired) @@ -170,7 +197,7 @@ public virtual TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, } // Keep the federated authentication required flag in the server session - (session as GenericTDSServerSession).FedAuthRequiredPreLoginServerResponse = preLoginToken.FedAuthRequired; + genericTdsServerSession.FedAuthRequiredPreLoginServerResponse = preLoginToken.FedAuthRequired; if (preLoginRequest.Nonce != null) { @@ -180,7 +207,7 @@ public virtual TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, } // Cache the server Nonce in a session - (session as GenericTDSServerSession).ServerNonce = preLoginToken.Nonce; + genericTdsServerSession.ServerNonce = preLoginToken.Nonce; // Log response TDSUtilities.Log(Arguments.Log, "Response", preLoginToken); @@ -244,7 +271,7 @@ public virtual TDSMessageCollection OnLogin7Request(ITDSServerSession session, T TDSLogin7SessionRecoveryOptionToken sessionStateOption = option as TDSLogin7SessionRecoveryOptionToken; // Inflate session state - (session as GenericTDSServerSession).Inflate(sessionStateOption.Initial, sessionStateOption.Current); + (session as GenericTdsServerSession).Inflate(sessionStateOption.Initial, sessionStateOption.Current); break; } @@ -266,7 +293,7 @@ public virtual TDSMessageCollection OnLogin7Request(ITDSServerSession session, T } // Save the fed auth library to be used - (session as GenericTDSServerSession).FederatedAuthenticationLibrary = federatedAuthenticationOption.Library; + (session as GenericTdsServerSession).FederatedAuthenticationLibrary = federatedAuthenticationOption.Library; break; } @@ -542,7 +569,7 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi responseMessage.Add(infoToken); // Create new collation change token - envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.SQLCollation, (session as GenericTDSServerSession).Collation); + envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.SQLCollation, (session as GenericTdsServerSession).Collation); // Log response TDSUtilities.Log(Arguments.Log, "Response", envChange); @@ -551,7 +578,7 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi responseMessage.Add(envChange); // Create new language change token - envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.Language, LanguageString.ToString((session as GenericTDSServerSession).Language)); + envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.Language, LanguageString.ToString((session as GenericTdsServerSession).Language)); // Log response TDSUtilities.Log(Arguments.Log, "Response", envChange); @@ -593,7 +620,7 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi if (session.IsSessionRecoveryEnabled) { // Create Feature extension Ack token - TDSFeatureExtAckToken featureExtActToken = new TDSFeatureExtAckToken(new TDSFeatureExtAckSessionStateOption((session as GenericTDSServerSession).Deflate())); + TDSFeatureExtAckToken featureExtActToken = new TDSFeatureExtAckToken(new TDSFeatureExtAckSessionStateOption((session as GenericTdsServerSession).Deflate())); // Log response TDSUtilities.Log(Arguments.Log, "Response", featureExtActToken); @@ -654,6 +681,16 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi } } + if (!string.IsNullOrEmpty(Arguments.FailoverPartner)) + { + envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.RealTimeLogShipping, Arguments.FailoverPartner); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", envChange); + + responseMessage.Add(envChange); + } + // Create DONE token TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final); @@ -688,7 +725,7 @@ protected virtual TDSMessageCollection OnFederatedAuthenticationCompleted(ITDSSe try { // Get the Federated Authentication ticket using RPS - decryptedTicket = FederatedAuthenticationTicketService.DecryptTicket((session as GenericTDSServerSession).FederatedAuthenticationLibrary, ticket); + decryptedTicket = FederatedAuthenticationTicketService.DecryptTicket((session as GenericTdsServerSession).FederatedAuthenticationLibrary, ticket); if (decryptedTicket is RpsTicket) { @@ -719,17 +756,17 @@ protected virtual TDSMessageCollection OnFederatedAuthenticationCompleted(ITDSSe // Create federated authentication extension option TDSFeatureExtAckFederatedAuthenticationOption federatedAuthenticationOption; - if ((session as GenericTDSServerSession).FederatedAuthenticationLibrary == TDSFedAuthLibraryType.MSAL) + if ((session as GenericTdsServerSession).FederatedAuthenticationLibrary == TDSFedAuthLibraryType.MSAL) { // For the time being, fake fedauth tokens are used for ADAL, so decryptedTicket is null. federatedAuthenticationOption = - new TDSFeatureExtAckFederatedAuthenticationOption((session as GenericTDSServerSession).ClientNonce, null); + new TDSFeatureExtAckFederatedAuthenticationOption((session as GenericTdsServerSession).ClientNonce, null); } else { federatedAuthenticationOption = - new TDSFeatureExtAckFederatedAuthenticationOption((session as GenericTDSServerSession).ClientNonce, - decryptedTicket.GetSignature((session as GenericTDSServerSession).ClientNonce)); + new TDSFeatureExtAckFederatedAuthenticationOption((session as GenericTdsServerSession).ClientNonce, + decryptedTicket.GetSignature((session as GenericTdsServerSession).ClientNonce)); } // Look for feature extension token @@ -764,12 +801,12 @@ protected virtual TDSMessageCollection OnFederatedAuthenticationCompleted(ITDSSe protected virtual TDSMessageCollection CheckFederatedAuthenticationOption(ITDSServerSession session, TDSLogin7FedAuthOptionToken federatedAuthenticationOption) { // Check if server's prelogin response for FedAuthRequired prelogin option is echoed back correctly in FedAuth Feature Extenion Echo - if (federatedAuthenticationOption.Echo != (session as GenericTDSServerSession).FedAuthRequiredPreLoginServerResponse) + if (federatedAuthenticationOption.Echo != (session as GenericTdsServerSession).FedAuthRequiredPreLoginServerResponse) { // Create Error message string message = string.Format("FEDAUTHREQUIRED option in the prelogin response is not echoed back correctly: in prelogin response, it is {0} and in login, it is {1}: ", - (session as GenericTDSServerSession).FedAuthRequiredPreLoginServerResponse, + (session as GenericTdsServerSession).FedAuthRequiredPreLoginServerResponse, federatedAuthenticationOption.Echo); // Create errorToken token @@ -790,7 +827,7 @@ protected virtual TDSMessageCollection CheckFederatedAuthenticationOption(ITDSSe // Check if the nonce exists if ((federatedAuthenticationOption.Nonce == null && federatedAuthenticationOption.Library == TDSFedAuthLibraryType.IDCRL) - || !AreEqual((session as GenericTDSServerSession).ServerNonce, federatedAuthenticationOption.Nonce)) + || !AreEqual((session as GenericTdsServerSession).ServerNonce, federatedAuthenticationOption.Nonce)) { // Error message string message = string.Format("Unexpected NONCEOPT specified in the Federated authentication feature extension"); @@ -880,5 +917,11 @@ private bool AreEqual(byte[] left, byte[] right) return left.SequenceEqual(right); } + + public virtual void Dispose() + { + _endpoint?.Dispose(); + _endpoint = null; + } } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs similarity index 99% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs index e9e65d5f8f..2730fa02df 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs @@ -17,7 +17,7 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Generic session for TDS Server /// - public class GenericTDSServerSession : ITDSServerSession + public class GenericTdsServerSession : ITDSServerSession { /// /// Server that created the session @@ -259,7 +259,7 @@ public bool AnsiDefaults /// /// Initialization constructor /// - public GenericTDSServerSession(ITDSServer server, uint sessionID) : + public GenericTdsServerSession(ITDSServer server, uint sessionID) : this(server, sessionID, 4096) { } @@ -267,7 +267,7 @@ public GenericTDSServerSession(ITDSServer server, uint sessionID) : /// /// Initialization constructor /// - public GenericTDSServerSession(ITDSServer server, uint sessionID, uint packetSize) + public GenericTdsServerSession(ITDSServer server, uint sessionID, uint packetSize) { // Save the server Server = server; diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/QueryEngine.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/QueryEngine.cs index eb219f5dbc..579c47abcc 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/QueryEngine.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/QueryEngine.cs @@ -26,12 +26,12 @@ public class QueryEngine /// /// Server configuration /// - public TDSServerArguments ServerArguments { get; private set; } + public TdsServerArguments ServerArguments { get; private set; } /// /// Initialization constructor /// - public QueryEngine(TDSServerArguments arguments) + public QueryEngine(TdsServerArguments arguments) { ServerArguments = arguments; } @@ -1308,7 +1308,7 @@ private TDSMessage _PrepareAnsiDefaultsResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiDefaults); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiDefaults); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1347,7 +1347,7 @@ private TDSMessage _PrepareAnsiNullDefaultOnResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiNullDefaultOn); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiNullDefaultOn); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1386,7 +1386,7 @@ private TDSMessage _PrepareAnsiNullsResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiNulls); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiNulls); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1425,7 +1425,7 @@ private TDSMessage _PrepareAnsiPaddingResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiPadding); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiPadding); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1464,7 +1464,7 @@ private TDSMessage _PrepareAnsiWarningsResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiWarnings); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiWarnings); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1503,7 +1503,7 @@ private TDSMessage _PrepareArithAbortResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).ArithAbort); + rowToken.Data.Add((session as GenericTdsServerSession).ArithAbort); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1542,7 +1542,7 @@ private TDSMessage _PrepareConcatNullYieldsNullResponse(ITDSServerSession sessio TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).ConcatNullYieldsNull); + rowToken.Data.Add((session as GenericTdsServerSession).ConcatNullYieldsNull); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1581,7 +1581,7 @@ private TDSMessage _PrepareDateFirstResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((short)(session as GenericTDSServerSession).DateFirst); + rowToken.Data.Add((short)(session as GenericTdsServerSession).DateFirst); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1622,7 +1622,7 @@ private TDSMessage _PrepareDateFormatResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Generate a date format string - rowToken.Data.Add(DateFormatString.ToString((session as GenericTDSServerSession).DateFormat)); + rowToken.Data.Add(DateFormatString.ToString((session as GenericTdsServerSession).DateFormat)); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1661,7 +1661,7 @@ private TDSMessage _PrepareDeadlockPriorityResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Serialize the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).DeadlockPriority); + rowToken.Data.Add((session as GenericTdsServerSession).DeadlockPriority); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1702,7 +1702,7 @@ private TDSMessage _PrepareLanguageResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Generate a date format string - rowToken.Data.Add(LanguageString.ToString((session as GenericTDSServerSession).Language)); + rowToken.Data.Add(LanguageString.ToString((session as GenericTdsServerSession).Language)); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1741,7 +1741,7 @@ private TDSMessage _PrepareLockTimeoutResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Serialize the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).LockTimeout); + rowToken.Data.Add((session as GenericTdsServerSession).LockTimeout); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1780,7 +1780,7 @@ private TDSMessage _PrepareQuotedIdentifierResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).QuotedIdentifier); + rowToken.Data.Add((session as GenericTdsServerSession).QuotedIdentifier); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1819,7 +1819,7 @@ private TDSMessage _PrepareTextSizeResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).TextSize); + rowToken.Data.Add((session as GenericTdsServerSession).TextSize); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1858,7 +1858,7 @@ private TDSMessage _PrepareTransactionIsolationLevelResponse(ITDSServerSession s TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((short)(session as GenericTDSServerSession).TransactionIsolationLevel); + rowToken.Data.Add((short)(session as GenericTdsServerSession).TransactionIsolationLevel); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1897,7 +1897,7 @@ private TDSMessage _PrepareOptionsResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Convert to generic session - GenericTDSServerSession genericSession = session as GenericTDSServerSession; + GenericTdsServerSession genericSession = session as GenericTdsServerSession; // Serialize the options into the bit mask int options = 0; @@ -2029,13 +2029,13 @@ private TDSMessage _PrepareContextInfoResponse(ITDSServerSession session) byte[] contextInfo = null; // Check if session has a context info - if ((session as GenericTDSServerSession).ContextInfo != null) + if ((session as GenericTdsServerSession).ContextInfo != null) { // Allocate a container contextInfo = new byte[128]; // Copy context info into the container - Array.Copy((session as GenericTDSServerSession).ContextInfo, contextInfo, (session as GenericTDSServerSession).ContextInfo.Length); + Array.Copy((session as GenericTdsServerSession).ContextInfo, contextInfo, (session as GenericTdsServerSession).ContextInfo.Length); } // Set context info diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs similarity index 91% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServer.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs index 57596b24ac..8e119a54cd 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs @@ -16,20 +16,20 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// TDS Server that routes clients to the configured destination /// - public class RoutingTDSServer : GenericTDSServer + public class RoutingTdsServer : GenericTdsServer { /// /// Initialization constructor /// - public RoutingTDSServer() : - this(new RoutingTDSServerArguments()) + public RoutingTdsServer() : + this(new RoutingTdsServerArguments()) { } /// /// Initialization constructor /// - public RoutingTDSServer(RoutingTDSServerArguments arguments) : + public RoutingTdsServer(RoutingTdsServerArguments arguments) : base(arguments) { } @@ -43,10 +43,10 @@ public override TDSMessageCollection OnPreLoginRequest(ITDSServerSession session TDSMessageCollection response = base.OnPreLoginRequest(session, request); // Check if arguments are of the routing server - if (Arguments is RoutingTDSServerArguments) + if (Arguments is RoutingTdsServerArguments) { // Cast to routing server arguments - RoutingTDSServerArguments serverArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments serverArguments = Arguments as RoutingTdsServerArguments; // Check if routing is configured during login if (serverArguments.RouteOnPacket == TDSMessageType.TDS7Login) @@ -78,10 +78,10 @@ public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; // Check if arguments are of the routing server - if (Arguments is RoutingTDSServerArguments) + if (Arguments is RoutingTdsServerArguments) { // Cast to routing server arguments - RoutingTDSServerArguments ServerArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments ServerArguments = Arguments as RoutingTdsServerArguments; // Check filter if (ServerArguments.RequireReadOnly && (loginRequest.TypeFlags.ReadOnlyIntent != TDSLogin7TypeFlagsReadOnlyIntent.ReadOnly)) @@ -136,10 +136,10 @@ public override TDSMessageCollection OnSQLBatchRequest(ITDSServerSession session TDSMessageCollection batchResponse = base.OnSQLBatchRequest(session, request); // Check if arguments are of routing server - if (Arguments is RoutingTDSServerArguments) + if (Arguments is RoutingTdsServerArguments) { // Cast to routing server arguments - RoutingTDSServerArguments ServerArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments ServerArguments = Arguments as RoutingTdsServerArguments; // Check routing condition if (ServerArguments.RouteOnPacket == TDSMessageType.SQLBatch) @@ -188,10 +188,10 @@ protected override TDSMessageCollection OnAuthenticationCompleted(ITDSServerSess TDSMessageCollection responseMessageCollection = base.OnAuthenticationCompleted(session); // Check if arguments are of routing server - if (Arguments is RoutingTDSServerArguments) + if (Arguments is RoutingTdsServerArguments) { // Cast to routing server arguments - RoutingTDSServerArguments serverArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments serverArguments = Arguments as RoutingTdsServerArguments; // Check routing condition if (serverArguments.RouteOnPacket == TDSMessageType.TDS7Login) @@ -233,7 +233,7 @@ protected override TDSMessageCollection OnAuthenticationCompleted(ITDSServerSess protected TDSPacketToken CreateRoutingToken() { // Cast to routing server arguments - RoutingTDSServerArguments ServerArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments ServerArguments = Arguments as RoutingTdsServerArguments; // Construct routing token value TDSRoutingEnvChangeTokenValue routingInfo = new TDSRoutingEnvChangeTokenValue(); diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs similarity index 51% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs index 99cbd3baae..95fe97f4f2 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs @@ -7,43 +7,31 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Arguments for routing TDS Server /// - public class RoutingTDSServerArguments : TDSServerArguments + public class RoutingTdsServerArguments : TdsServerArguments { /// - /// Routing destination protocol + /// Routing destination protocol. /// - public int RoutingProtocol { get; set; } + public int RoutingProtocol { get; set; } = 0; /// /// Routing TCP port /// - public ushort RoutingTCPPort { get; set; } + public ushort RoutingTCPPort { get; set; } = 0; /// /// Routing TCP host name /// - public string RoutingTCPHost { get; set; } + public string RoutingTCPHost { get; set; } = string.Empty; /// /// Packet on which routing should occur /// - public TDSMessageType RouteOnPacket { get; set; } + public TDSMessageType RouteOnPacket { get; set; } = TDSMessageType.TDS7Login; /// /// Indicates that routing should only occur on read-only connections /// - public bool RequireReadOnly { get; set; } - - /// - /// Initialization constructor - /// - public RoutingTDSServerArguments() - { - // By default we route on login - RouteOnPacket = TDSMessageType.TDS7Login; - - // By default we reject non-read-only connections - RequireReadOnly = true; - } + public bool RequireReadOnly { get; set; } = true; } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDS.Servers.csproj b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDS.Servers.csproj index b7757b257b..c689554310 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDS.Servers.csproj +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDS.Servers.csproj @@ -11,21 +11,24 @@ - - + + - - - - - + + + + + - - + + - - - + + + + + + diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServer.cs new file mode 100644 index 0000000000..d3bb1861ef --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServer.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.SqlServer.TDS.Servers +{ + public class TdsServer : GenericTdsServer + { + /// + /// Default constructor + /// + public TdsServer() : this(new TdsServerArguments()) + { + } + + /// + /// Constructor with arguments + /// + public TdsServer(TdsServerArguments arguments) : base(arguments) + { + } + + /// + /// Constructor with arguments and query engine + /// + /// Query engine + /// Server arguments + public TdsServer(QueryEngine queryEngine, TdsServerArguments arguments) : base(arguments, queryEngine) + { + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs similarity index 63% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs index 88e577ab68..7ceb2e0272 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs @@ -13,7 +13,7 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Common arguments for TDS Server /// - public class TDSServerArguments + public class TdsServerArguments { /// /// Service Principal Name, representing Azure SQL Database in Azure Active Directory. @@ -28,76 +28,56 @@ public class TDSServerArguments /// /// Log to which send TDS conversation /// - public TextWriter Log { get; set; } + public TextWriter Log { get; set; } = null; /// /// Server name /// - public string ServerName { get; set; } + public string ServerName { get; set; } = Environment.MachineName; /// /// Server version /// - public Version ServerVersion { get; set; } + public Version ServerVersion { get; set; } = new Version(11, 0, 1083); /// /// Server principal name /// - public string ServerPrincipalName { get; set; } + public string ServerPrincipalName { get; set; } = AzureADServicePrincipalName; /// /// Sts Url /// - public string StsUrl { get; set; } + public string StsUrl { get; set; } = AzureADProductionTokenEndpoint; /// /// Size of the TDS packet server should operate with /// - public int PacketSize { get; set; } + public int PacketSize { get; set; } = 4096; /// /// Transport encryption /// - public TDSPreLoginTokenEncryptionType Encryption { get; set; } + public TDSPreLoginTokenEncryptionType Encryption { get; set; } = TDSPreLoginTokenEncryptionType.NotSupported; /// /// Specifies the FedAuthRequired option /// - public TdsPreLoginFedAuthRequiredOption FedAuthRequiredPreLoginOption { get; set; } + public TdsPreLoginFedAuthRequiredOption FedAuthRequiredPreLoginOption { get; set; } = TdsPreLoginFedAuthRequiredOption.FedAuthNotRequired; /// /// Certificate to use for transport encryption /// - public X509Certificate EncryptionCertificate { get; set; } + public X509Certificate EncryptionCertificate { get; set; } = null; /// /// SSL/TLS protocols to use for transport encryption /// - public SslProtocols EncryptionProtocols { get; set; } + public SslProtocols EncryptionProtocols { get; set; } = SslProtocols.Tls12; /// - /// Initialization constructor + /// Routing destination protocol /// - public TDSServerArguments() - { - // Assign default server version - ServerName = Environment.MachineName; - ServerVersion = new Version(11, 0, 1083); - - // Default packet size - PacketSize = 4096; - - // By default we don't support encryption - Encryption = TDSPreLoginTokenEncryptionType.NotSupported; - - // By Default SQL authentication will be used. - FedAuthRequiredPreLoginOption = TdsPreLoginFedAuthRequiredOption.FedAuthNotRequired; - - EncryptionCertificate = null; - EncryptionProtocols = SslProtocols.Tls12; - - ServerPrincipalName = AzureADServicePrincipalName; - StsUrl = AzureADProductionTokenEndpoint; - } + public string FailoverPartner { get; set; } = string.Empty; } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServer.cs new file mode 100644 index 0000000000..d0a15e90ae --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServer.cs @@ -0,0 +1,84 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using Microsoft.SqlServer.TDS.EndPoint; + +namespace Microsoft.SqlServer.TDS.Servers +{ + /// + /// TDS Server that delays response to simulate transient network delays + /// + public class TransientDelayTdsServer : GenericTdsServer, IDisposable + { + private int RequestCounter = 0; + + public TransientDelayTdsServer(TransientDelayTdsServerArguments arguments) + : base(arguments) + { + } + + public TransientDelayTdsServer(TransientDelayTdsServerArguments arguments, QueryEngine queryEngine) + : base(arguments, queryEngine) + { + } + + /// + public override void Dispose() + { + base.Dispose(); + RequestCounter = 0; + } + + /// + /// Handler for login request + /// + public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) + { + // Check if we're still going to raise transient error + if (Arguments.IsEnabledPermanentDelay || + (Arguments.IsEnabledTransientDelay && RequestCounter < Arguments.RepeatCount)) + { + Thread.Sleep(Arguments.DelayDuration); + + RequestCounter++; + } + + // Return login response from the base class + return base.OnLogin7Request(session, request); + } + + /// + public override TDSMessageCollection OnSQLBatchRequest(ITDSServerSession session, TDSMessage message) + { + if (Arguments.IsEnabledPermanentDelay || + (Arguments.IsEnabledTransientDelay && RequestCounter < Arguments.RepeatCount)) + { + Thread.Sleep(Arguments.DelayDuration); + + RequestCounter++; + } + + return base.OnSQLBatchRequest(session, message); + } + + public void ResetRequestCounter() + { + RequestCounter = 0; + } + + public void SetTransientTimeoutBehavior(bool isEnabledTransientTimeout, TimeSpan sleepDuration) + { + SetTransientTimeoutBehavior(isEnabledTransientTimeout, false, sleepDuration); + } + + public void SetTransientTimeoutBehavior(bool isEnabledTransientTimeout, bool isEnabledPermanentTimeout, TimeSpan sleepDuration) + { + Arguments.IsEnabledTransientDelay = isEnabledTransientTimeout; + Arguments.IsEnabledPermanentDelay = isEnabledPermanentTimeout; + Arguments.DelayDuration = sleepDuration; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServerArguments.cs new file mode 100644 index 0000000000..d89ee7dfdc --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServerArguments.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.SqlServer.TDS.Servers +{ + public class TransientDelayTdsServerArguments : TdsServerArguments + { + /// + /// The duration for which the server should sleep before responding to a request. + /// + public TimeSpan DelayDuration { get; set; } = TimeSpan.FromSeconds(0); + + /// + /// Flag to consider when simulating a delay on the next request. + /// + public bool IsEnabledTransientDelay { get; set; } = false; + + /// + /// Flag to consider when simulating a delay on each request. + /// + public bool IsEnabledPermanentDelay { get; set; } = false; + + /// + /// The number of logins during which the delay should be applied. + /// + public int RepeatCount { get; set; } = 1; + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServer.cs deleted file mode 100644 index 1933444df6..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServer.cs +++ /dev/null @@ -1,153 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Net; -using System.Runtime.CompilerServices; -using System.Threading; -using Microsoft.SqlServer.TDS.Done; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.Error; -using Microsoft.SqlServer.TDS.Login7; - -namespace Microsoft.SqlServer.TDS.Servers -{ - /// - /// TDS Server that authenticates clients according to the requested parameters - /// - public class TransientFaultTDSServer : GenericTDSServer, IDisposable - { - private static int RequestCounter = 0; - - public int Port { get; set; } - - /// - /// Constructor - /// - public TransientFaultTDSServer() => new TransientFaultTDSServer(new TransientFaultTDSServerArguments()); - - /// - /// Constructor - /// - /// - public TransientFaultTDSServer(TransientFaultTDSServerArguments arguments) : - base(arguments) - { } - - /// - /// Constructor - /// - /// - /// - public TransientFaultTDSServer(QueryEngine engine, TransientFaultTDSServerArguments args) : base(args) - { - Engine = engine; - } - - private TDSServerEndPoint _endpoint = null; - - private static string GetErrorMessage(uint errorNumber) - { - switch (errorNumber) - { - case 40613: - return "Database on server is not currently available. Please retry the connection later. " + - "If the problem persists, contact customer support, and provide them the session tracing ID."; - case 42108: - return "Can not connect to the SQL pool since it is paused. Please resume the SQL pool and try again."; - case 42109: - return "The SQL pool is warming up. Please try again."; - } - return "Unknown server error occurred"; - } - - /// - /// Handler for login request - /// - public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) - { - // Inflate login7 request from the message - TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; - - // Check if arguments are of the transient fault TDS server - if (Arguments is TransientFaultTDSServerArguments) - { - // Cast to transient fault TDS server arguments - TransientFaultTDSServerArguments ServerArguments = Arguments as TransientFaultTDSServerArguments; - - // Check if we're still going to raise transient error - if (ServerArguments.IsEnabledTransientError && RequestCounter < 1) // Fail first time, then connect - { - uint errorNumber = ServerArguments.Number; - string errorMessage = ServerArguments.Message; - - // Log request to which we're about to send a failure - TDSUtilities.Log(Arguments.Log, "Request", loginRequest); - - // Prepare ERROR token with the denial details - TDSErrorToken errorToken = new TDSErrorToken(errorNumber, 1, 20, errorMessage); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the error token into the response packet - TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); - - // Create DONE token - TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", doneToken); - - // Serialize DONE token into the response packet - responseMessage.Add(doneToken); - - RequestCounter++; - - // Put a single message into the collection and return it - return new TDSMessageCollection(responseMessage); - } - } - - // Return login response from the base class - return base.OnLogin7Request(session, request); - } - - public static TransientFaultTDSServer StartTestServer(bool isEnabledTransientFault, bool enableLog, uint errorNumber, [CallerMemberName] string methodName = "") - => StartServerWithQueryEngine(null, isEnabledTransientFault, enableLog, errorNumber, methodName); - - public static TransientFaultTDSServer StartServerWithQueryEngine(QueryEngine engine, bool isEnabledTransientFault, bool enableLog, uint errorNumber, [CallerMemberName] string methodName = "") - { - TransientFaultTDSServerArguments args = new TransientFaultTDSServerArguments() - { - Log = enableLog ? Console.Out : null, - IsEnabledTransientError = isEnabledTransientFault, - Number = errorNumber, - Message = GetErrorMessage(errorNumber) - }; - - TransientFaultTDSServer server = engine == null ? new TransientFaultTDSServer(args) : new TransientFaultTDSServer(engine, args); - server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; - server._endpoint.EndpointName = methodName; - - // The server EventLog should be enabled as it logs the exceptions. - server._endpoint.EventLog = enableLog ? Console.Out : null; - server._endpoint.Start(); - - server.Port = server._endpoint.ServerEndPoint.Port; - return server; - } - - public void Dispose() => Dispose(true); - - private void Dispose(bool isDisposing) - { - if (isDisposing) - { - _endpoint?.Stop(); - RequestCounter = 0; - } - } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServer.cs new file mode 100644 index 0000000000..e5d2e52100 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServer.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.SqlServer.TDS.Done; +using Microsoft.SqlServer.TDS.EndPoint; +using Microsoft.SqlServer.TDS.Error; +using Microsoft.SqlServer.TDS.Login7; + +namespace Microsoft.SqlServer.TDS.Servers +{ + /// + /// TDS Server that returns TDS error token on login request for the specified number of times + /// + public class TransientTdsErrorTdsServer : GenericTdsServer, IDisposable + { + private int RequestCounter = 0; + + public void SetErrorBehavior(bool isEnabledTransientError, uint errorNumber, int repeatCount = 1, string message = null) + { + Arguments.IsEnabledTransientError = isEnabledTransientError; + Arguments.Number = errorNumber; + Arguments.Message = message; + Arguments.RepeatCount = repeatCount; + } + + public TransientTdsErrorTdsServer(TransientTdsErrorTdsServerArguments arguments) : base(arguments) + { + } + + public TransientTdsErrorTdsServer(TransientTdsErrorTdsServerArguments arguments, QueryEngine queryEngine) : base(arguments, queryEngine) + { + } + + private static string GetErrorMessage(uint errorNumber) + { + switch (errorNumber) + { + case 40613: + return "Database on server is not currently available. Please retry the connection later. " + + "If the problem persists, contact customer support, and provide them the session tracing ID."; + case 42108: + return "Can not connect to the SQL pool since it is paused. Please resume the SQL pool and try again."; + case 42109: + return "The SQL pool is warming up. Please try again."; + } + return "Unknown server error occurred"; + } + + /// + /// Handler for login request + /// + public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) + { + // Inflate login7 request from the message + TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; + + // Check if we're still going to raise transient error + if (Arguments.IsEnabledTransientError && RequestCounter < Arguments.RepeatCount) + { + uint errorNumber = Arguments.Number; + string errorMessage = Arguments.Message ?? GetErrorMessage(errorNumber); + + // Log request to which we're about to send a failure + TDSUtilities.Log(Arguments.Log, "Request", loginRequest); + + // Prepare ERROR token with the denial details + TDSErrorToken errorToken = new TDSErrorToken(errorNumber, 1, 20, errorMessage); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the error token into the response packet + TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); + + // Create DONE token + TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", doneToken); + + // Serialize DONE token into the response packet + responseMessage.Add(doneToken); + + RequestCounter++; + + // Put a single message into the collection and return it + return new TDSMessageCollection(responseMessage); + } + + // Return login response from the base class + return base.OnLogin7Request(session, request); + } + + public override void Dispose() { + base.Dispose(); + RequestCounter = 0; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServerArguments.cs similarity index 59% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServerArguments.cs index 77eec68c5f..6a61d9cc6f 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServerArguments.cs @@ -4,31 +4,26 @@ namespace Microsoft.SqlServer.TDS.Servers { - public class TransientFaultTDSServerArguments : TDSServerArguments + public class TransientTdsErrorTdsServerArguments : TdsServerArguments { /// /// Transient error number to be raised by server. /// - public uint Number { get; set; } + public uint Number { get; set; } = 0; /// /// Transient error message to be raised by server. /// - public string Message { get; set; } + public string Message { get; set; } = string.Empty; /// /// Flag to consider when raising Transient error. /// - public bool IsEnabledTransientError { get; set; } + public bool IsEnabledTransientError { get; set; } = false; /// - /// Constructor to initialize + /// The number of times the transient error should be raised. /// - public TransientFaultTDSServerArguments() - { - Number = 0; - Message = string.Empty; - IsEnabledTransientError = false; - } + public int RepeatCount { get; set; } = 1; } }