Skip to content

Commit

Permalink
[Release/2.0] Fix pooled connection re-use on access token expiry
Browse files Browse the repository at this point in the history
Ports dotnet#635 to v2.0
  • Loading branch information
cheenamalhotra committed Jul 8, 2020
1 parent 8b8196a commit 251cbbc
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ virtual protected bool ReadyToPrepareTransaction
}
}

virtual internal bool IsAccessTokenExpired
{
get
{
return false;
}
}

abstract protected void Activate(Transaction transaction);

internal void ActivateConnection(Transaction transaction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,13 @@ private bool TryGetConnection(DbConnection owningObject, uint waitForMultipleObj
_waitHandles.CreationSemaphore.Release(1);
}
}

// Do not use this pooled connection if access token is about to expire soon before we can connect.
if(null != obj && obj.IsAccessTokenExpired)
{
DestroyObject(obj);
obj = null;
}
} while (null == obj);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
{
SqlConnectionString opt = (SqlConnectionString)options;
SqlConnectionPoolKey key = (SqlConnectionPoolKey)poolKey;
SqlInternalConnection result = null;
SessionData recoverySessionData = null;

SqlConnection sqlOwningConnection = (SqlConnection)owningConnection;
Expand Down Expand Up @@ -131,8 +130,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
return result;
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
}

protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
internal bool _federatedAuthenticationAcknowledged;
internal bool _federatedAuthenticationInfoRequested; // Keep this distinct from _federatedAuthenticationRequested, since some fedauth library types may not need more info
internal bool _federatedAuthenticationInfoReceived;

// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
Expand Down Expand Up @@ -181,7 +184,7 @@ internal bool IsDNSCachingBeforeRedirectSupported
}
}

internal SQLDNSInfo pendingSQLDNSObject = null;
internal SQLDNSInfo pendingSQLDNSObject = null;

// TCE flags
internal byte _tceVersionSupported;
Expand Down Expand Up @@ -684,6 +687,16 @@ protected override bool UnbindOnTransactionCompletion
}
}

/// <summary>
/// Validates if federated authentication is used, Access Token used by this connection is active for the next 10 minutes.
/// </summary>
internal override bool IsAccessTokenExpired
{
get
{
return _federatedAuthenticationInfoRequested && DateTime.FromFileTimeUtc(_fedAuthToken.expirationFileTime) < DateTime.UtcNow.AddMinutes(10d);
}
}

////////////////////////////////////////////////////////////////////////////////////////
// GENERAL METHODS
Expand Down Expand Up @@ -2091,8 +2104,6 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
// We want to refresh the token, if taking the lock on the authentication context is successful.
bool attemptRefreshTokenLocked = false;

// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken fedAuthToken = null;

if (_dbConnectionPool != null)
{
Expand Down Expand Up @@ -2127,7 +2138,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
}
else if (_forceExpiryLocked)
{
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out fedAuthToken);
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out _fedAuthToken);
}
#endif

Expand All @@ -2141,11 +2152,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)

// Call the function which tries to acquire a lock over the authentication context before trying to update.
// If the lock could not be obtained, it will return false, without attempting to fetch a new token.
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out fedAuthToken);
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out _fedAuthToken);

// If TryGetFedAuthTokenLocked returns true, it means lock was obtained and fedAuthToken should not be null.
// If there was an exception in retrieving the new token, TryGetFedAuthTokenLocked should have thrown, so we won't be here.
Debug.Assert(!attemptRefreshTokenLocked || fedAuthToken != null, "Either Lock should not have been obtained or fedAuthToken should not be null.");
Debug.Assert(!attemptRefreshTokenLocked || _fedAuthToken != null, "Either Lock should not have been obtained or fedAuthToken should not be null.");
Debug.Assert(!attemptRefreshTokenLocked || _newDbConnectionPoolAuthenticationContext != null, "Either Lock should not have been obtained or _newDbConnectionPoolAuthenticationContext should not be null.");

// Indicate in EventSource Trace that we are successful with the update.
Expand All @@ -2162,8 +2173,8 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
if (dbConnectionPoolAuthenticationContext == null || attemptRefreshTokenUnLocked)
{
// Get the Federated Authentication Token.
fedAuthToken = GetFedAuthToken(fedAuthInfo);
Debug.Assert(fedAuthToken != null, "fedAuthToken should not be null.");
_fedAuthToken = GetFedAuthToken(fedAuthInfo);
Debug.Assert(_fedAuthToken != null, "fedAuthToken should not be null.");

if (_dbConnectionPool != null)
{
Expand All @@ -2174,18 +2185,19 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
else if (!attemptRefreshTokenLocked)
{
Debug.Assert(dbConnectionPoolAuthenticationContext != null, "dbConnectionPoolAuthenticationContext should not be null.");
Debug.Assert(fedAuthToken == null, "fedAuthToken should be null in this case.");
Debug.Assert(_fedAuthToken == null, "fedAuthToken should be null in this case.");
Debug.Assert(_newDbConnectionPoolAuthenticationContext == null, "_newDbConnectionPoolAuthenticationContext should be null.");

fedAuthToken = new SqlFedAuthToken();
_fedAuthToken = new SqlFedAuthToken();

// If the code flow is here, then we are re-using the context from the cache for this connection attempt and not
// generating a new access token on this thread.
fedAuthToken.accessToken = dbConnectionPoolAuthenticationContext.AccessToken;
_fedAuthToken.accessToken = dbConnectionPoolAuthenticationContext.AccessToken;
_fedAuthToken.expirationFileTime = dbConnectionPoolAuthenticationContext.ExpirationTime.ToFileTime();
}

Debug.Assert(fedAuthToken != null && fedAuthToken.accessToken != null, "fedAuthToken and fedAuthToken.accessToken cannot be null.");
_parser.SendFedAuthToken(fedAuthToken);
Debug.Assert(_fedAuthToken != null && _fedAuthToken.accessToken != null, "fedAuthToken and fedAuthToken.accessToken cannot be null.");
_parser.SendFedAuthToken(_fedAuthToken);
}

/// <summary>
Expand Down Expand Up @@ -2260,7 +2272,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
int numberOfAttempts = 0;

// Object that will be returned to the caller, containing all required data about the token.
SqlFedAuthToken fedAuthToken = new SqlFedAuthToken();
_fedAuthToken = new SqlFedAuthToken();

// Username to use in error messages.
string username = null;
Expand Down Expand Up @@ -2300,54 +2312,54 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)

