diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs
index 8aaa67e1c4..97c1347b10 100644
--- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs
+++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs
@@ -172,7 +172,7 @@ public enum SqlAuthenticationMethod
SqlPassword = 1
}
///
- public partial class SqlAuthenticationParameters
+ public class SqlAuthenticationParameters
{
///
protected SqlAuthenticationParameters(Microsoft.Data.SqlClient.SqlAuthenticationMethod authenticationMethod, string serverName, string databaseName, string resource, string authority, string userId, string password, System.Guid connectionId, int connectionTimeout) { }
diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
index c8bcfd40f7..e078631f4e 100644
--- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
+++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs
@@ -128,7 +128,7 @@ public enum SqlAuthenticationMethod
SqlPassword = 1
}
///
- public partial class SqlAuthenticationParameters
+ public class SqlAuthenticationParameters
{
///
protected SqlAuthenticationParameters(Microsoft.Data.SqlClient.SqlAuthenticationMethod authenticationMethod, string serverName, string databaseName, string resource, string authority, string userId, string password, System.Guid connectionId, int connectionTimeout) { }
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlAuthenticationParameters.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlAuthenticationParameters.cs
index 745ba716e3..9c74b937b8 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlAuthenticationParameters.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlAuthenticationParameters.cs
@@ -38,7 +38,7 @@ public class SqlAuthenticationParameters
public string DatabaseName { get; }
///
- public int ConnectionTimeout = ADP.DefaultConnectionTimeout;
+ public int ConnectionTimeout { get; } = ADP.DefaultConnectionTimeout;
///
protected SqlAuthenticationParameters(
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs
index 2d809934e1..27acea362d 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs
@@ -64,6 +64,7 @@ public static class DataTestUtility
public static string AADAccessToken = null;
public static string AADSystemIdentityAccessToken = null;
public static string AADUserIdentityAccessToken = null;
+ public const string ApplicationClientId = "2fd908ad-0664-4344-b9be-cd3e8b574c38";
public const string UdtTestDbName = "UdtTestDb";
public const string AKVKeyName = "TestSqlClientAzureKeyVaultProvider";
public const string EventSourcePrefix = "Microsoft.Data.SqlClient";
diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs
index e9849cccd9..6833668b2a 100644
--- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs
+++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs
@@ -5,12 +5,57 @@
using System;
using System.Diagnostics;
using System.Security;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Identity.Client;
using Xunit;
namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public class AADConnectionsTest
{
+ class CustomSqlAuthenticationProvider : SqlAuthenticationProvider
+ {
+ string _appClientId;
+
+ internal CustomSqlAuthenticationProvider(string appClientId)
+ {
+ _appClientId = appClientId;
+ }
+
+ public override async Task AcquireTokenAsync(SqlAuthenticationParameters parameters)
+ {
+ string s_defaultScopeSuffix = "/.default";
+ string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix;
+
+ _ = parameters.ServerName;
+ _ = parameters.DatabaseName;
+ _ = parameters.ConnectionId;
+
+ var cts = new CancellationTokenSource();
+ cts.CancelAfter(parameters.ConnectionTimeout * 1000);
+
+ string[] scopes = new string[] { scope };
+ SecureString password = new SecureString();
+ foreach (char c in parameters.Password)
+ password.AppendChar(c);
+ password.MakeReadOnly();
+
+ AuthenticationResult result = await PublicClientApplicationBuilder.Create(_appClientId)
+ .WithAuthority(parameters.Authority)
+ .Build().AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
+ .WithCorrelationId(parameters.ConnectionId)
+ .ExecuteAsync(cancellationToken: cts.Token);
+
+ return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
+ }
+
+ public override bool IsSupported(SqlAuthenticationMethod authenticationMethod)
+ {
+ return authenticationMethod.Equals(SqlAuthenticationMethod.ActiveDirectoryPassword);
+ }
+ }
+
private static void ConnectAndDisconnect(string connectionString, SqlCredential credential = null)
{
using (SqlConnection conn = new SqlConnection(connectionString))
@@ -167,7 +212,6 @@ public static void AADPasswordWithWrongPassword()
Assert.Contains(expectedMessage, e.Message);
}
-
[ConditionalFact(nameof(IsAADConnStringsSetup))]
public static void GetAccessTokenByPasswordTest()
{
@@ -181,7 +225,7 @@ public static void GetAccessTokenByPasswordTest()
}
[ConditionalFact(nameof(IsAADConnStringsSetup))]
- public static void testADPasswordAuthentication()
+ public static void TestADPasswordAuthentication()
{
// Connect to Azure DB with password and retrieve user name.
using (SqlConnection conn = new SqlConnection(DataTestUtility.AADPasswordConnectionString))
@@ -201,6 +245,30 @@ public static void testADPasswordAuthentication()
}
}
+ [ConditionalFact(nameof(IsAADConnStringsSetup))]
+ public static void TestCustomProviderAuthentication()
+ {
+ SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryPassword, new CustomSqlAuthenticationProvider(DataTestUtility.ApplicationClientId));
+ // Connect to Azure DB with password and retrieve user name using custom authentication provider
+ using (SqlConnection conn = new SqlConnection(DataTestUtility.AADPasswordConnectionString))
+ {
+ conn.Open();
+ using (SqlCommand sqlCommand = new SqlCommand
+ (
+ cmdText: $"SELECT SUSER_SNAME();",
+ connection: conn,
+ transaction: null
+ ))
+ {
+ string customerId = (string)sqlCommand.ExecuteScalar();
+ string expected = DataTestUtility.RetrieveValueFromConnStr(DataTestUtility.AADPasswordConnectionString, new string[] { "User ID", "UID" });
+ Assert.Equal(expected, customerId);
+ }
+ }
+ // Reset to driver internal provider.
+ SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryPassword, new ActiveDirectoryAuthenticationProvider(DataTestUtility.ApplicationClientId));
+ }
+
[ConditionalFact(nameof(IsAADConnStringsSetup))]
public static void ActiveDirectoryPasswordWithNoAuthType()
{
@@ -269,7 +337,7 @@ public static void EmptyCredInConnStrAADPasswordAnyUnix()
string[] removeKeys = { "User ID", "Password", "UID", "PWD" };
string connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, removeKeys) + "User ID=; Password=;";
SqlException e = Assert.Throws(() => ConnectAndDisconnect(connStr));
-
+
string expectedMessage = "MSAL cannot determine the username (UPN) of the currently logged in user.For Integrated Windows Authentication and Username/Password flows, please use .WithUsername() before calling ExecuteAsync().";
Assert.Contains(expectedMessage, e.Message);
}
@@ -504,13 +572,13 @@ public static void ADInteractiveUsingSSPI()
public static void ConnectionSpeed()
{
var connString = DataTestUtility.AADPasswordConnectionString;
-
+
//Ensure server endpoints are warm
using (var connectionDrill = new SqlConnection(connString))
{
connectionDrill.Open();
}
-
+
SqlConnection.ClearAllPools();
ActiveDirectoryAuthenticationProvider.ClearUserTokenCache();
@@ -529,7 +597,7 @@ public static void ConnectionSpeed()
secondConnectionTime.Stop();
}
}
-
+
// Subsequent AAD connections within a short timeframe should use an auth token cached from the connection pool
// Second connection speed in tests was typically 10-15% of the first connection time. Using 30% since speeds may vary.
Assert.True(((double)secondConnectionTime.ElapsedMilliseconds / firstConnectionTime.ElapsedMilliseconds) < 0.30, $"Second AAD connection too slow ({secondConnectionTime.ElapsedMilliseconds}ms)! (More than 30% of the first ({firstConnectionTime.ElapsedMilliseconds}ms).)");