diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs index 8aaa67e1c4..97c1347b10 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs @@ -172,7 +172,7 @@ public enum SqlAuthenticationMethod SqlPassword = 1 } /// - public partial class SqlAuthenticationParameters + public class SqlAuthenticationParameters { /// protected SqlAuthenticationParameters(Microsoft.Data.SqlClient.SqlAuthenticationMethod authenticationMethod, string serverName, string databaseName, string resource, string authority, string userId, string password, System.Guid connectionId, int connectionTimeout) { } diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs index c8bcfd40f7..e078631f4e 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs @@ -128,7 +128,7 @@ public enum SqlAuthenticationMethod SqlPassword = 1 } /// - public partial class SqlAuthenticationParameters + public class SqlAuthenticationParameters { /// protected SqlAuthenticationParameters(Microsoft.Data.SqlClient.SqlAuthenticationMethod authenticationMethod, string serverName, string databaseName, string resource, string authority, string userId, string password, System.Guid connectionId, int connectionTimeout) { } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlAuthenticationParameters.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlAuthenticationParameters.cs index 745ba716e3..9c74b937b8 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlAuthenticationParameters.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlAuthenticationParameters.cs @@ -38,7 +38,7 @@ public class SqlAuthenticationParameters public string DatabaseName { get; } /// - public int ConnectionTimeout = ADP.DefaultConnectionTimeout; + public int ConnectionTimeout { get; } = ADP.DefaultConnectionTimeout; /// protected SqlAuthenticationParameters( diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs index 2d809934e1..27acea362d 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs @@ -64,6 +64,7 @@ public static class DataTestUtility public static string AADAccessToken = null; public static string AADSystemIdentityAccessToken = null; public static string AADUserIdentityAccessToken = null; + public const string ApplicationClientId = "2fd908ad-0664-4344-b9be-cd3e8b574c38"; public const string UdtTestDbName = "UdtTestDb"; public const string AKVKeyName = "TestSqlClientAzureKeyVaultProvider"; public const string EventSourcePrefix = "Microsoft.Data.SqlClient"; diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs index e9849cccd9..6833668b2a 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs @@ -5,12 +5,57 @@ using System; using System.Diagnostics; using System.Security; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Client; using Xunit; namespace Microsoft.Data.SqlClient.ManualTesting.Tests { public class AADConnectionsTest { + class CustomSqlAuthenticationProvider : SqlAuthenticationProvider + { + string _appClientId; + + internal CustomSqlAuthenticationProvider(string appClientId) + { + _appClientId = appClientId; + } + + public override async Task AcquireTokenAsync(SqlAuthenticationParameters parameters) + { + string s_defaultScopeSuffix = "/.default"; + string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix; + + _ = parameters.ServerName; + _ = parameters.DatabaseName; + _ = parameters.ConnectionId; + + var cts = new CancellationTokenSource(); + cts.CancelAfter(parameters.ConnectionTimeout * 1000); + + string[] scopes = new string[] { scope }; + SecureString password = new SecureString(); + foreach (char c in parameters.Password) + password.AppendChar(c); + password.MakeReadOnly(); + + AuthenticationResult result = await PublicClientApplicationBuilder.Create(_appClientId) + .WithAuthority(parameters.Authority) + .Build().AcquireTokenByUsernamePassword(scopes, parameters.UserId, password) + .WithCorrelationId(parameters.ConnectionId) + .ExecuteAsync(cancellationToken: cts.Token); + + return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn); + } + + public override bool IsSupported(SqlAuthenticationMethod authenticationMethod) + { + return authenticationMethod.Equals(SqlAuthenticationMethod.ActiveDirectoryPassword); + } + } + private static void ConnectAndDisconnect(string connectionString, SqlCredential credential = null) { using (SqlConnection conn = new SqlConnection(connectionString)) @@ -167,7 +212,6 @@ public static void AADPasswordWithWrongPassword() Assert.Contains(expectedMessage, e.Message); } - [ConditionalFact(nameof(IsAADConnStringsSetup))] public static void GetAccessTokenByPasswordTest() { @@ -181,7 +225,7 @@ public static void GetAccessTokenByPasswordTest() } [ConditionalFact(nameof(IsAADConnStringsSetup))] - public static void testADPasswordAuthentication() + public static void TestADPasswordAuthentication() { // Connect to Azure DB with password and retrieve user name. using (SqlConnection conn = new SqlConnection(DataTestUtility.AADPasswordConnectionString)) @@ -201,6 +245,30 @@ public static void testADPasswordAuthentication() } } + [ConditionalFact(nameof(IsAADConnStringsSetup))] + public static void TestCustomProviderAuthentication() + { + SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryPassword, new CustomSqlAuthenticationProvider(DataTestUtility.ApplicationClientId)); + // Connect to Azure DB with password and retrieve user name using custom authentication provider + using (SqlConnection conn = new SqlConnection(DataTestUtility.AADPasswordConnectionString)) + { + conn.Open(); + using (SqlCommand sqlCommand = new SqlCommand + ( + cmdText: $"SELECT SUSER_SNAME();", + connection: conn, + transaction: null + )) + { + string customerId = (string)sqlCommand.ExecuteScalar(); + string expected = DataTestUtility.RetrieveValueFromConnStr(DataTestUtility.AADPasswordConnectionString, new string[] { "User ID", "UID" }); + Assert.Equal(expected, customerId); + } + } + // Reset to driver internal provider. + SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryPassword, new ActiveDirectoryAuthenticationProvider(DataTestUtility.ApplicationClientId)); + } + [ConditionalFact(nameof(IsAADConnStringsSetup))] public static void ActiveDirectoryPasswordWithNoAuthType() { @@ -269,7 +337,7 @@ public static void EmptyCredInConnStrAADPasswordAnyUnix() string[] removeKeys = { "User ID", "Password", "UID", "PWD" }; string connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, removeKeys) + "User ID=; Password=;"; SqlException e = Assert.Throws(() => ConnectAndDisconnect(connStr)); - + string expectedMessage = "MSAL cannot determine the username (UPN) of the currently logged in user.For Integrated Windows Authentication and Username/Password flows, please use .WithUsername() before calling ExecuteAsync()."; Assert.Contains(expectedMessage, e.Message); } @@ -504,13 +572,13 @@ public static void ADInteractiveUsingSSPI() public static void ConnectionSpeed() { var connString = DataTestUtility.AADPasswordConnectionString; - + //Ensure server endpoints are warm using (var connectionDrill = new SqlConnection(connString)) { connectionDrill.Open(); } - + SqlConnection.ClearAllPools(); ActiveDirectoryAuthenticationProvider.ClearUserTokenCache(); @@ -529,7 +597,7 @@ public static void ConnectionSpeed() secondConnectionTime.Stop(); } } - + // Subsequent AAD connections within a short timeframe should use an auth token cached from the connection pool // Second connection speed in tests was typically 10-15% of the first connection time. Using 30% since speeds may vary. Assert.True(((double)secondConnectionTime.ElapsedMilliseconds / firstConnectionTime.ElapsedMilliseconds) < 0.30, $"Second AAD connection too slow ({secondConnectionTime.ElapsedMilliseconds}ms)! (More than 30% of the first ({firstConnectionTime.ElapsedMilliseconds}ms).)");