diff --git a/src/Accounts/Accounts.Test/AzureRMProfileTests.cs b/src/Accounts/Accounts.Test/AzureRMProfileTests.cs index 97a4efca724c..468ac8331e2d 100644 --- a/src/Accounts/Accounts.Test/AzureRMProfileTests.cs +++ b/src/Accounts/Accounts.Test/AzureRMProfileTests.cs @@ -662,6 +662,79 @@ public void GetAzureRmSubscriptionByNameMultiplePages() Assert.Equal(tenants[1], resultSubscription.TenantId); } + [Fact] + [Trait(Category.AcceptanceType, Category.CheckIn)] + public void GetAzureRmSubscriptionManagedService() + { + var tenants = new List { Guid.NewGuid().ToString(), DefaultTenant.ToString() }; + var firstTenantSubscriptions = new List { Guid.NewGuid().ToString(), + Guid.NewGuid().ToString(), + Guid.NewGuid().ToString(), + Guid.NewGuid().ToString() }; + var secondTenantSubscriptions = new List { Guid.NewGuid().ToString(), + Guid.NewGuid().ToString(), + Guid.NewGuid().ToString(), + Guid.NewGuid().ToString() }; + + var firstList = new List { firstTenantSubscriptions[0], firstTenantSubscriptions[1] }; + var secondList = new List { firstTenantSubscriptions[2], firstTenantSubscriptions[3] }; + + var thirdList = new List { secondTenantSubscriptions[0], secondTenantSubscriptions[1] }; + var fourthList = new List { secondTenantSubscriptions[2], secondTenantSubscriptions[3] }; + + var client = SetupTestEnvironment(tenants, firstList, secondList, thirdList, fourthList); + + // TEST WITH USER TYPE + var dataStore = new MemoryDataStore(); + AzureSession.Instance.DataStore = dataStore; + var commandRuntimeMock = new MockCommandRuntime(); + AzureSession.Instance.AuthenticationFactory = new MockTokenAuthenticationFactory(); + var profile = new AzureRmProfile(); + profile.EnvironmentTable.Add("foo", new AzureEnvironment(AzureEnvironment.PublicEnvironments.Values.FirstOrDefault())); + profile.DefaultContext = Context; + profile.DefaultContext.Account = new AzureAccount(); + profile.DefaultContext.Tenant.Id = DefaultTenant.ToString(); + + profile.DefaultContext.Account.Type = "User"; + var cmdlt = new GetAzureRMSubscriptionCommand(); + // Setup + cmdlt.DefaultProfile = profile; + cmdlt.CommandRuntime = commandRuntimeMock; + Assert.Null(cmdlt.TenantId); + // Act + cmdlt.InvokeBeginProcessing(); + cmdlt.ExecuteCmdlet(); + cmdlt.InvokeEndProcessing(); + Assert.Null(cmdlt.TenantId); + Assert.True(commandRuntimeMock.OutputPipeline.Count == 8); + + // TEST WITH MANAGEDSERVICE + client = SetupTestEnvironment(tenants, firstList, secondList, thirdList, fourthList); + + dataStore = new MemoryDataStore(); + AzureSession.Instance.DataStore = dataStore; + commandRuntimeMock = new MockCommandRuntime(); + AzureSession.Instance.AuthenticationFactory = new MockTokenAuthenticationFactory(); + profile = new AzureRmProfile(); + profile.EnvironmentTable.Add("foo", new AzureEnvironment(AzureEnvironment.PublicEnvironments.Values.FirstOrDefault())); + profile.DefaultContext = Context; + profile.DefaultContext.Account = new AzureAccount(); + profile.DefaultContext.Tenant.Id = DefaultTenant.ToString(); + + profile.DefaultContext.Account.Type = "ManagedService"; + cmdlt = new GetAzureRMSubscriptionCommand(); + // Setup + cmdlt.DefaultProfile = profile; + cmdlt.CommandRuntime = commandRuntimeMock; + Assert.Null(cmdlt.TenantId); + // Act + cmdlt.InvokeBeginProcessing(); + cmdlt.ExecuteCmdlet(); + cmdlt.InvokeEndProcessing(); + Assert.NotNull(cmdlt.TenantId); + Assert.True(commandRuntimeMock.OutputPipeline.Count == 4); + } + #if NETSTANDARD [Fact(Skip = "ConcurrentDictionary is not marked as Serializable")] [Trait(Category.RunType, Category.DesktopOnly)] diff --git a/src/Accounts/Accounts/Subscription/GetAzureRMSubscription.cs b/src/Accounts/Accounts/Subscription/GetAzureRMSubscription.cs index 012f28f28cdd..855f7507dd80 100644 --- a/src/Accounts/Accounts/Subscription/GetAzureRMSubscription.cs +++ b/src/Accounts/Accounts/Subscription/GetAzureRMSubscription.cs @@ -63,13 +63,12 @@ protected override void BeginProcessing() public override void ExecuteCmdlet() { - var tenant = TenantId; if (!string.IsNullOrWhiteSpace(this.SubscriptionName)) { IAzureSubscription result; try { - if (!this._client.TryGetSubscriptionByName(tenant, this.SubscriptionName, out result)) + if (!this._client.TryGetSubscriptionByName(TenantId, this.SubscriptionName, out result)) { ThrowSubscriptionNotFoundError(this.TenantId, this.SubscriptionName); } @@ -78,7 +77,7 @@ public override void ExecuteCmdlet() } catch (AadAuthenticationException exception) { - ThrowTenantAuthenticationError(tenant, exception); + ThrowTenantAuthenticationError(TenantId, exception); throw; } @@ -88,7 +87,7 @@ public override void ExecuteCmdlet() IAzureSubscription result; try { - if (!this._client.TryGetSubscriptionById(tenant, this.SubscriptionId, out result)) + if (!this._client.TryGetSubscriptionById(TenantId, this.SubscriptionId, out result)) { ThrowSubscriptionNotFoundError(this.TenantId, this.SubscriptionId); } @@ -97,7 +96,7 @@ public override void ExecuteCmdlet() } catch (AadAuthenticationException exception) { - ThrowTenantAuthenticationError(tenant, exception); + ThrowTenantAuthenticationError(TenantId, exception); throw; } @@ -108,26 +107,26 @@ public override void ExecuteCmdlet() { if (DefaultContext.Account.Type.Equals("ManagedService")) { - if (tenant == null) + if (TenantId == null) { - tenant = DefaultContext.Tenant.Id; + TenantId = DefaultContext.Tenant.Id; } - if (tenant.Equals(DefaultContext.Tenant.Id)) + if (TenantId.Equals(DefaultContext.Tenant.Id)) { - var subscriptions = _client.ListSubscriptions(tenant); + var subscriptions = _client.ListSubscriptions(TenantId); WriteObject(subscriptions.Select((s) => new PSAzureSubscription(s)), enumerateCollection: true); } } else { - var subscriptions = _client.ListSubscriptions(tenant); + var subscriptions = _client.ListSubscriptions(TenantId); WriteObject(subscriptions.Select((s) => new PSAzureSubscription(s)), enumerateCollection: true); } } catch (AadAuthenticationException exception) { - ThrowTenantAuthenticationError(tenant, exception); + ThrowTenantAuthenticationError(TenantId, exception); throw; } }