Skip to content

Commit fa913b1

Browse files
committed
Refactor as suggested by creating hekper class for Federated Authentication Helper using reflection. Also, applied SOLID and DRY principles.
1 parent bcb4936 commit fa913b1

File tree

5 files changed

+126
-82
lines changed

5 files changed

+126
-82
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2259,7 +2259,8 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
22592259
{
22602260
// Try adding this new _newDbConnectionPoolAuthenticationContext to the _dbConnectionPool's AuthenticationContextKeys if it is not in there yet.
22612261
// The DbConnectionPoolAuthenticationContextKeys collection is used to refresh a cached token just before it expires within 10 minutes.
2262-
_dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
2262+
//_dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
2263+
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
22632264
}
22642265
}
22652266
}

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2685,7 +2685,8 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
26852685
{
26862686
// Try adding this new _newDbConnectionPoolAuthenticationContext to the _dbConnectionPool's AuthenticationContextKeys if it is not in there yet.
26872687
// The DbConnectionPoolAuthenticationContextKeys collection is used to refresh a cached token just before it expires within 10 minutes.
2688-
_dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
2688+
// _dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
2689+
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
26892690
}
26902691
}
26912692
}

src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@
170170
<Compile Include="ProviderAgnostic\MultipleResultsTest\MultipleResultsTest.cs" />
171171
<Compile Include="ProviderAgnostic\ReaderTest\ReaderTest.cs" />
172172
<Compile Include="TracingTests\EventSourceTest.cs" />
173+
<Compile Include="SQL\AADFedAuthTokenRefreshTest\AADFedAuthTokenRefreshTest.cs" />
173174
<Compile Include="SQL\ConnectionPoolTest\ConnectionPoolTest.cs" />
174175
<Compile Include="SQL\ConnectionPoolTest\PoolBlockPeriodTest.cs" />
175176
<Compile Include="SQL\InstanceNameTest\InstanceNameTest.cs" />
@@ -267,7 +268,6 @@
267268
<Compile Include="DataCommon\ProxyServer.cs" />
268269
<Compile Include="DataCommon\SqlClientCustomTokenCredential.cs" />
269270
<Compile Include="DataCommon\SystemDataResourceManager.cs" />
270-
<Compile Include="SQL\AADFedAuthTokenRefreshTest\AADFedAuthTokenRefreshTest.cs" />
271271
<Compile Include="SQL\Common\AsyncDebugScope.cs" />
272272
<Compile Include="SQL\Common\ConnectionPoolWrapper.cs" />
273273
<Compile Include="SQL\Common\InternalConnectionWrapper.cs" />
@@ -276,6 +276,7 @@
276276
<Compile Include="SQL\Common\SystemDataInternals\ConnectionHelper.cs" />
277277
<Compile Include="SQL\Common\SystemDataInternals\ConnectionPoolHelper.cs" />
278278
<Compile Include="SQL\Common\SystemDataInternals\DataReaderHelper.cs" />
279+
<Compile Include="SQL\Common\SystemDataInternals\FedAuthTokenHelper.cs" />
279280
<Compile Include="SQL\Common\SystemDataInternals\TdsParserHelper.cs" />
280281
<Compile Include="SQL\Common\SystemDataInternals\TdsParserStateObjectHelper.cs" />
281282
<Compile Include="SQL\ConnectionTestWithSSLCert\CertificateTest.cs" />
@@ -341,7 +342,7 @@
341342
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="$(SystemIdentityModelTokensJwtVersion)" />
342343
<PackageReference Condition="'$(TargetGroup)'=='netfx'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersion)" />
343344
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersionNet)" />
344-
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
345+
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
345346
<PackageReference Condition="'$(TargetGroup)'!='netfx'" Include="System.ServiceProcess.ServiceController" Version="$(SystemServiceProcessServiceControllerVersion)" />
346347
</ItemGroup>
347348
<ItemGroup>
Lines changed: 15 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
using System;
2-
using System.Collections;
3-
using System.Linq;
4-
using System.Reflection;
5-
using System.Security.Cryptography;
6-
using System.Text;
2+
using Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals;
73
using Xunit;
84
using Xunit.Abstractions;
95

