Skip to content

Commit

Permalink
Fix | Fix driver to not send expired token and refresh token first be…
Browse files Browse the repository at this point in the history
…fore sending it. (#2273)
  • Loading branch information
arellegue authored Jan 24, 2024
1 parent 9347412 commit 3487389
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2256,6 +2256,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");

if (_newDbConnectionPoolAuthenticationContext != null)
{
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
}
}
}
else if (!attemptRefreshTokenLocked)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2683,6 +2683,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
{
// GetFedAuthToken should have updated _newDbConnectionPoolAuthenticationContext.
Debug.Assert(_newDbConnectionPoolAuthenticationContext != null, "_newDbConnectionPoolAuthenticationContext should not be null.");

if (_newDbConnectionPoolAuthenticationContext != null)
{
_dbConnectionPool.AuthenticationContexts.TryAdd(_dbConnectionPoolAuthenticationContextKey, _newDbConnectionPoolAuthenticationContext);
}
}
}
else if (!attemptRefreshTokenLocked)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@
<Compile Include="ProviderAgnostic\MultipleResultsTest\MultipleResultsTest.cs" />
<Compile Include="ProviderAgnostic\ReaderTest\ReaderTest.cs" />
<Compile Include="TracingTests\EventSourceTest.cs" />
<Compile Include="SQL\AADFedAuthTokenRefreshTest\AADFedAuthTokenRefreshTest.cs" />
<Compile Include="SQL\ConnectionPoolTest\ConnectionPoolTest.cs" />
<Compile Include="SQL\ConnectionPoolTest\PoolBlockPeriodTest.cs" />
<Compile Include="SQL\InstanceNameTest\InstanceNameTest.cs" />
Expand Down Expand Up @@ -279,6 +280,7 @@
<Compile Include="SQL\Common\SystemDataInternals\ConnectionHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\ConnectionPoolHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\DataReaderHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\FedAuthTokenHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\TdsParserHelper.cs" />
<Compile Include="SQL\Common\SystemDataInternals\TdsParserStateObjectHelper.cs" />
<Compile Include="SQL\ConnectionTestWithSSLCert\CertificateTest.cs" />
Expand Down Expand Up @@ -342,7 +344,7 @@
<PackageReference Include="System.IdentityModel.Tokens.Jwt" Version="$(SystemIdentityModelTokensJwtVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netfx'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersionNet)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.DotNet.RemoteExecutor" Version="$(MicrosoftDotnetRemoteExecutorVersion)" />
<PackageReference Condition="'$(TargetGroup)'!='netfx'" Include="System.ServiceProcess.ServiceController" Version="$(SystemServiceProcessServiceControllerVersion)" />
</ItemGroup>
<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals;
using Xunit;
using Xunit.Abstractions;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests
{
public class AADFedAuthTokenRefreshTest
{
private readonly ITestOutputHelper _testOutputHelper;

public AADFedAuthTokenRefreshTest(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
}

[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsAADPasswordConnStrSetup))]
public void FedAuthTokenRefreshTest()
{
string connectionString = DataTestUtility.AADPasswordConnectionString;

using (SqlConnection connection = new SqlConnection(connectionString))
{
connection.Open();

string oldTokenHash = "";
DateTime? oldExpiryDateTime = FedAuthTokenHelper.SetTokenExpiryDateTime(connection, minutesToExpire: 1, out oldTokenHash);
Assert.True(oldExpiryDateTime != null, "Failed to make token expiry to expire in one minute.");

// Convert and display the old expiry into local time which should be in 1 minute from now
DateTime oldLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)oldExpiryDateTime, TimeZoneInfo.Local);
LogInfo($"Token: {oldTokenHash} Old Expiry: {oldLocalExpiryTime}");
TimeSpan timeDiff = oldLocalExpiryTime - DateTime.Now;
Assert.InRange(timeDiff.TotalSeconds, 0, 60);

// Check if connection is still alive to continue further testing
string result = "";
SqlCommand cmd = connection.CreateCommand();
cmd.CommandText = "select @@version";
result = $"{cmd.ExecuteScalar()}";
Assert.True(result != string.Empty, "The connection's command must return a value");

// The new connection will use the same FedAuthToken but will refresh it first as it will expire in 1 minute.
using (SqlConnection connection2 = new SqlConnection(connectionString))
{
connection2.Open();

// Check if connection is alive
cmd = connection2.CreateCommand();
cmd.CommandText = "select 1";
result = $"{cmd.ExecuteScalar()}";
Assert.True(result != string.Empty, "The connection's command must return a value after a token refresh.");

string newTokenHash = "";
DateTime? newExpiryDateTime = FedAuthTokenHelper.GetTokenExpiryDateTime(connection2, out newTokenHash);
DateTime newLocalExpiryTime = TimeZoneInfo.ConvertTimeFromUtc((DateTime)newExpiryDateTime, TimeZoneInfo.Local);
LogInfo($"Token: {newTokenHash} New Expiry: {newLocalExpiryTime}");

Assert.True(oldTokenHash == newTokenHash, "The token's hash before and after token refresh must be identical.");
Assert.True(newLocalExpiryTime > oldLocalExpiryTime, "The refreshed token must have a new or later expiry time.");
}
}
}