if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
_fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
Task.Run(() => fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken;
Task.Run(() => _fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
case SqlAuthenticationMethod.ActiveDirectoryInteractive:
if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
_fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
authParamsBuilder.WithUserId(ConnectionOptions.UserID);
Task.Run(() => fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken;
Task.Run(() => _fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
case SqlAuthenticationMethod.ActiveDirectoryPassword:
case SqlAuthenticationMethod.ActiveDirectoryServicePrincipal:
if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
_fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
if (_credential != null)
{
username = _credential.UserId;
authParamsBuilder.WithUserId(username).WithPassword(_credential.Password);
Task.Run(() => fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
Task.Run(() => _fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
}
else
{
username = ConnectionOptions.UserID;
authParamsBuilder.WithUserId(username).WithPassword(ConnectionOptions.Password);
Task.Run(() => fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
Task.Run(() => _fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
}
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken;
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
default:
throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication);
}

Debug.Assert(fedAuthToken.accessToken != null, "AccessToken should not be null.");
Debug.Assert(_fedAuthToken.accessToken != null, "AccessToken should not be null.");
#if DEBUG
if (_forceMsalRetry)
{
Expand Down Expand Up @@ -2417,28 +2429,29 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
}
}

Debug.Assert(fedAuthToken != null, "fedAuthToken should not be null.");
Debug.Assert(fedAuthToken.accessToken != null && fedAuthToken.accessToken.Length > 0, "fedAuthToken.accessToken should not be null or empty.");
Debug.Assert(_fedAuthToken != null, "fedAuthToken should not be null.");
Debug.Assert(_fedAuthToken.accessToken != null && _fedAuthToken.accessToken.Length > 0, "fedAuthToken.accessToken should not be null or empty.");

// Store the newly generated token in _newDbConnectionPoolAuthenticationContext, only if using pooling.
if (_dbConnectionPool != null)
{
DateTime expirationTime = DateTime.FromFileTimeUtc(fedAuthToken.expirationFileTime);
_newDbConnectionPoolAuthenticationContext = new DbConnectionPoolAuthenticationContext(fedAuthToken.accessToken, expirationTime);
DateTime expirationTime = DateTime.FromFileTimeUtc(_fedAuthToken.expirationFileTime);
_newDbConnectionPoolAuthenticationContext = new DbConnectionPoolAuthenticationContext(_fedAuthToken.accessToken, expirationTime);
}
SqlClientEventSource.Log.TraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken> {0}, Finished generating federated authentication token.", ObjectID);
return fedAuthToken;
return _fedAuthToken;
}

internal void OnFeatureExtAck(int featureId, byte[] data)
{
if (RoutingInfo != null)
{
if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId) {
if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId)
{
return;
}
}

switch (featureId)
{
case TdsEnums.FEATUREEXT_SRECOVERY:
Expand Down Expand Up @@ -2636,16 +2649,18 @@ internal void OnFeatureExtAck(int featureId, byte[] data)
throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream);
}

if (1 == data[0]) {
if (1 == data[0])
{
IsSQLDNSCachingSupported = true;
_cleanSQLDNSCaching = false;

if (RoutingInfo != null)
{
IsDNSCachingBeforeRedirectSupported = true;
}
}
else {
else
{
// we receive the IsSupported whose value is 0
IsSQLDNSCachingSupported = false;
_cleanSQLDNSCaching = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,13 @@ public ConnectionState State
return _state;
}
}
virtual internal bool IsAccessTokenExpired
{
get
{
return false;
}
}

abstract protected void Activate(SysTx.Transaction transaction);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,13 @@ private bool TryGetConnection(DbConnection owningObject, uint waitForMultipleObj
{
Marshal.ThrowExceptionForHR(releaseSemaphoreResult); // will only throw if (hresult < 0)
}

// Do not use this pooled connection if access token is about to expire soon before we can connect.
if (null != obj && obj.IsAccessTokenExpired)
{
DestroyObject(obj);
obj = null;
}
} while (null == obj);
}

Expand Down
Loading

0 comments on commit 251cbbc

Please sign in to comment.