Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pooled connection re-use on access token expiry #635

Merged
merged 5 commits into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ virtual protected bool ReadyToPrepareTransaction
}
}

virtual internal bool IsAccessTokenExpired
{
get
{
return false;
}
}

abstract protected void Activate(Transaction transaction);

internal void ActivateConnection(Transaction transaction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,13 @@ private bool TryGetConnection(DbConnection owningObject, uint waitForMultipleObj
_waitHandles.CreationSemaphore.Release(1);
}
}

// Do not use this pooled connection if access token is about to expire soon before we can connect.
if(null != obj && obj.IsAccessTokenExpired)
{
DestroyObject(obj);
obj = null;
}
} while (null == obj);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
{
SqlConnectionString opt = (SqlConnectionString)options;
SqlConnectionPoolKey key = (SqlConnectionPoolKey)poolKey;
SqlInternalConnection result = null;
SessionData recoverySessionData = null;

SqlConnection sqlOwningConnection = (SqlConnection)owningConnection;
Expand Down Expand Up @@ -131,8 +130,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
}
result = new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
return result;
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool);
}

protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ sealed internal class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
internal bool _federatedAuthenticationAcknowledged;
internal bool _federatedAuthenticationInfoRequested; // Keep this distinct from _federatedAuthenticationRequested, since some fedauth library types may not need more info
internal bool _federatedAuthenticationInfoReceived;

// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
Expand Down Expand Up @@ -181,7 +184,7 @@ internal bool IsDNSCachingBeforeRedirectSupported
}
}

internal SQLDNSInfo pendingSQLDNSObject = null;
internal SQLDNSInfo pendingSQLDNSObject = null;

// TCE flags
internal byte _tceVersionSupported;
Expand Down Expand Up @@ -684,6 +687,16 @@ protected override bool UnbindOnTransactionCompletion
}
}

/// <summary>
/// Validates if federated authentication is used, Access Token used by this connection is active for the next 10 minutes.
/// </summary>
internal override bool IsAccessTokenExpired
{
get
{
return _federatedAuthenticationInfoRequested && DateTime.FromFileTimeUtc(_fedAuthToken.expirationFileTime) < DateTime.UtcNow.AddMinutes(10d);
}
}

////////////////////////////////////////////////////////////////////////////////////////
// GENERAL METHODS
Expand Down Expand Up @@ -2091,8 +2104,6 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
// We want to refresh the token, if taking the lock on the authentication context is successful.
bool attemptRefreshTokenLocked = false;

// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
SqlFedAuthToken fedAuthToken = null;

if (_dbConnectionPool != null)
{
Expand Down Expand Up @@ -2127,7 +2138,7 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
}
else if (_forceExpiryLocked)
{
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out fedAuthToken);
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out _fedAuthToken);
}
#endif

Expand All @@ -2141,11 +2152,11 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)

// Call the function which tries to acquire a lock over the authentication context before trying to update.
// If the lock could not be obtained, it will return false, without attempting to fetch a new token.
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out fedAuthToken);
attemptRefreshTokenLocked = TryGetFedAuthTokenLocked(fedAuthInfo, dbConnectionPoolAuthenticationContext, out _fedAuthToken);

// If TryGetFedAuthTokenLocked returns true, it means lock was obtained and fedAuthToken should not be null.
// If there was an exception in retrieving the new token, TryGetFedAuthTokenLocked should have thrown, so we won't be here.
Debug.Assert(!attemptRefreshTokenLocked || fedAuthToken != null, "Either Lock should not have been obtained or fedAuthToken should not be null.");
Debug.Assert(!attemptRefreshTokenLocked || _fedAuthToken != null, "Either Lock should not have been obtained or fedAuthToken should not be null.");
Debug.Assert(!attemptRefreshTokenLocked || _newDbConnectionPoolAuthenticationContext != null, "Either Lock should not have been obtained or _newDbConnectionPoolAuthenticationContext should not be null.");

// Indicate in EventSource Trace that we are successful with the update.
Expand All @@ -2162,8 +2173,8 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
if (dbConnectionPoolAuthenticationContext == null || attemptRefreshTokenUnLocked)
{
// Get the Federated Authentication Token.
fedAuthToken = GetFedAuthToken(fedAuthInfo);
Debug.Assert(fedAuthToken != null, "fedAuthToken should not be null.");
_fedAuthToken = GetFedAuthToken(fedAuthInfo);
Debug.Assert(_fedAuthToken != null, "fedAuthToken should not be null.");

if (_dbConnectionPool != null)
{
Expand All @@ -2174,18 +2185,19 @@ internal void OnFedAuthInfo(SqlFedAuthInfo fedAuthInfo)
else if (!attemptRefreshTokenLocked)
{
Debug.Assert(dbConnectionPoolAuthenticationContext != null, "dbConnectionPoolAuthenticationContext should not be null.");
Debug.Assert(fedAuthToken == null, "fedAuthToken should be null in this case.");
Debug.Assert(_fedAuthToken == null, "fedAuthToken should be null in this case.");
Debug.Assert(_newDbConnectionPoolAuthenticationContext == null, "_newDbConnectionPoolAuthenticationContext should be null.");

fedAuthToken = new SqlFedAuthToken();
_fedAuthToken = new SqlFedAuthToken();

// If the code flow is here, then we are re-using the context from the cache for this connection attempt and not
// generating a new access token on this thread.
fedAuthToken.accessToken = dbConnectionPoolAuthenticationContext.AccessToken;
_fedAuthToken.accessToken = dbConnectionPoolAuthenticationContext.AccessToken;
_fedAuthToken.expirationFileTime = dbConnectionPoolAuthenticationContext.ExpirationTime.ToFileTime();
}

Debug.Assert(fedAuthToken != null && fedAuthToken.accessToken != null, "fedAuthToken and fedAuthToken.accessToken cannot be null.");
_parser.SendFedAuthToken(fedAuthToken);
Debug.Assert(_fedAuthToken != null && _fedAuthToken.accessToken != null, "fedAuthToken and fedAuthToken.accessToken cannot be null.");
_parser.SendFedAuthToken(_fedAuthToken);
}

/// <summary>
Expand Down Expand Up @@ -2260,7 +2272,7 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
int numberOfAttempts = 0;

// Object that will be returned to the caller, containing all required data about the token.
SqlFedAuthToken fedAuthToken = new SqlFedAuthToken();
_fedAuthToken = new SqlFedAuthToken();

// Username to use in error messages.
string username = null;
Expand Down Expand Up @@ -2300,54 +2312,54 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)

