Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Await when adding account + refresh account when needed #17664

Merged
merged 2 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions localization/xliff/enu/constants/localizedConstants.enu.xlf
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@
<trans-unit id="msgPromptFirewallRuleCreated">
<source xml:lang="en">Firewall rule successfully created.</source>
</trans-unit>
<trans-unit id="msgAuthTypeNotFound">
<source xml:lang="en">Failed to get authentication method, please remove and re-add the account.</source>
</trans-unit>
<trans-unit id="msgAccountNotFound">
<source xml:lang="en">Account not found</source>
</trans-unit>
Expand Down
5 changes: 4 additions & 1 deletion src/azure/adal/adalAzureController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import { AzureUserInteraction } from './azureUserInteraction';
import { StorageService } from './storageService';

export class AdalAzureController extends AzureController {

private _authMappings = new Map<AzureAuthType, AzureAuth>();
private cacheProvider: SimpleTokenCache;
private storageService: StorageService;
Expand Down Expand Up @@ -54,6 +53,10 @@ export class AdalAzureController extends AzureController {
return response ? response as IAccount : undefined;
}

public isAccountInCache(account: IAccount): Promise<boolean> {
throw new Error('Method not implemented.');
}

public async getAccountSecurityToken(account: IAccount, tenantId: string, settings: IAADResource): Promise<IToken | undefined> {
let token: IToken | undefined;
let azureAuth = await this.getAzureAuthInstance(getAzureActiveDirectoryConfig());
Expand Down
2 changes: 2 additions & 0 deletions src/azure/azureController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ export abstract class AzureController {
public abstract refreshAccessToken(account: IAccount, accountStore: AccountStore,
tenantId: string | undefined, settings: IAADResource): Promise<IToken | undefined>;

public abstract isAccountInCache(account: IAccount): Promise<boolean>;

public abstract removeAccount(account: IAccount): Promise<void>;

public abstract handleAuthMapping(): void;
Expand Down
47 changes: 26 additions & 21 deletions src/azure/msal/msalAzureAuth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,30 +166,35 @@ export abstract class MsalAzureAuth {
}

public async refreshAccessToken(account: IAccount, tenantId: string, settings: IAADResource): Promise<IAccount | undefined> {
try {
const tokenResult = await this.getToken(account, tenantId, settings);
if (!tokenResult) {
account.isStale = true;
return account;
}
if (account) {
try {
const tokenResult = await this.getToken(account, tenantId, settings);
if (!tokenResult) {
account.isStale = true;
return account;
}

const tokenClaims = this.getTokenClaims(tokenResult.accessToken);
if (!tokenClaims) {
account.isStale = true;
return account;
}
const tokenClaims = this.getTokenClaims(tokenResult.accessToken);
if (!tokenClaims) {
account.isStale = true;
return account;
}

const token: IToken = {
key: tokenResult.account!.homeAccountId,
token: tokenResult.accessToken,
tokenType: tokenResult.tokenType,
expiresOn: tokenResult.account!.idTokenClaims!.exp
};
const token: IToken = {
key: tokenResult.account!.homeAccountId,
token: tokenResult.accessToken,
tokenType: tokenResult.tokenType,
expiresOn: tokenResult.account!.idTokenClaims!.exp
};

return await this.hydrateAccount(token, tokenClaims);
} catch (ex) {
account.isStale = true;
throw ex;
return await this.hydrateAccount(token, tokenClaims);
} catch (ex) {
account.isStale = true;
throw ex;
}
} else {
this.logger.error(`refreshAccessToken: Account not received for refreshing access token.`);
throw Error(LocalizedConstants.msgAccountNotFound);
}
}

Expand Down
19 changes: 16 additions & 3 deletions src/azure/msal/msalAzureController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ export class MsalAzureController extends AzureController {
return response ? response as IAccount : undefined;
}

public async isAccountInCache(account: IAccount): Promise<boolean> {
let authType = getAzureActiveDirectoryConfig();
let azureAuth = await this.getAzureAuthInstance(authType!);
await this.clearOldCacheIfExists();
let accountInfo = await azureAuth.getAccountFromMsalCache(account.key.id);
return accountInfo !== undefined;
}

private async getAzureAuthInstance(authType: AzureAuthType): Promise<MsalAzureAuth | undefined> {
if (!this._authMappings.has(authType)) {
await this.handleAuthMapping();
Expand All @@ -113,9 +121,14 @@ export class MsalAzureController extends AzureController {
return token;
}
} else {
account.isStale = true;
this.logger.error(`_getAccountSecurityToken: Authentication method not found for account ${account.displayInfo.displayName}`);
throw Error('Failed to get authentication method, please remove and re-add the account');
if (account) {
account.isStale = true;
this.logger.error(`_getAccountSecurityToken: Authentication method not found for account ${account.displayInfo.displayName}`);
throw Error(LocalizedConstants.msgAuthTypeNotFound);
} else {
this.logger.error(`_getAccountSecurityToken: Authentication method not found as account not available.`);
throw Error(LocalizedConstants.msgAccountNotFound);
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/controllers/mainController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ export default class MainController implements vscode.Disposable {
this.registerCommand(Constants.cmdAadRemoveAccount);
this._event.on(Constants.cmdAadRemoveAccount, () => this.removeAadAccount(this._prompter));
this.registerCommand(Constants.cmdAadAddAccount);
this._event.on(Constants.cmdAadAddAccount, () => this.addAddAccount());
this._event.on(Constants.cmdAadAddAccount, () => this.addAadAccount());

this.initializeObjectExplorer();

Expand Down Expand Up @@ -1241,7 +1241,7 @@ export default class MainController implements vscode.Disposable {
this.connectionManager.removeAccount(prompter);
}

public addAddAccount(): void {
public addAadAccount(): void {
this.connectionManager.addAccount();
}
}
38 changes: 21 additions & 17 deletions src/models/connectionProfile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,24 +82,39 @@ export class ConnectionProfile extends ConnectionCredentials implements IConnect
type: QuestionTypes.confirm,
name: LocalizedConstants.msgSavePassword,
message: LocalizedConstants.msgSavePassword,
shouldPrompt: (answers) => !profile.connectionString && ConnectionCredentials.isPasswordBasedCredential(profile),
shouldPrompt: () => !profile.connectionString && ConnectionCredentials.isPasswordBasedCredential(profile),
onAnswered: (value) => profile.savePassword = value
},
{
type: QuestionTypes.expand,
name: LocalizedConstants.aad,
message: LocalizedConstants.azureChooseAccount,
choices: azureAccountChoices,
shouldPrompt: (answers) => profile.isAzureActiveDirectory(),
onAnswered: (value) => {
shouldPrompt: () => profile.isAzureActiveDirectory(),
onAnswered: async (value) => {
accountAnswer = value;
if (value !== 'addAccount') {
let account: IAccount = value;
let account = value;
profile.accountId = account?.key.id;
tenantChoices.push(...account?.properties?.tenants.map(t => ({ name: t.displayName, value: t })));
tenantChoices.push(...account?.properties?.tenants!.map(t => ({ name: t.displayName, value: t })));
if (tenantChoices.length === 1) {
profile.tenantId = tenantChoices[0].value.id;
}
try {
profile = await azureController.refreshTokenWrapper(profile, accountStore, accountAnswer, providerSettings.resources.databaseResource);
} catch (error) {
console.log(`Refreshing tokens failed: ${error}`);
}
} else {
try {
profile = await azureController.populateAccountProperties(profile, accountStore, providerSettings.resources.databaseResource);
if (profile) {
vscode.window.showInformationMessage(utils.formatString(LocalizedConstants.accountAddedSuccessfully, profile.email));
}
} catch (e) {
console.error(`Could not add account: ${e}`);
vscode.window.showErrorMessage(e);
}
}
}
},
Expand All @@ -111,7 +126,7 @@ export class ConnectionProfile extends ConnectionCredentials implements IConnect
default: defaultProfileValues ? defaultProfileValues.tenantId : undefined,
// Need not prompt for tenant question when 'Sql Authentication Provider' is enabled,
// since tenant information is received from Server with authority URI in the Login flow.
shouldPrompt: (answers) => profile.isAzureActiveDirectory() && tenantChoices.length > 1 && !getEnableSqlAuthenticationProviderConfig(),
shouldPrompt: () => profile.isAzureActiveDirectory() && tenantChoices.length > 1 && !getEnableSqlAuthenticationProviderConfig(),
onAnswered: (value: ITenant) => {
profile.tenantId = value.id;
}
Expand All @@ -130,17 +145,6 @@ export class ConnectionProfile extends ConnectionCredentials implements IConnect
});

return prompter.prompt(questions, true).then(async answers => {
if (answers?.authenticationType === 'AzureMFA') {
if (answers.AAD === 'addAccount') {
profile = await azureController.populateAccountProperties(profile, accountStore, providerSettings.resources.databaseResource);
} else {
try {
profile = await azureController.refreshTokenWrapper(profile, accountStore, accountAnswer, providerSettings.resources.databaseResource);
} catch (error) {
console.log(`Refreshing tokens failed: ${error}`);
}
}
}
if (answers && profile.isValidProfile()) {
return profile;
}
Expand Down
7 changes: 6 additions & 1 deletion src/objectExplorer/objectExplorerService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -459,11 +459,16 @@ export class ObjectExplorerService {
let azureController = this._connectionManager.azureController;
let account = this._connectionManager.accountStore.getAccount(connectionCredentials.accountId);
let profile = new ConnectionProfile(connectionCredentials);
let needsRefresh: boolean = false;
if (azureController.isSqlAuthProviderEnabled()) {
this._client.logger.verbose('SQL Authentication provider is enabled for Azure MFA connections, skipping token acquiry in extension.');
connectionCredentials.user = account.displayInfo.displayName;
connectionCredentials.email = account.displayInfo.email;
} else if (!connectionCredentials.azureAccountToken) {
if (!azureController.isAccountInCache(account)) {
needsRefresh = true;
}
}
if (!connectionCredentials.azureAccountToken && (!azureController.isSqlAuthProviderEnabled() || needsRefresh)) {
let azureAccountToken = await azureController.refreshAccessToken(
account, this._connectionManager.accountStore, connectionCredentials.tenantId, providerSettings.resources.databaseResource);
if (!azureAccountToken) {
Expand Down
4 changes: 2 additions & 2 deletions src/prompts/adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ export default class CodeAdapter implements IPrompter {
// }

if (!question.shouldPrompt || question.shouldPrompt(answers) === true) {
return prompt.render().then(result => {
return prompt.render().then(async result => {
answers[question.name] = result;

if (question.onAnswered) {
question.onAnswered(result);
await question.onAnswered(result);
}
return answers;
});
Expand Down
2 changes: 1 addition & 1 deletion src/prompts/question.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export interface IQuestion {
// Optional pre-prompt function. Takes in set of answers so far, and returns true if prompt should occur
shouldPrompt?: (answers: { [id: string]: any }) => boolean;
// Optional action to take on the question being answered
onAnswered?: (value: any) => void;
onAnswered?: (value: any) => void | Promise<void>;
// Optional set of options to support matching choices.
matchOptions?: vscode.QuickPickOptions;
}
Expand Down
37 changes: 18 additions & 19 deletions src/views/connectionUI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -615,27 +615,26 @@ export class ConnectionUI {
}

private async createFirewallRule(serverName: string, ipAddress: string): Promise<boolean> {
return this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptRetryFirewallRuleSignedIn,
LocalizedConstants.createFirewallRuleLabel).then(async (result) => {
if (result === LocalizedConstants.createFirewallRuleLabel) {
const firewallService = this.connectionManager.firewallService;
let ipRange = await this.promptForIpAddress(ipAddress);
if (ipRange) {
let firewallResult = await firewallService.createFirewallRule(serverName, ipRange.startIpAddress, ipRange.endIpAddress);
if (firewallResult.result) {
this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptFirewallRuleCreated);
return true;
} else {
Utils.showErrorMsg(firewallResult.errorMessage);
return false;
}
} else {
return false;
}
let result = await this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptRetryFirewallRuleSignedIn,
LocalizedConstants.createFirewallRuleLabel);
if (result === LocalizedConstants.createFirewallRuleLabel) {
const firewallService = this.connectionManager.firewallService;
let ipRange = await this.promptForIpAddress(ipAddress);
if (ipRange) {
let firewallResult = await firewallService.createFirewallRule(serverName, ipRange.startIpAddress, ipRange.endIpAddress);
if (firewallResult.result) {
this._vscodeWrapper.showInformationMessage(LocalizedConstants.msgPromptFirewallRuleCreated);
return true;
} else {
Utils.showErrorMsg(firewallResult.errorMessage);
return false;
}
});
} else {
return false;
}
} else {
return false;
}
}

private promptForRetryConnectWithDifferentCredentials(): PromiseLike<boolean> {
Expand Down Expand Up @@ -672,7 +671,7 @@ export class ConnectionUI {
}

public async addNewAccount(): Promise<IAccount> {
return this.connectionManager.azureController.addAccount(this._accountStore);
return await this.connectionManager.azureController.addAccount(this._accountStore);
}

// Prompts the user to pick a profile for removal, then removes from the global saved state
Expand Down