55using System ;
66using System . Diagnostics ;
77using System . Security ;
8+ using System . Threading ;
9+ using System . Threading . Tasks ;
10+ using Microsoft . Identity . Client ;
811using Xunit ;
912
1013namespace Microsoft . Data . SqlClient . ManualTesting . Tests
1114{
1215 public class AADConnectionsTest
1316 {
17+ class CustomSqlAuthenticationProvider : SqlAuthenticationProvider
18+ {
19+ string _appClientId ;
20+
21+ internal CustomSqlAuthenticationProvider ( string appClientId )
22+ {
23+ _appClientId = appClientId ;
24+ }
25+
26+ public override async Task < SqlAuthenticationToken > AcquireTokenAsync ( SqlAuthenticationParameters parameters )
27+ {
28+ string s_defaultScopeSuffix = "/.default" ;
29+ string scope = parameters . Resource . EndsWith ( s_defaultScopeSuffix ) ? parameters . Resource : parameters . Resource + s_defaultScopeSuffix ;
30+
31+ _ = parameters . ServerName ;
32+ _ = parameters . DatabaseName ;
33+ _ = parameters . ConnectionId ;
34+
35+ var cts = new CancellationTokenSource ( ) ;
36+ cts . CancelAfter ( parameters . ConnectionTimeout * 1000 ) ;
37+
38+ string [ ] scopes = new string [ ] { scope } ;
39+ SecureString password = new SecureString ( ) ;
40+ foreach ( char c in parameters . Password )
41+ password . AppendChar ( c ) ;
42+ password . MakeReadOnly ( ) ;
43+
44+ AuthenticationResult result = await PublicClientApplicationBuilder . Create ( _appClientId )
45+ . WithAuthority ( parameters . Authority )
46+ . Build ( ) . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , password )
47+ . WithCorrelationId ( parameters . ConnectionId )
48+ . ExecuteAsync ( cancellationToken : cts . Token ) ;
49+
50+ return new SqlAuthenticationToken ( result . AccessToken , result . ExpiresOn ) ;
51+ }
52+
53+ public override bool IsSupported ( SqlAuthenticationMethod authenticationMethod )
54+ {
55+ return authenticationMethod . Equals ( SqlAuthenticationMethod . ActiveDirectoryPassword ) ;
56+ }
57+ }
58+
1459 private static void ConnectAndDisconnect ( string connectionString , SqlCredential credential = null )
1560 {
1661 using ( SqlConnection conn = new SqlConnection ( connectionString ) )
@@ -167,7 +212,6 @@ public static void AADPasswordWithWrongPassword()
167212 Assert . Contains ( expectedMessage , e . Message ) ;
168213 }
169214
170-
171215 [ ConditionalFact ( nameof ( IsAADConnStringsSetup ) ) ]
172216 public static void GetAccessTokenByPasswordTest ( )
173217 {
@@ -181,7 +225,7 @@ public static void GetAccessTokenByPasswordTest()
181225 }
182226
183227 [ ConditionalFact ( nameof ( IsAADConnStringsSetup ) ) ]
184- public static void testADPasswordAuthentication ( )
228+ public static void TestADPasswordAuthentication ( )
185229 {
186230 // Connect to Azure DB with password and retrieve user name.
187231 using ( SqlConnection conn = new SqlConnection ( DataTestUtility . AADPasswordConnectionString ) )
@@ -201,6 +245,30 @@ public static void testADPasswordAuthentication()
201245 }
202246 }
203247
248+ [ ConditionalFact ( nameof ( IsAADConnStringsSetup ) ) ]
249+ public static void TestCustomProviderAuthentication ( )
250+ {
251+ SqlAuthenticationProvider . SetProvider ( SqlAuthenticationMethod . ActiveDirectoryPassword , new CustomSqlAuthenticationProvider ( DataTestUtility . ApplicationClientId ) ) ;
252+ // Connect to Azure DB with password and retrieve user name using custom authentication provider
253+ using ( SqlConnection conn = new SqlConnection ( DataTestUtility . AADPasswordConnectionString ) )
254+ {
255+ conn . Open ( ) ;
256+ using ( SqlCommand sqlCommand = new SqlCommand
257+ (
258+ cmdText : $ "SELECT SUSER_SNAME();",
259+ connection : conn ,
260+ transaction : null
261+ ) )
262+ {
263+ string customerId = ( string ) sqlCommand . ExecuteScalar ( ) ;
264+ string expected = DataTestUtility . RetrieveValueFromConnStr ( DataTestUtility . AADPasswordConnectionString , new string [ ] { "User ID" , "UID" } ) ;
265+ Assert . Equal ( expected , customerId ) ;
266+ }
267+ }
268+ // Reset to driver internal provider.
269+ SqlAuthenticationProvider . SetProvider ( SqlAuthenticationMethod . ActiveDirectoryPassword , new ActiveDirectoryAuthenticationProvider ( DataTestUtility . ApplicationClientId ) ) ;
270+ }
271+
204272 [ ConditionalFact ( nameof ( IsAADConnStringsSetup ) ) ]
205273 public static void ActiveDirectoryPasswordWithNoAuthType ( )
206274 {
@@ -269,7 +337,7 @@ public static void EmptyCredInConnStrAADPasswordAnyUnix()
269337 string [ ] removeKeys = { "User ID" , "Password" , "UID" , "PWD" } ;
270338 string connStr = DataTestUtility . RemoveKeysInConnStr ( DataTestUtility . AADPasswordConnectionString , removeKeys ) + "User ID=; Password=;" ;
271339 SqlException e = Assert . Throws < SqlException > ( ( ) => ConnectAndDisconnect ( connStr ) ) ;
272-
340+
273341 string expectedMessage = "MSAL cannot determine the username (UPN) of the currently logged in user.For Integrated Windows Authentication and Username/Password flows, please use .WithUsername() before calling ExecuteAsync()." ;
274342 Assert . Contains ( expectedMessage , e . Message ) ;
275343 }
@@ -504,13 +572,13 @@ public static void ADInteractiveUsingSSPI()
504572 public static void ConnectionSpeed ( )
505573 {
506574 var connString = DataTestUtility . AADPasswordConnectionString ;
507-
575+
508576 //Ensure server endpoints are warm
509577 using ( var connectionDrill = new SqlConnection ( connString ) )
510578 {
511579 connectionDrill . Open ( ) ;
512580 }
513-
581+
514582 SqlConnection . ClearAllPools ( ) ;
515583 ActiveDirectoryAuthenticationProvider . ClearUserTokenCache ( ) ;
516584
@@ -529,7 +597,7 @@ public static void ConnectionSpeed()
529597 secondConnectionTime . Stop ( ) ;
530598 }
531599 }
532-
600+
533601 // Subsequent AAD connections within a short timeframe should use an auth token cached from the connection pool
534602 // Second connection speed in tests was typically 10-15% of the first connection time. Using 30% since speeds may vary.
535603 Assert . True ( ( ( double ) secondConnectionTime . ElapsedMilliseconds / firstConnectionTime . ElapsedMilliseconds ) < 0.30 , $ "Second AAD connection too slow ({ secondConnectionTime . ElapsedMilliseconds } ms)! (More than 30% of the first ({ firstConnectionTime . ElapsedMilliseconds } ms).)") ;
0 commit comments