if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
_fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
Task.Run(() => fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken;
Task.Run(() => _fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
case SqlAuthenticationMethod.ActiveDirectoryInteractive:
if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
_fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
authParamsBuilder.WithUserId(ConnectionOptions.UserID);
Task.Run(() => fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken;
Task.Run(() => _fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
case SqlAuthenticationMethod.ActiveDirectoryPassword:
case SqlAuthenticationMethod.ActiveDirectoryServicePrincipal:
if (_activeDirectoryAuthTimeoutRetryHelper.State == ActiveDirectoryAuthenticationTimeoutRetryState.Retrying)
{
fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
_fedAuthToken = _activeDirectoryAuthTimeoutRetryHelper.CachedToken;
}
else
{
if (_credential != null)
{
username = _credential.UserId;
authParamsBuilder.WithUserId(username).WithPassword(_credential.Password);
Task.Run(() => fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
Task.Run(() => _fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
}
else
{
username = ConnectionOptions.UserID;
authParamsBuilder.WithUserId(username).WithPassword(ConnectionOptions.Password);
Task.Run(() => fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
Task.Run(() => _fedAuthToken = authProvider.AcquireTokenAsync(authParamsBuilder).Result.ToSqlFedAuthToken()).Wait();
}
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = fedAuthToken;
_activeDirectoryAuthTimeoutRetryHelper.CachedToken = _fedAuthToken;
}
break;
default:
throw SQL.UnsupportedAuthenticationSpecified(ConnectionOptions.Authentication);
}

Debug.Assert(fedAuthToken.accessToken != null, "AccessToken should not be null.");
Debug.Assert(_fedAuthToken.accessToken != null, "AccessToken should not be null.");
#if DEBUG
if (_forceMsalRetry)
{
Expand Down Expand Up @@ -2417,28 +2429,29 @@ internal SqlFedAuthToken GetFedAuthToken(SqlFedAuthInfo fedAuthInfo)
}
}

Debug.Assert(fedAuthToken != null, "fedAuthToken should not be null.");
Debug.Assert(fedAuthToken.accessToken != null && fedAuthToken.accessToken.Length > 0, "fedAuthToken.accessToken should not be null or empty.");
Debug.Assert(_fedAuthToken != null, "fedAuthToken should not be null.");
Debug.Assert(_fedAuthToken.accessToken != null && _fedAuthToken.accessToken.Length > 0, "fedAuthToken.accessToken should not be null or empty.");

// Store the newly generated token in _newDbConnectionPoolAuthenticationContext, only if using pooling.
if (_dbConnectionPool != null)
{
DateTime expirationTime = DateTime.FromFileTimeUtc(fedAuthToken.expirationFileTime);
_newDbConnectionPoolAuthenticationContext = new DbConnectionPoolAuthenticationContext(fedAuthToken.accessToken, expirationTime);
DateTime expirationTime = DateTime.FromFileTimeUtc(_fedAuthToken.expirationFileTime);
_newDbConnectionPoolAuthenticationContext = new DbConnectionPoolAuthenticationContext(_fedAuthToken.accessToken, expirationTime);
}
SqlClientEventSource.Log.TraceEvent("<sc.SqlInternalConnectionTds.GetFedAuthToken> {0}, Finished generating federated authentication token.", ObjectID);
return fedAuthToken;
return _fedAuthToken;
}

internal void OnFeatureExtAck(int featureId, byte[] data)
{
if (RoutingInfo != null)
{
if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId) {
if (TdsEnums.FEATUREEXT_SQLDNSCACHING != featureId)
{
return;
}
}

switch (featureId)
{
case TdsEnums.FEATUREEXT_SRECOVERY:
Expand Down Expand Up @@ -2636,16 +2649,18 @@ internal void OnFeatureExtAck(int featureId, byte[] data)
throw SQL.ParsingError(ParsingErrorState.CorruptedTdsStream);
}

if (1 == data[0]) {
if (1 == data[0])
{
IsSQLDNSCachingSupported = true;
_cleanSQLDNSCaching = false;

if (RoutingInfo != null)
{
IsDNSCachingBeforeRedirectSupported = true;
}
}
else {
else
{
// we receive the IsSupported whose value is 0
IsSQLDNSCachingSupported = false;
_cleanSQLDNSCaching = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,13 @@ public ConnectionState State
return _state;
}
}
virtual internal bool IsAccessTokenExpired
{
get
{
return false;
}
}

abstract protected void Activate(SysTx.Transaction transaction);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,13 @@ private bool TryGetConnection(DbConnection owningObject, uint waitForMultipleObj
{
Marshal.ThrowExceptionForHR(releaseSemaphoreResult); // will only throw if (hresult < 0)
}

// Do not use this pooled connection if access token is about to expire soon before we can connect.
if (null != obj && obj.IsAccessTokenExpired)
{
DestroyObject(obj);
obj = null;
}
} while (null == obj);
}

Expand Down
Loading