diff --git a/src/Common/Commands.Common.Test/Common/ProfileClientTests.cs b/src/Common/Commands.Common.Test/Common/ProfileClientTests.cs index 850e66b46cc2..222da7747089 100644 --- a/src/Common/Commands.Common.Test/Common/ProfileClientTests.cs +++ b/src/Common/Commands.Common.Test/Common/ProfileClientTests.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using Microsoft.Azure.Subscriptions.Models; using Microsoft.WindowsAzure.Commands.Common.Models; using Microsoft.WindowsAzure.Commands.Common.Test.Mocks; using Microsoft.WindowsAzure.Commands.Profile; @@ -46,6 +47,10 @@ public class ProfileClientTests private AzureSubscription azureSubscription3withoutUser; private AzureEnvironment azureEnvironment; private AzureAccount azureAccount; + private TenantIdDescription commonTenant; + private TenantIdDescription guestTenant; + private Subscriptions.Models.SubscriptionListOperationResponse.Subscription guestRdfeSubscription; + private Subscription guestCsmSubscription; public ProfileClientTests() { @@ -285,6 +290,80 @@ public void AddAzureAccountReturnsAccountWithAllSubscriptionsInCsmMode() Assert.True(account.GetSubscriptions(client.Profile).Any(s => s.Id == new Guid(csmSubscription1.SubscriptionId))); } + [Fact] + public void AddAzureAccountWithImpersonatedGuestWithNoSubscriptions() + { + SetMocks(new[] { rdfeSubscription1 }.ToList(), new List(), + new[] { commonTenant, guestTenant }.ToList(), + (userAccount, environment, tenant) => + { + var token = new MockAccessToken + { + UserId = tenant == commonTenant.TenantId ? userAccount.Id : "UserB", + AccessToken = "def", + LoginType = LoginType.OrgId + }; + userAccount.Id = token.UserId; + return token; + }); + MockDataStore dataStore = new MockDataStore(); + dataStore.VirtualStore[oldProfileDataPath] = oldProfileData; + ProfileClient.DataStore = dataStore; + ProfileClient client = new ProfileClient(); + + var account = client.AddAccountAndLoadSubscriptions(new AzureAccount { Id = "UserA", Type = AzureAccount.AccountType.User }, AzureEnvironment.PublicEnvironments[EnvironmentName.AzureCloud], null); + + Assert.Equal("UserA", account.Id); + Assert.Equal(1, account.GetSubscriptions(client.Profile).Count); + var subrdfe1 = account.GetSubscriptions(client.Profile).FirstOrDefault(s => s.Id == new Guid(rdfeSubscription1.SubscriptionId)); + var userA = client.GetAccount("UserA"); + var userB = client.GetAccount("UserB"); + Assert.NotNull(userA); + Assert.NotNull(userB); + Assert.Contains(rdfeSubscription1.SubscriptionId, userA.GetPropertyAsArray(AzureAccount.Property.Subscriptions), StringComparer.OrdinalIgnoreCase); + Assert.False(userB.HasSubscription(new Guid(rdfeSubscription1.SubscriptionId))); + Assert.NotNull(subrdfe1); + Assert.Equal("UserA", subrdfe1.Account); + } + + [Fact] + public void AddAzureAccountWithImpersonatedGuestWithSubscriptions() + { + SetMocks(new[] { rdfeSubscription1, guestRdfeSubscription }.ToList(), new List(), new[] { commonTenant, guestTenant }.ToList(), + (userAccount, environment, tenant) => + { + var token = new MockAccessToken + { + UserId = tenant == commonTenant.TenantId ? userAccount.Id : "UserB", + AccessToken = "def", + LoginType = LoginType.OrgId + }; + userAccount.Id = token.UserId; + return token; + }); + MockDataStore dataStore = new MockDataStore(); + dataStore.VirtualStore[oldProfileDataPath] = oldProfileData; + ProfileClient.DataStore = dataStore; + ProfileClient client = new ProfileClient(); + + var account = client.AddAccountAndLoadSubscriptions(new AzureAccount { Id = "UserA", Type = AzureAccount.AccountType.User }, AzureEnvironment.PublicEnvironments[EnvironmentName.AzureCloud], null); + + Assert.Equal("UserA", account.Id); + Assert.Equal(1, account.GetSubscriptions(client.Profile).Count); + var subrdfe1 = account.GetSubscriptions(client.Profile).FirstOrDefault(s => s.Id == new Guid(rdfeSubscription1.SubscriptionId)); + var userA = client.GetAccount("UserA"); + var userB = client.GetAccount("UserB"); + var subGuest = userB.GetSubscriptions(client.Profile).FirstOrDefault(s => s.Id == new Guid(guestRdfeSubscription.SubscriptionId)); + Assert.NotNull(userA); + Assert.NotNull(userB); + Assert.Contains(rdfeSubscription1.SubscriptionId, userA.GetPropertyAsArray(AzureAccount.Property.Subscriptions), StringComparer.OrdinalIgnoreCase); + Assert.Contains(guestRdfeSubscription.SubscriptionId, userB.GetPropertyAsArray(AzureAccount.Property.Subscriptions), StringComparer.OrdinalIgnoreCase); + Assert.NotNull(subrdfe1); + Assert.NotNull(subGuest); + Assert.Equal("UserA", subrdfe1.Account); + Assert.Equal("UserB", subGuest.Account); + } + [Fact] public void GetAzureAccountReturnsAccountWithSubscriptions() { @@ -1139,7 +1218,7 @@ public void SelectAzureSubscriptionByIdWorks() cmdlt.InvokeBeginProcessing(); cmdlt.ExecuteCmdlet(); cmdlt.InvokeEndProcessing(); - + Assert.Equal(tempSubscriptions[2].Id, AzureSession.CurrentContext.Subscription.Id); } @@ -1184,21 +1263,40 @@ public void ImportPublishSettingsAddsSecondCertificate() } private void SetMocks(List rdfeSubscriptions, - List csmSubscriptions) + List csmSubscriptions, + List tenants = null, + Func tokenProvider = null) { ClientMocks clientMocks = new ClientMocks(new Guid(defaultSubscription)); clientMocks.LoadRdfeSubscriptions(rdfeSubscriptions); clientMocks.LoadCsmSubscriptions(csmSubscriptions); + clientMocks.LoadTenants(tenants); AzureSession.ClientFactory = new MockClientFactory(new object[] { clientMocks.RdfeSubscriptionClientMock.Object, clientMocks.CsmSubscriptionClientMock.Object }); - AzureSession.AuthenticationFactory = new MockTokenAuthenticationFactory(); + var mockFactory = new MockTokenAuthenticationFactory(); + if (tokenProvider != null) + { + mockFactory.TokenProvider = tokenProvider; + } + + AzureSession.AuthenticationFactory = mockFactory; } private void SetMockData() { + commonTenant = new TenantIdDescription + { + Id = "Common", + TenantId = "Common" + }; + guestTenant = new TenantIdDescription + { + Id = "Guest", + TenantId = "Guest" + }; rdfeSubscription1 = new Subscriptions.Models.SubscriptionListOperationResponse.Subscription { SubscriptionId = "16E3F6FD-A3AA-439A-8FC4-1F5C41D2AD1E", @@ -1213,6 +1311,13 @@ private void SetMockData() SubscriptionStatus = Subscriptions.Models.SubscriptionStatus.Active, ActiveDirectoryTenantId = "Common" }; + guestRdfeSubscription = new Subscriptions.Models.SubscriptionListOperationResponse.Subscription + { + SubscriptionId = "26E3F6FD-A3AA-439A-8FC4-1F5C41D2AD1C", + SubscriptionName = "RdfeSub2", + SubscriptionStatus = Subscriptions.Models.SubscriptionStatus.Active, + ActiveDirectoryTenantId = "Guest" + }; csmSubscription1 = new Azure.Subscriptions.Models.Subscription { Id = "Subscriptions/36E3F6FD-A3AA-439A-8FC4-1F5C41D2AD1E", @@ -1234,6 +1339,13 @@ private void SetMockData() State = "Active", SubscriptionId = "46E3F6FD-A3AA-439A-8FC4-1F5C41D2AD1E" }; + guestCsmSubscription = new Azure.Subscriptions.Models.Subscription + { + Id = "Subscriptions/76E3F6FD-A3AA-439A-8FC4-1F5C41D2AD1D", + DisplayName = "CsmGuestSub", + State = "Active", + SubscriptionId = "76E3F6FD-A3AA-439A-8FC4-1F5C41D2AD1D" + }; azureSubscription1 = new AzureSubscription { Id = new Guid("56E3F6FD-A3AA-439A-8FC4-1F5C41D2AD1E"), diff --git a/src/Common/Commands.Common.Test/Mocks/ClientMocks.cs b/src/Common/Commands.Common.Test/Mocks/ClientMocks.cs index 9d47bf128a18..23adacee0364 100644 --- a/src/Common/Commands.Common.Test/Mocks/ClientMocks.cs +++ b/src/Common/Commands.Common.Test/Mocks/ClientMocks.cs @@ -58,7 +58,7 @@ private SubscriptionCloudCredentials CreateCredentials(Guid subscriptionId) } public static CloudException Make404Exception() - { + { return CloudException.Create( new HttpRequestMessage(), null, @@ -75,18 +75,7 @@ public void LoadCsmSubscriptions(List s Subscriptions = subscriptions })); - var tenantOperationsMock = new Mock(); - tenantOperationsMock.Setup(f => f.ListAsync(new CancellationToken())) - .Returns(Task.Factory.StartNew(() => new Azure.Subscriptions.Models.TenantListResult - { - TenantIds = new[] { new Azure.Subscriptions.Models.TenantIdDescription - { - Id = "1", TenantId = "1" - }}.ToList() - })); - CsmSubscriptionClientMock.Setup(f => f.Subscriptions).Returns(subscriptionOperationsMock.Object); - CsmSubscriptionClientMock.Setup(f => f.Tenants).Returns(tenantOperationsMock.Object); } public void LoadRdfeSubscriptions(List subscriptions) @@ -100,5 +89,26 @@ public void LoadRdfeSubscriptions(List f.Subscriptions).Returns(subscriptionOperationsMock.Object); } + + public void LoadTenants(List tenantIds = null) + { + + tenantIds = tenantIds ?? new[] + { + new Azure.Subscriptions.Models.TenantIdDescription + { + Id = "Common", + TenantId = "Common" + } + }.ToList(); + var tenantOperationsMock = new Mock(); + tenantOperationsMock.Setup(f => f.ListAsync(new CancellationToken())) + .Returns(Task.Factory.StartNew(() => new Azure.Subscriptions.Models.TenantListResult + { + TenantIds = tenantIds + })); + + CsmSubscriptionClientMock.Setup(f => f.Tenants).Returns(tenantOperationsMock.Object); + } } } diff --git a/src/Common/Commands.Common.Test/Mocks/MockTokenAuthenticationFactory.cs b/src/Common/Commands.Common.Test/Mocks/MockTokenAuthenticationFactory.cs index ee85746615f2..8269cdb7e422 100644 --- a/src/Common/Commands.Common.Test/Mocks/MockTokenAuthenticationFactory.cs +++ b/src/Common/Commands.Common.Test/Mocks/MockTokenAuthenticationFactory.cs @@ -12,6 +12,7 @@ // limitations under the License. // ---------------------------------------------------------------------------------- +using System; using System.Security; using System.Security.Cryptography.X509Certificates; using Microsoft.WindowsAzure.Commands.Common.Models; @@ -23,6 +24,8 @@ public class MockTokenAuthenticationFactory : IAuthenticationFactory { public IAccessToken Token { get; set; } + public Func TokenProvider { get; set; } + public MockTokenAuthenticationFactory() { Token = new MockAccessToken @@ -31,6 +34,13 @@ public MockTokenAuthenticationFactory() LoginType = LoginType.OrgId, AccessToken = "abc" }; + + TokenProvider = (account, environment, tenant) => Token = new MockAccessToken + { + UserId = account.Id, + LoginType = LoginType.OrgId, + AccessToken = Token.AccessToken + }; } public MockTokenAuthenticationFactory(string userId, string accessToken) @@ -50,14 +60,7 @@ public IAccessToken Authenticate(AzureAccount account, AzureEnvironment environm account.Id = "test"; } - Token = new MockAccessToken - { - UserId = account.Id, - LoginType = LoginType.OrgId, - AccessToken = Token.AccessToken - }; - - return Token; + return TokenProvider(account, environment, tenant); } public SubscriptionCloudCredentials GetSubscriptionCloudCredentials(AzureContext context) diff --git a/src/Common/Commands.Common/Common/ProfileClient.cs b/src/Common/Commands.Common/Common/ProfileClient.cs index 9e89e74ff103..14b612c9babb 100644 --- a/src/Common/Commands.Common/Common/ProfileClient.cs +++ b/src/Common/Commands.Common/Common/ProfileClient.cs @@ -80,7 +80,7 @@ private static void UpgradeProfile() } AzureProfile oldProfile = new AzureProfile(DataStore, oldProfilePath); - + if (DataStore.FileExists(newProfileFilePath)) { // Merge profile files @@ -111,7 +111,7 @@ private static void UpgradeProfile() // Save the profile to the disk oldProfile.Save(); - + // Rename WindowsAzureProfile.xml to WindowsAzureProfile.json DataStore.RenameFile(oldProfilePath, newProfileFilePath); @@ -173,17 +173,14 @@ public AzureAccount AddAccountAndLoadSubscriptions(AzureAccount account, AzureEn } var subscriptionsFromServer = ListSubscriptionsFromServer( - account, - environment, - password, + account, + environment, + password, password == null ? ShowDialog.Always : ShowDialog.Never).ToList(); - + // If account id is null the login failed if (account.Id != null) { - // Add the account to the profile - AddOrSetAccount(account); - // Update back Profile.Subscriptions foreach (var subscription in subscriptionsFromServer) { @@ -276,7 +273,7 @@ public AzureAccount GetAccountOrNull(string accountName) public AzureAccount GetAccount(string accountName) { var account = GetAccountOrNull(accountName); - + if (account == null) { throw new ArgumentException(string.Format("Account with name '{0}' does not exist.", accountName), "accountName"); @@ -288,7 +285,7 @@ public AzureAccount GetAccount(string accountName) public IEnumerable ListAccounts(string accountName) { List accounts = new List(); - + if (!string.IsNullOrEmpty(accountName)) { if (Profile.Accounts.ContainsKey(accountName)) @@ -684,11 +681,12 @@ private IEnumerable ListSubscriptionsFromServerForAllAccounts private IEnumerable ListSubscriptionsFromServer(AzureAccount account, AzureEnvironment environment, SecureString password, ShowDialog promptBehavior) { + string[] tenants = null; try { if (!account.IsPropertySet(AzureAccount.Property.Tenants)) { - LoadAccountTenants(account, environment, password, promptBehavior); + tenants = LoadAccountTenants(account, environment, password, promptBehavior); } } catch (AadAuthenticationException aadEx) @@ -699,15 +697,14 @@ private IEnumerable ListSubscriptionsFromServer(AzureAccount try { + tenants = tenants ?? account.GetPropertyAsArray(AzureAccount.Property.Tenants); List mergedSubscriptions = MergeSubscriptions( - ListServiceManagementSubscriptions(account, environment, password, ShowDialog.Never).ToList(), - ListResourceManagerSubscriptions(account, environment, password, ShowDialog.Never).ToList()); + ListServiceManagementSubscriptions(account, environment, password, ShowDialog.Never, tenants).ToList(), + ListResourceManagerSubscriptions(account, environment, password, ShowDialog.Never, tenants).ToList()); // Set user ID foreach (var subscription in mergedSubscriptions) { - subscription.Environment = environment.Name; - subscription.Account = account.Id; account.SetOrAppendProperty(AzureAccount.Property.Subscriptions, subscription.Id.ToString()); } @@ -727,7 +724,7 @@ private IEnumerable ListSubscriptionsFromServer(AzureAccount } } - private void LoadAccountTenants(AzureAccount account, AzureEnvironment environment, SecureString password, ShowDialog promptBehavior) + private string[] LoadAccountTenants(AzureAccount account, AzureEnvironment environment, SecureString password, ShowDialog promptBehavior) { var commonTenantToken = AzureSession.AuthenticationFactory.Authenticate(account, environment, AuthenticationFactory.CommonAdTenant, password, promptBehavior); @@ -739,8 +736,7 @@ private void LoadAccountTenants(AzureAccount account, AzureEnvironment environme new TokenCloudCredentials(commonTenantToken.AccessToken), environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ResourceManager))) { - account.SetOrAppendProperty(AzureAccount.Property.Tenants, - subscriptionClient.Tenants.List().TenantIds.Select(ti => ti.TenantId).ToArray()); + return subscriptionClient.Tenants.List().TenantIds.Select(ti => ti.TenantId).ToArray(); } } else @@ -751,8 +747,7 @@ private void LoadAccountTenants(AzureAccount account, AzureEnvironment environme environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ServiceManagement))) { var subscriptionListResult = subscriptionClient.Subscriptions.List(); - account.SetOrAppendProperty(AzureAccount.Property.Tenants, - subscriptionListResult.Subscriptions.Select(s => s.ActiveDirectoryTenantId).Distinct().ToArray()); + return subscriptionListResult.Subscriptions.Select(s => s.ActiveDirectoryTenantId).Distinct().ToArray(); } } } @@ -907,7 +902,7 @@ private AzureAccount MergeAccountProperties(AzureAccount account1, AzureAccount return mergeAccount; } - private IEnumerable ListResourceManagerSubscriptions(AzureAccount account, AzureEnvironment environment, SecureString password, ShowDialog promptBehavior) + private IEnumerable ListResourceManagerSubscriptions(AzureAccount account, AzureEnvironment environment, SecureString password, ShowDialog promptBehavior, string[] tenants) { List result = new List(); @@ -916,11 +911,19 @@ private IEnumerable ListResourceManagerSubscriptions(AzureAcc return result; } - foreach (var tenant in account.GetPropertyAsArray(AzureAccount.Property.Tenants)) + foreach (var tenant in tenants) { try { - var tenantToken = AzureSession.AuthenticationFactory.Authenticate(account, environment, tenant, password, ShowDialog.Never); + var tenantAccount = new AzureAccount(); + CopyAccount(account, tenantAccount); + var tenantToken = AzureSession.AuthenticationFactory.Authenticate(tenantAccount, environment, tenant, password, ShowDialog.Never); + if (tenantAccount.Id == account.Id) + { + tenantAccount = account; + } + + tenantAccount.SetOrAppendProperty(AzureAccount.Property.Tenants, new string[] { tenant }); using (var subscriptionClient = AzureSession.ClientFactory.CreateCustomClient( new TokenCloudCredentials(tenantToken.AccessToken), @@ -937,10 +940,14 @@ private IEnumerable ListResourceManagerSubscriptions(AzureAcc }; psSubscription.SetProperty(AzureSubscription.Property.SupportedModes, AzureModule.AzureResourceManager.ToString()); psSubscription.SetProperty(AzureSubscription.Property.Tenants, tenant); - + psSubscription.Account = tenantAccount.Id; + tenantAccount.SetOrAppendProperty(AzureAccount.Property.Subscriptions, new string[] { psSubscription.Id.ToString() }); result.Add(psSubscription); } } + + AddOrSetAccount(tenantAccount); + } catch (CloudException cEx) { @@ -955,7 +962,13 @@ private IEnumerable ListResourceManagerSubscriptions(AzureAcc return result; } - private IEnumerable ListServiceManagementSubscriptions(AzureAccount account, AzureEnvironment environment, SecureString password, ShowDialog promptBehavior) + private void CopyAccount(AzureAccount sourceAccount, AzureAccount targetAccount) + { + targetAccount.Id = sourceAccount.Id; + targetAccount.Type = sourceAccount.Type; + } + + private IEnumerable ListServiceManagementSubscriptions(AzureAccount account, AzureEnvironment environment, SecureString password, ShowDialog promptBehavior, string[] tenants) { List result = new List(); @@ -964,12 +977,19 @@ private IEnumerable ListServiceManagementSubscriptions(AzureA return result; } - foreach (var tenant in account.GetPropertyAsArray(AzureAccount.Property.Tenants)) + foreach (var tenant in tenants) { try { - var tenantToken = AzureSession.AuthenticationFactory.Authenticate(account, environment, tenant, password, ShowDialog.Never); + var tenantAccount = new AzureAccount(); + CopyAccount(account, tenantAccount); + var tenantToken = AzureSession.AuthenticationFactory.Authenticate(tenantAccount, environment, tenant, password, ShowDialog.Never); + if (tenantAccount.Id == account.Id) + { + tenantAccount = account; + } + tenantAccount.SetOrAppendProperty(AzureAccount.Property.Tenants, new string[] { tenant }); using (var subscriptionClient = AzureSession.ClientFactory.CreateCustomClient( new TokenCloudCredentials(tenantToken.AccessToken), environment.GetEndpointAsUri(AzureEnvironment.Endpoint.ServiceManagement))) @@ -977,18 +997,28 @@ private IEnumerable ListServiceManagementSubscriptions(AzureA var subscriptionListResult = subscriptionClient.Subscriptions.List(); foreach (var subscription in subscriptionListResult.Subscriptions) { - AzureSubscription psSubscription = new AzureSubscription + // only add the subscription if it's actually in this tenant + if (subscription.ActiveDirectoryTenantId == tenant) { - Id = new Guid(subscription.SubscriptionId), - Name = subscription.SubscriptionName, - Environment = environment.Name - }; - psSubscription.Properties[AzureSubscription.Property.SupportedModes] = AzureModule.AzureServiceManagement.ToString(); - psSubscription.SetProperty(AzureSubscription.Property.Tenants, subscription.ActiveDirectoryTenantId); - - result.Add(psSubscription); + AzureSubscription psSubscription = new AzureSubscription + { + Id = new Guid(subscription.SubscriptionId), + Name = subscription.SubscriptionName, + Environment = environment.Name + }; + psSubscription.Properties[AzureSubscription.Property.SupportedModes] = + AzureModule.AzureServiceManagement.ToString(); + psSubscription.SetProperty(AzureSubscription.Property.Tenants, + subscription.ActiveDirectoryTenantId); + psSubscription.Account = tenantAccount.Id; + tenantAccount.SetOrAppendProperty(AzureAccount.Property.Subscriptions, + new string[] { psSubscription.Id.ToString() }); + result.Add(psSubscription); + } } } + + AddOrSetAccount(tenantAccount); } catch (CloudException cEx) { @@ -1114,7 +1144,7 @@ public AzureEnvironment RemoveEnvironment(string name) { throw new ArgumentException(Resources.RemovingDefaultEnvironmentsNotSupported, "name"); } - + if (Profile.Environments.ContainsKey(name)) { var environment = Profile.Environments[name]; @@ -1158,8 +1188,8 @@ public AzureEnvironment AddOrSetEnvironment(AzureEnvironment environment) if (AzureSession.CurrentContext != null && AzureSession.CurrentContext.Environment != null && AzureSession.CurrentContext.Environment.Name == environment.Name) { - AzureSession.SetCurrentContext(AzureSession.CurrentContext.Subscription, - Profile.Environments[environment.Name], + AzureSession.SetCurrentContext(AzureSession.CurrentContext.Subscription, + Profile.Environments[environment.Name], AzureSession.CurrentContext.Account); }