Skip to content

Commit

Permalink
Refactor as suggested by creating hekper class for Federated Authenti…
Browse files Browse the repository at this point in the history
…cation Helper using reflection. Also, applied SOLID and DRY principles.
  • Loading branch information
arellegue committed Jan 24, 2024
1 parent bcb4936 commit fa913b1
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2259,7 +2259,8 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// Try adding this new _newDbConnectionPoolAuthenticationContext to the _dbConnectionPool's AuthenticationContextKeys if it is not in there yet.
// The DbConnectionPoolAuthenticationContextKeys collection is used to refresh a cached token just before it expires within 10 minutes.
_dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
//_dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2685,7 +2685,8 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// Try adding this new _newDbConnectionPoolAuthenticationContext to the _dbConnectionPool's AuthenticationContextKeys if it is not in there yet.
// The DbConnectionPoolAuthenticationContextKeys collection is used to refresh a cached token just before it expires within 10 minutes.
_dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
// _dbConnectionPool.AuthenticationContexts.TryAdd(new DbConnectionPoolAuthenticationContextKey(fedAuthInfo.stsurl, fedAuthInfo.spn), _newDbConnectionPoolAuthenticationContext);
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
<Compile Include="ProviderAgnostic\MultipleResultsTest\MultipleResultsTest.cs" />
<Compile Include="ProviderAgnostic\ReaderTest\ReaderTest.cs" />
<Compile Include="TracingTests\EventSourceTest.cs" />
<Compile Include="SQL\AADFedAuthTokenRefreshTest\AADFedAuthTokenRefreshTest.cs" />
<Compile Include="SQL\ConnectionPoolTest\ConnectionPoolTest.cs" />
<Compile Include="SQL\ConnectionPoolTest\PoolBlockPeriodTest.cs" />
<Compile Include="SQL\InstanceNameTest\InstanceNameTest.cs" />
Expand Down Expand Up @@ -267,7 +268,6 @@
<Compile Include="DataCommon\ProxyServer.cs" />
<Compile Include="DataCommon\SqlClientCustomTokenCredential.cs" />
<Compile Include="DataCommon\SystemDataResourceManager.cs" />
<Compile Include="SQL\AADFedAuthTokenRefreshTest\AADFedAuthTokenRefreshTest.cs" />
<Compile Include="SQL\Common\AsyncDebugScope.cs" />
<Compile Include="SQL\Common\ConnectionPoolWrapper.cs" />
<Compile Include="SQL\Common\InternalConnectionWrapper.cs" />
Expand All @@ -276,6 +276,7 @@
<Compile Include="SQL\Common\SystemDataInternals\ConnectionHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\ConnectionPoolHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\DataReaderHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\FedAuthTokenHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\TdsParserHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\TdsParserStateObjectHelper.cs" />
<Compile Include="SQL\ConnectionTestWithSSLCert\CertificateTest.cs" />
Expand Down Expand Up @@ -341,7 +342,7 @@
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="$(SystemIdentityModelTokensJwtVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netfx'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersionNet)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
<PackageReference Condition="'$(TargetGroup)'!='netfx'" Include="System.ServiceProcess.ServiceController" Version="$(SystemServiceProcessServiceControllerVersion)" />
</ItemGroup>
<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
using System;
using System.Collections;
using System.Linq;
using System.Reflection;
using System.Security.Cryptography;
using System.Text;
using Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals;
using Xunit;
using Xunit.Abstractions;

