diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs index b92746098f..ad61422ee2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNICommon.cs @@ -3,9 +3,13 @@ // See the LICENSE file in the project root for more information. using System; +using System.Diagnostics; using System.Net; using System.Net.Security; using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -194,7 +198,7 @@ internal static bool ValidateSslServerCertificate(string targetServerName, X509C return true; } } - + /// /// We validate the provided certificate provided by the client with the one from the server to see if it matches. /// Certificate validation and chain trust validations are done by SSLStream class [System.Net.Security.SecureChannel.VerifyRemoteCertificate method] @@ -239,6 +243,24 @@ internal static bool ValidateSslServerCertificate(X509Certificate clientCert, X5 } } + internal static IPAddress[] GetDnsIpAddresses(string serverName, TimeoutTimer timeout) + { + using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) + { + int remainingTimeout = timeout.MillisecondsRemainingInt; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNICommon), EventType.INFO, + "Getting DNS host entries for serverName {0} within {1} milliseconds.", + args0: serverName, + args1: remainingTimeout); + using CancellationTokenSource cts = new CancellationTokenSource(remainingTimeout); + // using this overload to support netstandard + Task task = Dns.GetHostAddressesAsync(serverName); + task.ConfigureAwait(false); + task.Wait(cts.Token); + return task.Result; + } + } + internal static IPAddress[] GetDnsIpAddresses(string serverName) { using (TrySNIEventScope.Create(nameof(GetDnsIpAddresses))) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs index b48ea36958..2c3c2aeaf3 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNINpHandle.cs @@ -10,6 +10,7 @@ using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Threading; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -37,7 +38,7 @@ internal sealed class SNINpHandle : SNIPhysicalHandle private int _bufferSize = TdsEnums.DEFAULT_LOGIN_PACKET_SIZE; private readonly Guid _connectionId = Guid.NewGuid(); - public SNINpHandle(string serverName, string pipeName, long timerExpire, bool tlsFirst) + public SNINpHandle(string serverName, string pipeName, TimeoutTimer timeout, bool tlsFirst) { using (TrySNIEventScope.Create(nameof(SNINpHandle))) { @@ -54,17 +55,25 @@ public SNINpHandle(string serverName, string pipeName, long timerExpire, bool tl PipeDirection.InOut, PipeOptions.Asynchronous | PipeOptions.WriteThrough); - bool isInfiniteTimeOut = long.MaxValue == timerExpire; - if (isInfiniteTimeOut) + if (timeout.IsInfinite) { + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, + "Connection Id {0}, Setting server name = {1}, pipe name = {2}. Connecting with infinite timeout.", + args0: _connectionId, + args1: serverName, + args2: pipeName); _pipeStream.Connect(Timeout.Infinite); } else { - TimeSpan ts = DateTime.FromFileTime(timerExpire) - DateTime.Now; - ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; - - _pipeStream.Connect((int)ts.TotalMilliseconds); + int timeoutMilliseconds = timeout.MillisecondsRemainingInt; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNINpHandle), EventType.INFO, + "Connection Id {0}, Setting server name = {1}, pipe name = {2}. Connecting within the {3} sepecified milliseconds.", + args0: _connectionId, + args1: serverName, + args2: pipeName, + args3: timeoutMilliseconds); + _pipeStream.Connect(timeoutMilliseconds); } } catch (TimeoutException te) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 8f101b7bdf..b4b3a37222 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -9,6 +9,7 @@ using System.Net.Security; using System.Net.Sockets; using System.Text; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -130,7 +131,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) /// Create a SNI connection handle /// /// Full server name from connection string - /// Timer expiration + /// Timer expiration /// Instance name /// SPN /// pre-defined SPN @@ -147,7 +148,7 @@ private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode) /// SNI handle internal static SNIHandle CreateConnectionHandle( string fullServerName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, ref byte[][] spnBuffer, string serverSPN, @@ -186,11 +187,11 @@ internal static SNIHandle CreateConnectionHandle( case DataSource.Protocol.Admin: case DataSource.Protocol.None: // default to using tcp if no protocol is provided case DataSource.Protocol.TCP: - sniHandle = CreateTcpHandle(details, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo, + sniHandle = CreateTcpHandle(details, timeout, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); break; case DataSource.Protocol.NP: - sniHandle = CreateNpHandle(details, timerExpire, parallel, tlsFirst); + sniHandle = CreateNpHandle(details, timeout, parallel, tlsFirst); break; default: Debug.Fail($"Unexpected connection protocol: {details._connectionProtocol}"); @@ -279,7 +280,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr /// Creates an SNITCPHandle object /// /// Data source - /// Timer expiration + /// Timer expiration /// Should MultiSubnetFailover be used /// IP address preference /// Key for DNS Cache @@ -290,7 +291,7 @@ private static byte[][] GetSqlServerSPNs(string hostNameOrAddress, string portOr /// SNITCPHandle private static SNITCPHandle CreateTcpHandle( DataSource details, - long timerExpire, + TimeoutTimer timeout, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, @@ -317,8 +318,8 @@ private static SNITCPHandle CreateTcpHandle( try { port = isAdminConnection ? - SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference) : - SSRP.GetPortByInstanceName(hostName, details.InstanceName, timerExpire, parallel, ipPreference); + SSRP.GetDacPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference) : + SSRP.GetPortByInstanceName(hostName, details.InstanceName, timeout, parallel, ipPreference); } catch (SocketException se) { @@ -335,7 +336,7 @@ private static SNITCPHandle CreateTcpHandle( port = isAdminConnection ? DefaultSqlServerDacPort : DefaultSqlServerPort; } - return new SNITCPHandle(hostName, port, timerExpire, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo, + return new SNITCPHandle(hostName, port, timeout, parallel, ipPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); } @@ -343,11 +344,11 @@ private static SNITCPHandle CreateTcpHandle( /// Creates an SNINpHandle object /// /// Data source - /// Timer expiration + /// Timer expiration /// Should MultiSubnetFailover be used. Only returns an error for named pipes. /// /// SNINpHandle - private static SNINpHandle CreateNpHandle(DataSource details, long timerExpire, bool parallel, bool tlsFirst) + private static SNINpHandle CreateNpHandle(DataSource details, TimeoutTimer timeout, bool parallel, bool tlsFirst) { if (parallel) { @@ -355,7 +356,7 @@ private static SNINpHandle CreateNpHandle(DataSource details, long timerExpire, SNICommon.ReportSNIError(SNIProviders.NP_PROV, 0, SNICommon.MultiSubnetFailoverWithNonTcpProtocol, Strings.SNI_ERROR_49); return null; } - return new SNINpHandle(details.PipeHostName, details.PipeName, timerExpire, tlsFirst); + return new SNINpHandle(details.PipeHostName, details.PipeName, timeout, tlsFirst); } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs index 14243e98d3..d12e91ad62 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNITcpHandle.cs @@ -16,6 +16,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -118,7 +119,7 @@ public override int ProtocolVersion /// /// Server name /// TCP port number - /// Connection timer expiration + /// Connection timer expiration /// Parallel executions /// IP address preference /// Key for DNS Cache @@ -129,7 +130,7 @@ public override int ProtocolVersion public SNITCPHandle( string serverName, int port, - long timerExpire, + TimeoutTimer timeout, bool parallel, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, @@ -153,17 +154,6 @@ public SNITCPHandle( try { - TimeSpan ts = default(TimeSpan); - - // In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count - // The infinite Timeout is a function of ConnectionString Timeout=0 - bool isInfiniteTimeOut = long.MaxValue == timerExpire; - if (!isInfiniteTimeOut) - { - ts = DateTime.FromFileTime(timerExpire) - DateTime.Now; - ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; - } - bool reportError = true; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Connecting to serverName {1} and port {2}", args0: _connectionId, args1: serverName, args2: port); @@ -174,15 +164,19 @@ public SNITCPHandle( { if (parallel) { - _socket = TryConnectParallel(serverName, port, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + _socket = TryConnectParallel(serverName, port, timeout, ref reportError, cachedFQDN, ref pendingDNSInfo); } else { - _socket = Connect(serverName, port, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); + _socket = Connect(serverName, port, timeout, ipPreference, cachedFQDN, ref pendingDNSInfo); } } catch (Exception ex) { + if (timeout.IsExpired) + { + throw; + } // Retry with cached IP address if (ex is SocketException || ex is ArgumentException || ex is AggregateException) { @@ -214,26 +208,30 @@ public SNITCPHandle( { if (parallel) { - _socket = TryConnectParallel(firstCachedIP, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + _socket = TryConnectParallel(firstCachedIP, portRetry, timeout, ref reportError, cachedFQDN, ref pendingDNSInfo); } else { - _socket = Connect(firstCachedIP, portRetry, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); + _socket = Connect(firstCachedIP, portRetry, timeout, ipPreference, cachedFQDN, ref pendingDNSInfo); } } catch (Exception exRetry) { + if (timeout.IsExpired) + { + throw; + } if (exRetry is SocketException || exRetry is ArgumentNullException || exRetry is ArgumentException || exRetry is ArgumentOutOfRangeException || exRetry is AggregateException) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Retrying exception {1}", args0: _connectionId, args1: exRetry?.Message); if (parallel) { - _socket = TryConnectParallel(secondCachedIP, portRetry, ts, isInfiniteTimeOut, ref reportError, cachedFQDN, ref pendingDNSInfo); + _socket = TryConnectParallel(secondCachedIP, portRetry, timeout, ref reportError, cachedFQDN, ref pendingDNSInfo); } else { - _socket = Connect(secondCachedIP, portRetry, ts, isInfiniteTimeOut, ipPreference, cachedFQDN, ref pendingDNSInfo); + _socket = Connect(secondCachedIP, portRetry, timeout, ipPreference, cachedFQDN, ref pendingDNSInfo); } } else @@ -300,12 +298,15 @@ public SNITCPHandle( // Connect to server with hostName and port in parellel mode. // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. - private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool isInfiniteTimeOut, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + private Socket TryConnectParallel(string hostName, int port, TimeoutTimer timeout, ref bool callerReportError, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { Socket availableSocket = null; Task connectTask; + bool isInfiniteTimeOut = timeout.IsInfinite; - IPAddress[] serverAddresses = SNICommon.GetDnsIpAddresses(hostName); + IPAddress[] serverAddresses = isInfiniteTimeOut + ? SNICommon.GetDnsIpAddresses(hostName) + : SNICommon.GetDnsIpAddresses(hostName, timeout); if (serverAddresses.Length > MaxParallelIpAddresses) { @@ -338,7 +339,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i connectTask = ParallelConnectAsync(serverAddresses, port); - if (!(isInfiniteTimeOut ? connectTask.Wait(-1) : connectTask.Wait(ts))) + if (!(connectTask.Wait(isInfiniteTimeOut ? -1: timeout.MillisecondsRemainingInt))) { callerReportError = false; SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "Connection Id {0} Connection timed out, Exception: {1}", args0: _connectionId, args1: Strings.SNI_ERROR_40); @@ -349,7 +350,7 @@ private Socket TryConnectParallel(string hostName, int port, TimeSpan ts, bool i availableSocket = connectTask.Result; return availableSocket; } - + /// /// Returns array of IP addresses for the given server name, sorted according to the given preference. /// @@ -389,15 +390,14 @@ private static IEnumerable GetHostAddressesSortedByPreference(string } } } - + // Connect to server with hostName and port. // The IP information will be collected temporarily as the pendingDNSInfo but is not stored in the DNS cache at this point. // Only write to the DNS cache when we receive IsSupported flag as true in the Feature Ext Ack from server. - private static Socket Connect(string serverName, int port, TimeSpan timeout, bool isInfiniteTimeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) + private static Socket Connect(string serverName, int port, TimeoutTimer timeout, SqlConnectionIPAddressPreference ipPreference, string cachedFQDN, ref SQLDNSInfo pendingDNSInfo) { SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "IP preference : {0}", Enum.GetName(typeof(SqlConnectionIPAddressPreference), ipPreference)); - - Stopwatch timeTaken = Stopwatch.StartNew(); + bool isInfiniteTimeout = timeout.IsInfinite; IEnumerable ipAddresses = GetHostAddressesSortedByPreference(serverName, ipPreference); @@ -422,26 +422,44 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo port, ipAddress.AddressFamily, isInfiniteTimeout); - + bool isConnected; try // catching SocketException with SocketErrorCode == WouldBlock to run Socket.Select { - socket.Connect(ipAddress, port); - if (!isInfiniteTimeout) + if (isInfiniteTimeout) { + socket.Connect(ipAddress, port); + } + else + { + if (timeout.IsExpired) + { + return null; + } + // Socket.Connect does not support infinite timeouts, so we use Task to simulate it + Task socketConnectTask = new Task(() => socket.Connect(ipAddress, port)); + socketConnectTask.ConfigureAwait(false); + socketConnectTask.Start(); + int remainingTimeout = timeout.MillisecondsRemainingInt; + if (!socketConnectTask.Wait(remainingTimeout)) + { + throw ADP.TimeoutException($"The socket couldn't connect during the expected {remainingTimeout} remaining time."); + } throw SQL.SocketDidNotThrow(); } - + isConnected = true; } - catch (SocketException socketException) when (!isInfiniteTimeout && - socketException.SocketErrorCode == - SocketError.WouldBlock) + catch (AggregateException aggregateException) when (!isInfiniteTimeout + && aggregateException.InnerException is SocketException socketException + && socketException.SocketErrorCode == SocketError.WouldBlock) { // https://github.com/dotnet/SqlClient/issues/826#issuecomment-736224118 // Socket.Select is used because it supports timeouts, while Socket.Connect does not - List checkReadLst; List checkWriteLst; List checkErrorLst; + List checkReadLst; + List checkWriteLst; + List checkErrorLst; // Repeating Socket.Select several times if our timeout is greater // than int.MaxValue microseconds because of @@ -449,18 +467,22 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo // which states that Socket.Select can't handle timeouts greater than int.MaxValue microseconds do { - TimeSpan timeLeft = timeout - timeTaken.Elapsed; - - if (timeLeft <= TimeSpan.Zero) + if (timeout.IsExpired) + { return null; + } int socketSelectTimeout = - checked((int)(Math.Min(timeLeft.TotalMilliseconds, int.MaxValue / 1000) * 1000)); + checked((int)(Math.Min(timeout.MillisecondsRemainingInt, int.MaxValue / 1000) * 1000)); checkReadLst = new List(1) { socket }; checkWriteLst = new List(1) { socket }; checkErrorLst = new List(1) { socket }; + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, + "Determining the status of the socket during the remaining timeout of {0} microseconds.", + socketSelectTimeout); + Socket.Select(checkReadLst, checkWriteLst, checkErrorLst, socketSelectTimeout); // nothing selected means timeout } while (checkReadLst.Count == 0 && checkWriteLst.Count == 0 && checkErrorLst.Count == 0); @@ -487,11 +509,11 @@ private static Socket Connect(string serverName, int port, TimeSpan timeout, boo return socket; } } - catch (SocketException e) + catch (AggregateException aggregateException) when (aggregateException.InnerException is SocketException socketException) { - SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: e?.Message); + SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.ERR, "THIS EXCEPTION IS BEING SWALLOWED: {0}", args0: socketException?.Message); SqlClientEventSource.Log.TryAdvancedTraceEvent( - $"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {e}"); + $"{nameof(SNITCPHandle)}.{nameof(Connect)}{EventType.ERR}THIS EXCEPTION IS BEING SWALLOWED: {socketException}"); } finally { @@ -675,7 +697,7 @@ private bool ValidateServerCertificate(object sender, X509Certificate serverCert SqlClientEventSource.Log.TrySNITraceEvent(nameof(SNITCPHandle), EventType.INFO, "Connection Id {0}, Certificate will not be validated.", args0: _connectionId); return true; } - + string serverNameToValidate; if (!string.IsNullOrEmpty(_hostNameInCertificate)) { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index 0348c227e0..3cad605caa 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -10,6 +10,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -29,11 +30,11 @@ internal sealed class SSRP /// /// SQL Sever Browser hostname /// instance name to find port number - /// Connection timer expiration + /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference /// port number for given instance name - internal static int GetPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + internal static int GetPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); @@ -43,7 +44,7 @@ internal static int GetPortByInstanceName(string browserHostName, string instanc byte[] responsePacket = null; try { - responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timerExpire, allIPsInParallel, ipPreference); + responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, instanceInfoRequest, timeout, allIPsInParallel, ipPreference); } catch (SocketException se) { @@ -104,17 +105,17 @@ private static byte[] CreateInstanceInfoRequest(string instanceName) /// /// SQL Sever Browser hostname /// instance name to lookup DAC port - /// Connection timer expiration + /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference /// DAC port for given instance name - internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + internal static int GetDacPortByInstanceName(string browserHostName, string instanceName, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) { Debug.Assert(!string.IsNullOrWhiteSpace(browserHostName), "browserHostName should not be null, empty, or whitespace"); Debug.Assert(!string.IsNullOrWhiteSpace(instanceName), "instanceName should not be null, empty, or whitespace"); byte[] dacPortInfoRequest = CreateDacPortInfoRequest(instanceName); - byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timerExpire, allIPsInParallel, ipPreference); + byte[] responsePacket = SendUDPRequest(browserHostName, SqlServerBrowserPort, dacPortInfoRequest, timeout, allIPsInParallel, ipPreference); const byte SvrResp = 0x05; const byte ProtocolVersion = 0x01; @@ -163,11 +164,11 @@ private class SsrpResult /// UDP server hostname /// UDP server port /// request packet - /// Connection timer expiration + /// Connection timer expiration /// query all resolved IP addresses in parallel /// IP address preference /// response packet from UDP server - private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, long timerExpire, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) + private static byte[] SendUDPRequest(string browserHostname, int port, byte[] requestPacket, TimeoutTimer timeout, bool allIPsInParallel, SqlConnectionIPAddressPreference ipPreference) { using (TrySNIEventScope.Create(nameof(SSRP))) { @@ -186,16 +187,10 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re return null; } - TimeSpan ts = default; - // In case the Timeout is Infinite, we will receive the max value of Int64 as the tick count - // The infinite Timeout is a function of ConnectionString Timeout=0 - if (long.MaxValue != timerExpire) - { - ts = DateTime.FromFileTime(timerExpire) - DateTime.Now; - ts = ts.Ticks < 0 ? TimeSpan.FromTicks(0) : ts; - } + IPAddress[] ipAddresses = timeout.IsInfinite + ? SNICommon.GetDnsIpAddresses(browserHostname) + : SNICommon.GetDnsIpAddresses(browserHostname, timeout); - IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname); Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); IPAddress[] ipv4Addresses = null; IPAddress[] ipv6Addresses = null; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 0b745ae01e..6a0ee2e0e0 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -1918,7 +1918,7 @@ private void AttemptOneLogin( _parser.Connect(serverInfo, this, - timeout.LegacyTimerExpire, + timeout, ConnectionOptions, withFailover); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index 899ea5fa7f..c2b8f4eeca 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -16,6 +16,7 @@ using System.Threading.Tasks; using System.Xml; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; using Microsoft.Data.Sql; using Microsoft.Data.SqlClient.DataClassification; using Microsoft.Data.SqlClient.Server; @@ -361,7 +362,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) internal void Connect( ServerInfo serverInfo, SqlInternalConnectionTds connHandler, - long timerExpire, + TimeoutTimer timeout, SqlConnectionString connectionOptions, bool withFailover) { @@ -444,7 +445,7 @@ internal void Connect( // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, - timerExpire, + timeout, out instanceName, ref _sniSpnBuffer, false, @@ -487,7 +488,7 @@ internal void Connect( } _state = TdsParserState.OpenNotLoggedIn; _physicalStateObj.SniContext = SniContext.Snix_PreLoginBeforeSuccessfulWrite; - _physicalStateObj.TimeoutTime = timerExpire; + _physicalStateObj.TimeoutTime = timeout.LegacyTimerExpire; bool marsCapable = false; @@ -542,7 +543,7 @@ internal void Connect( _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, - timerExpire, out instanceName, + timeout, out instanceName, ref _sniSpnBuffer, true, true, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index f97dfec553..6c73917d40 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient { @@ -198,7 +199,7 @@ private void ResetCancelAndProcessAttention() internal abstract void CreatePhysicalSNIHandle( string serverName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index 9665d8f188..1e0141dd58 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -13,6 +13,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient.SNI { @@ -82,7 +83,7 @@ protected override uint SNIPacketGetData(PacketHandle packet, byte[] inBuff, ref internal override void CreatePhysicalSNIHandle( string serverName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, @@ -97,7 +98,7 @@ internal override void CreatePhysicalSNIHandle( string hostNameInCertificate, string serverCertificateFilename) { - SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timerExpire, out instanceName, ref spnBuffer, serverSPN, + SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spnBuffer, serverSPN, flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index bf8337cacb..59776956a1 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -11,6 +11,7 @@ using Microsoft.Data.Common; using System.Net; using System.Text; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient { @@ -140,7 +141,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) internal override void CreatePhysicalSNIHandle( string serverName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, ref byte[][] spnBuffer, bool flushCache, @@ -175,30 +176,10 @@ internal override void CreatePhysicalSNIHandle( } SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async); - - // Translate to SNI timeout values (Int32 milliseconds) - long timeout; - if (long.MaxValue == timerExpire) - { - timeout = int.MaxValue; - } - else - { - timeout = ADP.TimerRemainingMilliseconds(timerExpire); - if (timeout > int.MaxValue) - { - timeout = int.MaxValue; - } - else if (0 > timeout) - { - timeout = 0; - } - } - SQLDNSInfo cachedDNSInfo; bool ret = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], checked((int)timeout), out instanceName, + _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer[0], timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index aa6a670021..068b37dc71 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -2300,7 +2300,7 @@ private void AttemptOneLogin(ServerInfo serverInfo, string newPassword, SecureSt _parser.Connect(serverInfo, this, - timeout.LegacyTimerExpire, + timeout, ConnectionOptions, withFailover, isFirstTransparentAttempt, diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 21abbab757..2bdca1a9e0 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -24,6 +24,7 @@ using Microsoft.Data.SqlClient.Server; using Microsoft.Data.SqlTypes; using Microsoft.SqlServer.Server; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient { @@ -494,7 +495,7 @@ internal void ProcessPendingAck(TdsParserStateObject stateObj) internal void Connect(ServerInfo serverInfo, SqlInternalConnectionTds connHandler, - long timerExpire, + TimeoutTimer timeout, SqlConnectionString connectionOptions, bool withFailover, bool isFirstTransparentAttempt, @@ -639,7 +640,7 @@ internal void Connect(ServerInfo serverInfo, _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, - timerExpire, + timeout, out instanceName, _sniSpnBuffer, false, @@ -679,7 +680,7 @@ internal void Connect(ServerInfo serverInfo, } _state = TdsParserState.OpenNotLoggedIn; _physicalStateObj.SniContext = SniContext.Snix_PreLoginBeforeSuccessfulWrite; // SQL BU DT 376766 - _physicalStateObj.TimeoutTime = timerExpire; + _physicalStateObj.TimeoutTime = timeout.LegacyTimerExpire; bool marsCapable = false; @@ -744,7 +745,7 @@ internal void Connect(ServerInfo serverInfo, _physicalStateObj.SniContext = SniContext.Snix_Connect; _physicalStateObj.CreatePhysicalSNIHandle( serverInfo.ExtendedServerName, - timerExpire, + timeout, out instanceName, _sniSpnBuffer, true, diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs index 596373121a..27c38631af 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs @@ -12,6 +12,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Common; +using Microsoft.Data.ProviderBase; namespace Microsoft.Data.SqlClient { @@ -279,7 +280,7 @@ private SNINativeMethodWrapper.ConsumerInfo CreateConsumerInfo(bool async) internal void CreatePhysicalSNIHandle( string serverName, - long timerExpire, + TimeoutTimer timeout, out byte[] instanceName, byte[] spnBuffer, bool flushCache, @@ -293,31 +294,12 @@ internal void CreatePhysicalSNIHandle( { SNINativeMethodWrapper.ConsumerInfo myInfo = CreateConsumerInfo(async); - // Translate to SNI timeout values (Int32 milliseconds) - long timeout; - if (long.MaxValue == timerExpire) - { - timeout = int.MaxValue; - } - else - { - timeout = ADP.TimerRemainingMilliseconds(timerExpire); - if (timeout > int.MaxValue) - { - timeout = int.MaxValue; - } - else if (0 > timeout) - { - timeout = 0; - } - } - // serverName : serverInfo.ExtendedServerName // may not use this serverName as key _ = SQLFallbackDNSCache.Instance.GetDNSInfo(cachedFQDN, out SQLDNSInfo cachedDNSInfo); - _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, checked((int)timeout), + _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, transparentNetworkResolutionState, totalTimeout, ipPreference, cachedDNSInfo, hostNameInCertificate); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/ProviderBase/TimeoutTimer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/ProviderBase/TimeoutTimer.cs index 9948b223d1..37c94fe355 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/ProviderBase/TimeoutTimer.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/ProviderBase/TimeoutTimer.cs @@ -138,7 +138,7 @@ internal bool IsInfinite } // Special accessor for TimerExpire for use when thunking to legacy timeout methods. - internal long LegacyTimerExpire + public long LegacyTimerExpire { get { @@ -180,6 +180,42 @@ internal long MillisecondsRemaining return milliseconds; } } + + // Returns milliseconds remaining trimmed to zero for none remaining + internal int MillisecondsRemainingInt + { + get + { + //------------------- + // Method Body + int milliseconds; + if (_isInfiniteTimeout) + { + milliseconds = int.MaxValue; + } + else + { + long longMilliseconds = ADP.TimerRemainingMilliseconds(_timerExpire); + if (0 > longMilliseconds) + { + milliseconds = 0; + } + else if (longMilliseconds > int.MaxValue) + { + milliseconds = int.MaxValue; + } + else + { + milliseconds = checked((int)longMilliseconds); + } + } + + //-------------------- + // Postconditions + Debug.Assert(0 <= milliseconds); + + return milliseconds; + } + } } } - diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs index 66dc223c4b..01c2f2c050 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs @@ -5,6 +5,7 @@ using System; using System.Data; using System.Data.Common; +using System.Diagnostics; using System.Globalization; using System.Reflection; using System.Security; @@ -269,28 +270,26 @@ public void ConnectionTimeoutTest(int timeout) server.Dispose(); // Measure the actual time it took to timeout and compare it with configured timeout - var start = DateTime.Now; - var end = start; + Stopwatch timer = new(); + Exception ex = null; // Open a connection with the server disposed. try { + timer.Start(); connection.Open(); } - catch (Exception) + catch (Exception e) { - end = DateTime.Now; + timer.Stop(); + ex = e; } - // Calculate actual duration of timeout - TimeSpan s = end - start; - // Did not time out? - if (s.TotalSeconds == 0) - Assert.True(s.TotalSeconds == 0); - - // Is actual time out the same as configured timeout or within an additional 3 second threshold because of overhead? - if (s.TotalSeconds > 0) - Assert.True(s.TotalSeconds <= timeout + 3); + Assert.False(timer.IsRunning, "Timer must be stopped."); + Assert.NotNull(ex); + Assert.True(timer.Elapsed.TotalSeconds <= timeout + 3, + $"The actual timeout {timer.Elapsed.TotalSeconds} is expected to be less than {timeout} plus 3 seconds additional threshold." + + $"{Environment.NewLine}{ex}"); } [Theory] @@ -310,28 +309,28 @@ public async void ConnectionTimeoutTestAsync(int timeout) server.Dispose(); // Measure the actual time it took to timeout and compare it with configured timeout - var start = DateTime.Now; - var end = start; + Stopwatch timer = new(); + Exception ex = null; // Open a connection with the server disposed. try { - await connection.OpenAsync(); + //an asyn call with a timeout token to cancel the operation after the specific time + using CancellationTokenSource cts = new CancellationTokenSource(timeout * 1000); + timer.Start(); + await connection.OpenAsync(cts.Token).ConfigureAwait(false); } - catch (Exception) + catch (Exception e) { - end = DateTime.Now; + timer.Stop(); + ex = e; } - // Calculate actual duration of timeout - TimeSpan s = end - start; - // Did not time out? - if (s.TotalSeconds == 0) - Assert.True(s.TotalSeconds == 0); - - // Is actual time out the same as configured timeout or within an additional 3 second threshold because of overhead? - if (s.TotalSeconds > 0) - Assert.True(s.TotalSeconds <= timeout + 3); + Assert.False(timer.IsRunning, "Timer must be stopped."); + Assert.NotNull(ex); + Assert.True(timer.Elapsed.TotalSeconds <= timeout + 3, + $"The actual timeout {timer.Elapsed.TotalSeconds} is expected to be less than {timeout} plus 3 seconds additional threshold." + + $"{Environment.NewLine}{ex}"); } [Fact]