@@ -21,21 +17,19 @@ public AADFedAuthTokenRefreshTest(ITestOutputHelper testOutputHelper)
2117
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup))]
2218
public void FedAuthTokenRefreshTest()
2319
{
24-
string connStr = DataTestUtility.AADPasswordConnectionString;
20+
string connectionString = DataTestUtility.AADPasswordConnectionString;
2521

26-
// Create a new connection object and open it
27-
using (SqlConnection connection = new SqlConnection(connStr))
22+
using (SqlConnection connection = new SqlConnection(connectionString))
2823
{
2924
connection.Open();
3025

31-
// Set the token expiry to expire in 1 minute from now to force token refresh
32-
string tokenHash1 = "";
33-
DateTime? oldExpiry = GetOrSetTokenExpiryDateTime(connection, true, out tokenHash1);
34-
Assert.True(oldExpiry != null, "Failed to make token expiry to expire in one minute.");
26+
string oldTokenHash = "";
27+
DateTime? oldExpiryDateTime = FedAuthTokenHelper.SetTokenExpiryDateTime(connection, minutesToExpire: 1, out oldTokenHash);
28+
Assert.True(oldExpiryDateTime != null, "Failed to make token expiry to expire in one minute.");
3529

3630
// Convert and display the old expiry into local time which should be in 1 minute from now
37-
DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiry, TimeZoneInfo.Local);
38-
LogInfo($"Token: {tokenHash1} Old Expiry: {oldLocalExpiryTime}");
31+
DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiryDateTime, TimeZoneInfo.Local);
32+
LogInfo($"Token: {oldTokenHash} Old Expiry: {oldLocalExpiryTime}");
3933
TimeSpan timeDiff = oldLocalExpiryTime - DateTime.Now;
4034
Assert.True(timeDiff.TotalSeconds <= 60, "Failed to set expiry after 1 minute from current time.");
4135

@@ -47,24 +41,22 @@ public void FedAuthTokenRefreshTest()
4741
Assert.True(result != string.Empty, "The connection's command must return a value");
4842

4943
// The new connection will use the same FedAuthToken but will refresh it first as it will expire in 1 minute.
50-
using (SqlConnection connection2 = new SqlConnection(connStr))
44+
using (SqlConnection connection2 = new SqlConnection(connectionString))
5145
{
5246
connection2.Open();
5347

54-
// Check again if connection is alive
48+
// Check if connection is alive
5549
cmd = connection2.CreateCommand();
5650
cmd.CommandText = "select 1";
5751
result = $"{cmd.ExecuteScalar()}";
5852
Assert.True(result != string.Empty, "The connection's command must return a value after a token refresh.");
5953

60-
// Get the refreshed token expiry
61-
string tokenHash2 = "";
62-
DateTime? newExpiry = GetOrSetTokenExpiryDateTime(connection2, false, out tokenHash2);
63-
// Display new expiry in local time
64-
DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiry, TimeZoneInfo.Local);
65-
LogInfo($"Token: {tokenHash2} New Expiry: {newLocalExpiryTime}");
54+
string newTokenHash = "";
55+
DateTime? newExpiryDateTime = FedAuthTokenHelper.GetTokenExpiryDateTime(connection2, out newTokenHash);
56+
DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiryDateTime, TimeZoneInfo.Local);
57+
LogInfo($"Token: {newTokenHash} New Expiry: {newLocalExpiryTime}");
6658