Expand All @@ -21,21 +17,19 @@ public AADFedAuthTokenRefreshTest(ITestOutputHelper testOutputHelper)
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup))]
public void FedAuthTokenRefreshTest()
{
string connStr = DataTestUtility.AADPasswordConnectionString;
string connectionString = DataTestUtility.AADPasswordConnectionString;

// Create a new connection object and open it
using (SqlConnection connection = new SqlConnection(connStr))
using (SqlConnection connection = new SqlConnection(connectionString))
{
connection.Open();

// Set the token expiry to expire in 1 minute from now to force token refresh
string tokenHash1 = "";
DateTime? oldExpiry = GetOrSetTokenExpiryDateTime(connection, true, out tokenHash1);
Assert.True(oldExpiry != null, "Failed to make token expiry to expire in one minute.");
string oldTokenHash = "";
DateTime? oldExpiryDateTime = FedAuthTokenHelper.SetTokenExpiryDateTime(connection, minutesToExpire: 1, out oldTokenHash);
Assert.True(oldExpiryDateTime != null, "Failed to make token expiry to expire in one minute.");

// Convert and display the old expiry into local time which should be in 1 minute from now
DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiry, TimeZoneInfo.Local);
LogInfo($"Token: {tokenHash1} Old Expiry: {oldLocalExpiryTime}");
DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiryDateTime, TimeZoneInfo.Local);
LogInfo($"Token: {oldTokenHash} Old Expiry: {oldLocalExpiryTime}");
TimeSpan timeDiff = oldLocalExpiryTime - DateTime.Now;
Assert.True(timeDiff.TotalSeconds <= 60, "Failed to set expiry after 1 minute from current time.");

Expand All @@ -47,24 +41,22 @@ public void FedAuthTokenRefreshTest()
Assert.True(result != string.Empty, "The connection's command must return a value");

// The new connection will use the same FedAuthToken but will refresh it first as it will expire in 1 minute.
using (SqlConnection connection2 = new SqlConnection(connStr))
using (SqlConnection connection2 = new SqlConnection(connectionString))
{
connection2.Open();

// Check again if connection is alive
// Check if connection is alive
cmd = connection2.CreateCommand();
cmd.CommandText = "select 1";
result = $"{cmd.ExecuteScalar()}";
Assert.True(result != string.Empty, "The connection's command must return a value after a token refresh.");

// Get the refreshed token expiry
string tokenHash2 = "";
DateTime? newExpiry = GetOrSetTokenExpiryDateTime(connection2, false, out tokenHash2);
// Display new expiry in local time
DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiry, TimeZoneInfo.Local);
LogInfo($"Token: {tokenHash2} New Expiry: {newLocalExpiryTime}");
string newTokenHash = "";
DateTime? newExpiryDateTime = FedAuthTokenHelper.GetTokenExpiryDateTime(connection2, out newTokenHash);
DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiryDateTime, TimeZoneInfo.Local);
LogInfo($"Token: {newTokenHash} New Expiry: {newLocalExpiryTime}");

Assert.True(tokenHash1 == tokenHash2, "The token's hash before and after token refresh must be identical.");
Assert.True(oldTokenHash == newTokenHash, "The token's hash before and after token refresh must be identical.");
Assert.True(newLocalExpiryTime > oldLocalExpiryTime, "The refreshed token must have a new or later expiry time.");
}
}
Expand All @@ -74,60 +66,5 @@ private void LogInfo(string message)
{
_testOutputHelper.WriteLine(message);
}

private DateTime? GetOrSetTokenExpiryDateTime(SqlConnection connection, bool setExpiry, out string tokenHash)
{
try
{
// Get the inner connection
object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection);

// Get the db connection pool
object poolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj);

// Get the Authentication Contexts
IEnumerable authContextCollection = (IEnumerable)poolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(poolObj, null);

// Get the first authentication context
object authContextObj = authContextCollection.Cast<object>().FirstOrDefault();

// Get the token object from the authentication context
object tokenObj = authContextObj.GetType().GetProperty("Value").GetValue(authContextObj, null);

DateTime expiry = (DateTime)tokenObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(tokenObj, null);

if (setExpiry)
{
// Forcing 1 minute expiry to trigger token refresh.
expiry = DateTime.UtcNow.AddMinutes(1);

// Apply the expiry to the token object
FieldInfo expirationTime = tokenObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance);
expirationTime.SetValue(tokenObj, expiry);
}

