Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Active Directory Default authentication improvements #1360

Merged
merged 5 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,28 +118,46 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti

string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix;
string[] scopes = new string[] { scope };
TokenRequestContext tokenRequestContext = new(scopes);

/* We split audience from Authority URL here. Audience can be one of the following:
* The Azure AD authority audience enumeration
* The tenant ID, which can be:
* - A GUID (the ID of your Azure AD instance), for single-tenant applications
* - A domain name associated with your Azure AD instance (also for single-tenant applications)
* One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration:
* - `organizations` for a multitenant application
* - `consumers` to sign in users only with their personal accounts
* - `common` to sign in users with their work and school accounts or their personal Microsoft accounts
*
* MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID.
* If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.)
* More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration
**/

int seperatorIndex = parameters.Authority.LastIndexOf('/');
string tenantId = parameters.Authority.Substring(seperatorIndex + 1);
string authority = parameters.Authority.Remove(seperatorIndex + 1);

TokenRequestContext tokenRequestContext = new TokenRequestContext(scopes);
string audience = parameters.Authority.Substring(seperatorIndex + 1);
string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId;

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault)
{
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new DefaultAzureCredentialOptions()
DefaultAzureCredentialOptions defaultAzureCredentialOptions = new()
{
AuthorityHost = new Uri(authority),
ManagedIdentityClientId = clientId,
InteractiveBrowserTenantId = tenantId,
SharedTokenCacheTenantId = tenantId,
SharedTokenCacheUsername = clientId,
VisualStudioCodeTenantId = tenantId,
VisualStudioTenantId = tenantId,
SharedTokenCacheTenantId = audience,
VisualStudioCodeTenantId = audience,
VisualStudioTenantId = audience,
ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications.
};
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);

// Optionally set clientId when available
if (clientId is not null)
{
defaultAzureCredentialOptions.ManagedIdentityClientId = clientId;
defaultAzureCredentialOptions.SharedTokenCacheUsername = clientId;
}
AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}
Expand All @@ -148,15 +166,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti

if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI)
{
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}

AuthenticationResult result;
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
{
AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token);
AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn);
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
}
Expand Down Expand Up @@ -194,13 +212,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync(cancellationToken: cts.Token);
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
else
{
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token);
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
Expand All @@ -213,14 +233,15 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti

result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token);
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync()).GetEnumerator();
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();

IAccount account = default;
if (accounts.MoveNext())
Expand Down Expand Up @@ -250,22 +271,22 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token);
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
catch (MsalUiRequiredException)
{
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
}
else
{
// If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts);
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}
}
Expand Down Expand Up @@ -304,7 +325,8 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
.WithCorrelationId(connectionId)
.WithCustomWebUi(_customWebUI)
.WithLoginHint(userId)
.ExecuteAsync(ctsInteractive.Token);
.ExecuteAsync(ctsInteractive.Token)
.ConfigureAwait(false);
}
else
{
Expand All @@ -328,15 +350,17 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
return await app.AcquireTokenInteractive(scopes)
.WithCorrelationId(connectionId)
.WithLoginHint(userId)
.ExecuteAsync(ctsInteractive.Token);
.ExecuteAsync(ctsInteractive.Token)
.ConfigureAwait(false);
}
}
else
{
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult))
.WithCorrelationId(connectionId)
.ExecuteAsync(cancellationToken: cts.Token);
.ExecuteAsync(cancellationToken: cts.Token)
.ConfigureAwait(false);
return result;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ private static void ConnectAndDisconnect(string connectionString, SqlCredential
private static bool IsAADConnStringsSetup() => DataTestUtility.IsAADPasswordConnStrSetup();
private static bool IsManagedIdentitySetup() => DataTestUtility.ManagedIdentitySupported;

[PlatformSpecific(TestPlatforms.Windows)]
[ConditionalFact(nameof(IsAccessTokenSetup), nameof(IsAADConnStringsSetup))]
public static void KustoDatabaseTest()
{
// This is a sample Kusto database that can be connected by any AD account.
using SqlConnection connection = new SqlConnection("Data Source=help.kusto.windows.net; Authentication=Active Directory Default;Trust Server Certificate=True;");
connection.Open();
Assert.True(connection.State == System.Data.ConnectionState.Open);
}

[ConditionalFact(nameof(IsAccessTokenSetup), nameof(IsAADConnStringsSetup))]
public static void AccessTokenTest()
{
Expand Down