67-
Assert.True(tokenHash1 == tokenHash2, "The token's hash before and after token refresh must be identical.");
59+
Assert.True(oldTokenHash == newTokenHash, "The token's hash before and after token refresh must be identical.");
6860
Assert.True(newLocalExpiryTime > oldLocalExpiryTime, "The refreshed token must have a new or later expiry time.");
6961
}
7062
}
@@ -74,60 +66,5 @@ private void LogInfo(string message)
7466
{
7567
_testOutputHelper.WriteLine(message);
7668
}
77-
78-
private DateTime? GetOrSetTokenExpiryDateTime(SqlConnection connection, bool setExpiry, out string tokenHash)
79-
{
80-
try
81-
{
82-
// Get the inner connection
83-
object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection);
84-
85-
// Get the db connection pool
86-
object poolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj);
87-
88-
// Get the Authentication Contexts
89-
IEnumerable authContextCollection = (IEnumerable)poolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(poolObj, null);
90-
91-
// Get the first authentication context
92-
object authContextObj = authContextCollection.Cast<object>().FirstOrDefault();
93-
94-
// Get the token object from the authentication context
95-
object tokenObj = authContextObj.GetType().GetProperty("Value").GetValue(authContextObj, null);
96-
97-
DateTime expiry = (DateTime)tokenObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(tokenObj, null);
98-
99-
if (setExpiry)
100-
{
101-
// Forcing 1 minute expiry to trigger token refresh.
102-
expiry = DateTime.UtcNow.AddMinutes(1);
103-
104-
// Apply the expiry to the token object
105-
FieldInfo expirationTime = tokenObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance);
106-
expirationTime.SetValue(tokenObj, expiry);
107-
}
108-
109-
byte[] tokenBytes = (byte[])tokenObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(tokenObj, null);
110-
111-
tokenHash = GetTokenHash(tokenBytes);
112-
113-
return expiry;
114-
}
115-
catch (Exception)
116-
{
117-
tokenHash = "";
118-
return null;
119-
}
120-
}
121-
122-
private string GetTokenHash(byte[] tokenBytes)
123-
{
124-
string token = Encoding.Unicode.GetString(tokenBytes);
125-
var bytesInUtf8 = Encoding.UTF8.GetBytes(token);
126-
using (var sha256 = SHA256.Create())
127-
{
128-
var hash = sha256.ComputeHash(bytesInUtf8);
129-
return Convert.ToBase64String(hash);
130-
}
131-
}
13269
}
13370
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
using System;
2+
using System.Collections;
3+
using System.Linq;
4+
using System.Reflection;
5+
6+
namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals
7+
{
8+
internal static class FedAuthTokenHelper
9+
{
10+
internal static DateTime? GetTokenExpiryDateTime(SqlConnection connection, out string tokenHash)
11+
{
12+
try
13+
{
14+
object authenticationContextValueObj = GetAuthenticationContextValue(connection);
15+
16+
DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);
17+
18+
tokenHash = GetTokenHash(authenticationContextValueObj);
19+
20+
return expirationTimeProperty;
21+
}
22+
catch (Exception)
23+
{
24+
tokenHash = "";
25+
return null;
26+
}
27+
}
28+
29+
internal static DateTime? SetTokenExpiryDateTime(SqlConnection connection, int minutesToExpire, out string tokenHash)
30+
{
31+
try
32+
{
33+
object authenticationContextValueObj = GetAuthenticationContextValue(connection);
34+
35+
DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);
36+
37+
expirationTimeProperty = DateTime.UtcNow.AddMinutes(minutesToExpire);
38+
39+
FieldInfo expirationTimeInfo = authenticationContextValueObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance);
40+
expirationTimeInfo.SetValue(authenticationContextValueObj, expirationTimeProperty);
41+
42+
tokenHash = GetTokenHash(authenticationContextValueObj);
43+
44+
return expirationTimeProperty;
45+
}
46+
catch (Exception)
47+
{
48+
tokenHash = "";
49+
return null;
50+
}
51+
}
52+
53+
internal static string GetTokenHash(object authenticationContextValueObj)
54+
{
55+
try
56+
{
57+
Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));
58+
59+
Type sqlFedAuthTokenType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SqlFedAuthToken");
60+
61+
Type[] sqlFedAuthTokenTypeArray = new Type[] { sqlFedAuthTokenType };
62+
63+
ConstructorInfo sqlFedAuthTokenConstructorInfo = sqlFedAuthTokenType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
64+
65+
Type activeDirectoryAuthenticationTimeoutRetryHelperType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.ActiveDirectoryAuthenticationTimeoutRetryHelper");
66+
67+
ConstructorInfo activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo = activeDirectoryAuthenticationTimeoutRetryHelperType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);
68+
69+
object activeDirectoryAuthenticationTimeoutRetryHelperObj = activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo.Invoke(new object[] { });
70+
71+
MethodInfo tokenHashInfo = activeDirectoryAuthenticationTimeoutRetryHelperObj.GetType().GetMethod("GetTokenHash", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, sqlFedAuthTokenTypeArray, null);
72+
73+
byte[] tokenBytes = (byte[])authenticationContextValueObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);
74+
75+
object sqlFedAuthTokenObj = sqlFedAuthTokenConstructorInfo.Invoke(new object[] { });
76+
FieldInfo accessTokenInfo = sqlFedAuthTokenObj.GetType().GetField("accessToken", BindingFlags.NonPublic | BindingFlags.Instance);
77+
accessTokenInfo.SetValue(sqlFedAuthTokenObj, tokenBytes);
78+
79+
string tokenHash = (string)tokenHashInfo.Invoke(activeDirectoryAuthenticationTimeoutRetryHelperObj, new object[] { sqlFedAuthTokenObj });
80+
81+
return tokenHash;
82+
}
83+
catch (Exception)
84+
{
85+
return "";
86+
}
87+
}
88+
89+
internal static object GetAuthenticationContextValue(SqlConnection connection)
90+
{
91+
object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection);
92+
93+
object databaseConnectionPoolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj);
94+
95+
IEnumerable authenticationContexts = (IEnumerable)databaseConnectionPoolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(databaseConnectionPoolObj, null);
96+
97+
object authenticationContextObj = authenticationContexts.Cast<object>().FirstOrDefault();
98+
99+
object authenticationContextValueObj = authenticationContextObj.GetType().GetProperty("Value").GetValue(authenticationContextObj, null);
100+
101+
return authenticationContextValueObj;
102+
}
103+
}
104+
}

0 commit comments

Comments
 (0)