private void LogInfo(string message)
{
_testOutputHelper.WriteLine(message);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Linq;
using System.Reflection;

namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SQL.Common.SystemDataInternals
{
internal static class FedAuthTokenHelper
{
internal static DateTime? GetTokenExpiryDateTime(SqlConnection connection, out string tokenHash)
{
try
{
object authenticationContextValueObj = GetAuthenticationContextValue(connection);

DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);

tokenHash = GetTokenHash(authenticationContextValueObj);

return expirationTimeProperty;
}
catch (Exception)
{
tokenHash = "";
return null;
}
}

internal static DateTime? SetTokenExpiryDateTime(SqlConnection connection, int minutesToExpire, out string tokenHash)
{
try
{
object authenticationContextValueObj = GetAuthenticationContextValue(connection);

DateTime expirationTimeProperty = (DateTime)authenticationContextValueObj.GetType().GetProperty("ExpirationTime", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);

expirationTimeProperty = DateTime.UtcNow.AddMinutes(minutesToExpire);

FieldInfo expirationTimeInfo = authenticationContextValueObj.GetType().GetField("_expirationTime", BindingFlags.NonPublic | BindingFlags.Instance);
expirationTimeInfo.SetValue(authenticationContextValueObj, expirationTimeProperty);

tokenHash = GetTokenHash(authenticationContextValueObj);

return expirationTimeProperty;
}
catch (Exception)
{
tokenHash = "";
return null;
}
}

internal static string GetTokenHash(object authenticationContextValueObj)
{
try
{
Assembly sqlConnectionAssembly = Assembly.GetAssembly(typeof(SqlConnection));

Type sqlFedAuthTokenType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.SqlFedAuthToken");

Type[] sqlFedAuthTokenTypeArray = new Type[] { sqlFedAuthTokenType };

ConstructorInfo sqlFedAuthTokenConstructorInfo = sqlFedAuthTokenType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

Type activeDirectoryAuthenticationTimeoutRetryHelperType = sqlConnectionAssembly.GetType("Microsoft.Data.SqlClient.ActiveDirectoryAuthenticationTimeoutRetryHelper");

ConstructorInfo activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo = activeDirectoryAuthenticationTimeoutRetryHelperType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, Type.EmptyTypes, null);

object activeDirectoryAuthenticationTimeoutRetryHelperObj = activeDirectoryAuthenticationTimeoutRetryHelperConstructorInfo.Invoke(new object[] { });

MethodInfo tokenHashInfo = activeDirectoryAuthenticationTimeoutRetryHelperObj.GetType().GetMethod("GetTokenHash", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, sqlFedAuthTokenTypeArray, null);

byte[] tokenBytes = (byte[])authenticationContextValueObj.GetType().GetProperty("AccessToken", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(authenticationContextValueObj, null);

object sqlFedAuthTokenObj = sqlFedAuthTokenConstructorInfo.Invoke(new object[] { });
FieldInfo accessTokenInfo = sqlFedAuthTokenObj.GetType().GetField("accessToken", BindingFlags.NonPublic | BindingFlags.Instance);
accessTokenInfo.SetValue(sqlFedAuthTokenObj, tokenBytes);

string tokenHash = (string)tokenHashInfo.Invoke(activeDirectoryAuthenticationTimeoutRetryHelperObj, new object[] { sqlFedAuthTokenObj });

return tokenHash;
}
catch (Exception)
{
return "";
}
}

internal static object GetAuthenticationContextValue(SqlConnection connection)
{
object innerConnectionObj = connection.GetType().GetProperty("InnerConnection", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(connection);

object databaseConnectionPoolObj = innerConnectionObj.GetType().GetProperty("Pool", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(innerConnectionObj);

IEnumerable authenticationContexts = (IEnumerable)databaseConnectionPoolObj.GetType().GetProperty("AuthenticationContexts", BindingFlags.NonPublic | BindingFlags.Instance).GetValue(databaseConnectionPoolObj, null);

object authenticationContextObj = authenticationContexts.Cast<object>().FirstOrDefault();

object authenticationContextValueObj = authenticationContextObj.GetType().GetProperty("Value").GetValue(authenticationContextObj, null);

return authenticationContextValueObj;
}
}
}

0 comments on commit 3487389

Please sign in to comment.