byte[] tokenBytes = (byte[])tokenObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(tokenObj, null);

tokenHash = GetTokenHash(tokenBytes);

return expiry;
}
catch (Exception)
{
tokenHash = "";
return null;
}
}

private string GetTokenHash(byte[] tokenBytes)
{
string token = Encoding.Unicode.GetString(tokenBytes);
var bytesInUtf8 = Encoding.UTF8.GetBytes(token);
using (var sha256 = SHA256.Create())
{
var hash = sha256.ComputeHash(bytesInUtf8);
return Convert.ToBase64String(hash);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using System;
using System.Collections;
using System.Linq;
using System.Reflection;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals
{
internal static class FedAuthTokenHelper
{
internal static DateTime? GetTokenExpiryDateTime(SqlConnection connection, out string tokenHash)
{
try
{
object authenticationContextValueObj = GetAuthenticationContextValue(connection);

DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);

tokenHash = GetTokenHash(authenticationContextValueObj);

return expirationTimeProperty;
}
catch (Exception)
{
tokenHash = "";
return null;
}
}

internal static DateTime? SetTokenExpiryDateTime(SqlConnection connection, int minutesToExpire, out string tokenHash)
{
try
{
object authenticationContextValueObj = GetAuthenticationContextValue(connection);

DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);

expirationTimeProperty = DateTime.UtcNow.AddMinutes(minutesToExpire);

FieldInfo expirationTimeInfo = authenticationContextValueObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance);
expirationTimeInfo.SetValue(authenticationContextValueObj, expirationTimeProperty);

tokenHash = GetTokenHash(authenticationContextValueObj);

return expirationTimeProperty;
}
catch (Exception)
{
tokenHash = "";
return null;
}
}

internal static string GetTokenHash(object authenticationContextValueObj)
{
try
{
Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));

Type sqlFedAuthTokenType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SqlFedAuthToken");

Type[] sqlFedAuthTokenTypeArray = new Type[] { sqlFedAuthTokenType };

ConstructorInfo sqlFedAuthTokenConstructorInfo = sqlFedAuthTokenType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

Type activeDirectoryAuthenticationTimeoutRetryHelperType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.ActiveDirectoryAuthenticationTimeoutRetryHelper");

ConstructorInfo activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo = activeDirectoryAuthenticationTimeoutRetryHelperType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

object activeDirectoryAuthenticationTimeoutRetryHelperObj = activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo.Invoke(new object[] { });

MethodInfo tokenHashInfo = activeDirectoryAuthenticationTimeoutRetryHelperObj.GetType().GetMethod("GetTokenHash", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, sqlFedAuthTokenTypeArray, null);

byte[] tokenBytes = (byte[])authenticationContextValueObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);

object sqlFedAuthTokenObj = sqlFedAuthTokenConstructorInfo.Invoke(new object[] { });
FieldInfo accessTokenInfo = sqlFedAuthTokenObj.GetType().GetField("accessToken", BindingFlags.NonPublic | BindingFlags.Instance);
accessTokenInfo.SetValue(sqlFedAuthTokenObj, tokenBytes);

string tokenHash = (string)tokenHashInfo.Invoke(activeDirectoryAuthenticationTimeoutRetryHelperObj, new object[] { sqlFedAuthTokenObj });

return tokenHash;
}
catch (Exception)
{
return "";
}
}

internal static object GetAuthenticationContextValue(SqlConnection connection)
{
object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection);

object databaseConnectionPoolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj);

IEnumerable authenticationContexts = (IEnumerable)databaseConnectionPoolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(databaseConnectionPoolObj, null);

object authenticationContextObj = authenticationContexts.Cast<object>().FirstOrDefault();

object authenticationContextValueObj = authenticationContextObj.GetType().GetProperty("Value").GetValue(authenticationContextObj, null);

return authenticationContextValueObj;
}
}
}

0 comments on commit fa913b1

Please sign in to comment.