diff --git a/.ci-config.json b/.ci-config.json index 98304dca0627..3a0ad5c881fd 100644 --- a/.ci-config.json +++ b/.ci-config.json @@ -323,7 +323,7 @@ "Compute", "Functions", "KeyVault", - "KubernetersConfiguration", + "KubernetesConfiguration", "Network", "PostgreSql", "Purview", diff --git a/tools/Modules/TestFx-Tasks.psm1 b/tools/Modules/TestFx-Tasks.psm1 index a6a8ffc5ce8a..8b15900bff90 100644 --- a/tools/Modules/TestFx-Tasks.psm1 +++ b/tools/Modules/TestFx-Tasks.psm1 @@ -17,7 +17,7 @@ $script:TestFxEnvExtraPropKeys = @( ) function Set-TestFxEnvironment { - [CmdletBinding(DefaultParameterSetName = "NewServicePrincipal")] + [CmdletBinding(DefaultParameterSetName = "UserAccount")] param( [Parameter(Mandatory)] [ValidateNotNullOrEmpty()] @@ -27,6 +27,10 @@ function Set-TestFxEnvironment { [ValidateNotNullOrEmpty()] [guid] $TenantId, + [Parameter(Mandatory, ParameterSetName = "UserAccount")] + [ValidateNotNullOrEmpty()] + [guid] $UserId, + [Parameter(Mandatory, ParameterSetName = "NewServicePrincipal")] [ValidateNotNullOrEmpty()] [string] $ServicePrincipalDisplayName, @@ -108,11 +112,23 @@ function Set-TestFxEnvironment { } } + $testFxEnvProps = [PSCustomObject]@{ + Environment = $TargetEnvironment + SubscriptionId = $SubscriptionId + TenantId = $TenantId + HttpRecorderMode = $RecorderMode + } + switch ($PSCmdlet.ParameterSetName) { + "UserAccount" { + $testFxEnvProps | Add-Member -NotePropertyName UserId -NotePropertyValue $UserId + } "NewServicePrincipal" { $sp = New-TestFxServicePrincipal -SubscriptionId $SubscriptionId -ServicePrincipalDisplayName $ServicePrincipalDisplayName -Force:$Force $spAppId = $sp.AppId $spSecret = $sp.PasswordCredentials.SecretText + $testFxEnvProps | Add-Member -NotePropertyName ServicePrincipal -NotePropertyValue $spAppId + $testFxEnvProps | Add-Member -NotePropertyName ServicePrincipalSecret -NotePropertyValue $spSecret } "ExistingServicePrincipal" { $sp = Get-AzADServicePrincipal -ApplicationId $ServicePrincipalId @@ -122,18 +138,11 @@ function Set-TestFxEnvironment { $spAppId = $ServicePrincipalId $spSecret = $ServicePrincipalSecret + $testFxEnvProps | Add-Member -NotePropertyName ServicePrincipal -NotePropertyValue $spAppId + $testFxEnvProps | Add-Member -NotePropertyName ServicePrincipalSecret -NotePropertyValue $spSecret } } - $testFxEnvProps = [PSCustomObject]@{ - Environment = $TargetEnvironment - SubscriptionId = $SubscriptionId - TenantId = $TenantId - ServicePrincipal = $spAppId - ServicePrincipalSecret = $spSecret - HttpRecorderMode = $RecorderMode - } - $script:testFxEnvExtraPropKeys | ForEach-Object { if ($PSBoundParameters.ContainsKey($_)) { $testFxEnvProps | Add-Member -NotePropertyName $_ -NotePropertyValue $PSBoundParameters[$_] diff --git a/tools/TestFx/ConnectionString.cs b/tools/TestFx/ConnectionString.cs index f0dd34050b4d..8b6e4f52e780 100644 --- a/tools/TestFx/ConnectionString.cs +++ b/tools/TestFx/ConnectionString.cs @@ -26,7 +26,6 @@ public class ConnectionString private Dictionary _keyValuePairs; private string _connString; private StringBuilder _parseErrorSb; - private string DEFAULT_TENANTID = "72f988bf-86f1-41af-91ab-2d7cd011db47"; public Dictionary KeyValuePairs { @@ -75,46 +74,13 @@ public ConnectionString(string connString) : this() { _connString = connString; Parse(_connString); //Keyvalue pairs are normalized and is called from Parse(string) function - NormalizeKeyValuePairs(); - } - - private void NormalizeKeyValuePairs() - { - string clientId, spn, password, spnSecret, userId, aadTenantId; - KeyValuePairs.TryGetValue(ConnectionStringKeys.AADClientIdKey, out clientId); - KeyValuePairs.TryGetValue(ConnectionStringKeys.ServicePrincipalKey, out spn); - - KeyValuePairs.TryGetValue(ConnectionStringKeys.UserIdKey, out userId); - KeyValuePairs.TryGetValue(ConnectionStringKeys.PasswordKey, out password); - KeyValuePairs.TryGetValue(ConnectionStringKeys.ServicePrincipalSecretKey, out spnSecret); - KeyValuePairs.TryGetValue(ConnectionStringKeys.TenantIdKey, out aadTenantId); - - //ClientId was provided and servicePrincipal was empty, we want ServicePrincipal to be initialized - //At some point we will deprecate ClientId keyName - if (!string.IsNullOrEmpty(clientId) && (string.IsNullOrEmpty(spn))) - { - KeyValuePairs[ConnectionStringKeys.ServicePrincipalKey] = clientId; - } - - //Set the value of PasswordKey to ServicePrincipalSecret ONLY if userId is empty - //If UserId is not empty, we are not sure if it's a password for inter active login or ServicePrincipal SecretKey - if (!string.IsNullOrEmpty(password) && (string.IsNullOrEmpty(spnSecret)) && (string.IsNullOrEmpty(userId))) - { - KeyValuePairs[ConnectionStringKeys.ServicePrincipalSecretKey] = password; - } - - //Initialize default value for AADTenent - if (string.IsNullOrEmpty(aadTenantId)) - { - KeyValuePairs[ConnectionStringKeys.TenantIdKey] = DEFAULT_TENANTID; - } } public void Parse(string connString) { string parseRegEx = @"(?[^=]+)=(?.+)"; - if (_parseErrorSb != null) _parseErrorSb.Clear(); + _parseErrorSb?.Clear(); if (string.IsNullOrEmpty(connString)) { @@ -161,10 +127,6 @@ public void Parse(string connString) ParseErrors = string.Format("Incorrect '{0}' keyValue pair format", pair); } } - - //Adjust key-value pairs and normalize values across multiple keys - //We need to do this here because Connection string can be parsed multiple time within same instance - NormalizeKeyValuePairs(); } } diff --git a/tools/TestFx/DelegatingHandlers/HttpMockServer.cs b/tools/TestFx/DelegatingHandlers/HttpMockServer.cs index b530de01c4c4..e4584fc5d8a8 100644 --- a/tools/TestFx/DelegatingHandlers/HttpMockServer.cs +++ b/tools/TestFx/DelegatingHandlers/HttpMockServer.cs @@ -35,7 +35,7 @@ public class HttpMockServer : DelegatingHandler public static string CallerIdentity { get; set; } public static string TestIdentity { get; set; } - public static HttpRecorderMode Mode { get; set; } + public static HttpRecorderMode Mode { get; internal set; } public static IRecordMatcher Matcher { get; set; } public static string RecordsDirectory { get; set; } public static Dictionary Variables { get; private set; } diff --git a/tools/TestFx/EnvironmentSetupHelper.cs b/tools/TestFx/EnvironmentSetupHelper.cs index 448aa2b7e66d..f4f9497f5713 100644 --- a/tools/TestFx/EnvironmentSetupHelper.cs +++ b/tools/TestFx/EnvironmentSetupHelper.cs @@ -26,7 +26,6 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; -using System.Diagnostics; using System.IO; using System.Linq; using System.Management.Automation; @@ -40,9 +39,9 @@ namespace Microsoft.Azure.Commands.TestFx { public class EnvironmentSetupHelper { - private const string TestEnvironmentName = "__test-environment"; + private const string TestFxEnvironmentName = "__testfx-environment"; - private const string TestSubscriptionName = "__test-subscriptions"; + private const string TestFxSubscriptionName = "__testfx-subscription"; private static string PackageDirectoryFromCommon { get; } = GetConfigDirectory(); @@ -134,9 +133,6 @@ public EnvironmentSetupHelper() // Ignore SSL errors System.Net.ServicePointManager.ServerCertificateValidationCallback += (se, cert, chain, sslerror) => true; - // Set RunningMocked - TestMockSupport.RunningMocked = HttpMockServer.GetCurrentMode() == HttpRecorderMode.Playback; - if (File.Exists(Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), Resources.AzureDirectoryName, "testcredentials.json"))) { SetEnvironmentVariableFromCredentialFile(); @@ -388,62 +384,67 @@ public void SetupAzureEnvironmentFromEnvironmentVariables(AzureModule mode) throw new NotSupportedException("RDFE environment is not supported in .Net Core"); } - if (currentEnvironment.UserName == null) - { - currentEnvironment.UserName = "fakeuser@microsoft.com"; - } - SetAuthenticationFactory(currentEnvironment); - AzureEnvironment environment = new AzureEnvironment { Name = TestEnvironmentName }; - Debug.Assert(currentEnvironment != null); - environment.ActiveDirectoryAuthority = currentEnvironment.Endpoints.AADAuthUri.AbsoluteUri; - environment.GalleryUrl = currentEnvironment.Endpoints.GalleryUri?.AbsoluteUri; - environment.ServiceManagementUrl = currentEnvironment.BaseUri.AbsoluteUri; - environment.ResourceManagerUrl = currentEnvironment.Endpoints.ResourceManagementUri.AbsoluteUri; - environment.GraphUrl = currentEnvironment.Endpoints.GraphUri.AbsoluteUri; - environment.AzureDataLakeAnalyticsCatalogAndJobEndpointSuffix = currentEnvironment.Endpoints.DataLakeAnalyticsJobAndCatalogServiceUri.OriginalString.Replace("https://", ""); // because it is just a sufix - environment.AzureDataLakeStoreFileSystemEndpointSuffix = currentEnvironment.Endpoints.DataLakeStoreServiceUri.OriginalString.Replace("https://", ""); // because it is just a sufix - environment.StorageEndpointSuffix = AzureEnvironmentConstants.AzureStorageEndpointSuffix; - environment.AzureKeyVaultDnsSuffix = AzureEnvironmentConstants.AzureKeyVaultDnsSuffix; - environment.AzureKeyVaultServiceEndpointResourceId = AzureEnvironmentConstants.AzureKeyVaultServiceEndpointResourceId; - environment.ExtendedProperties.SetProperty(AzureEnvironment.ExtendedEndpoint.MicrosoftGraphUrl, currentEnvironment.Endpoints.GraphUri.AbsoluteUri); - environment.ExtendedProperties.SetProperty(AzureEnvironment.ExtendedEndpoint.OperationalInsightsEndpoint, "https://api.loganalytics.io/v1"); - environment.ExtendedProperties.SetProperty(AzureEnvironment.ExtendedEndpoint.OperationalInsightsEndpointResourceId, "https://api.loganalytics.io"); - if (!AzureRmProfileProvider.Instance.GetProfile().EnvironmentTable.ContainsKey(TestEnvironmentName)) - { - AzureRmProfileProvider.Instance.GetProfile().EnvironmentTable[TestEnvironmentName] = environment; - } - - if (currentEnvironment.SubscriptionId != null) - { - var testSubscription = new AzureSubscription - { - Id = currentEnvironment.SubscriptionId, - Name = TestSubscriptionName, - }; - testSubscription.SetEnvironment(TestEnvironmentName); - testSubscription.SetAccount(currentEnvironment.UserName); + AzureEnvironment testEnvironment = new AzureEnvironment + { + Name = TestFxEnvironmentName, + ActiveDirectoryAuthority = currentEnvironment.Endpoints.AADAuthUri.AbsoluteUri, + ActiveDirectoryServiceEndpointResourceId = currentEnvironment.Endpoints.AADTokenAudienceUri.AbsoluteUri, + GraphUrl = currentEnvironment.Endpoints.GraphUri.AbsoluteUri, + GraphEndpointResourceId = currentEnvironment.Endpoints.GraphTokenAudienceUri.AbsoluteUri, + ResourceManagerUrl = currentEnvironment.Endpoints.ResourceManagementUri.AbsoluteUri, + ServiceManagementUrl = currentEnvironment.Endpoints.ServiceManagementUri.AbsoluteUri, + GalleryUrl = currentEnvironment.Endpoints.GalleryUri?.AbsoluteUri, + AzureDataLakeAnalyticsCatalogAndJobEndpointSuffix = currentEnvironment.Endpoints.DataLakeAnalyticsJobAndCatalogServiceUri.OriginalString.Replace("https://", ""), // because it is just a sufix + AzureDataLakeStoreFileSystemEndpointSuffix = currentEnvironment.Endpoints.DataLakeStoreServiceUri.OriginalString.Replace("https://", ""), // because it is just a sufix + StorageEndpointSuffix = AzureEnvironmentConstants.AzureStorageEndpointSuffix, + AzureKeyVaultDnsSuffix = AzureEnvironmentConstants.AzureKeyVaultDnsSuffix, + AzureKeyVaultServiceEndpointResourceId = AzureEnvironmentConstants.AzureKeyVaultServiceEndpointResourceId + }; + testEnvironment.ExtendedProperties.SetProperty(AzureEnvironment.ExtendedEndpoint.MicrosoftGraphUrl, currentEnvironment.Endpoints.GraphUri.AbsoluteUri); + testEnvironment.ExtendedProperties.SetProperty(AzureEnvironment.ExtendedEndpoint.OperationalInsightsEndpoint, "https://api.loganalytics.io/v1"); + testEnvironment.ExtendedProperties.SetProperty(AzureEnvironment.ExtendedEndpoint.OperationalInsightsEndpointResourceId, "https://api.loganalytics.io"); + if (!AzureRmProfileProvider.Instance.GetProfile().EnvironmentTable.ContainsKey(TestFxEnvironmentName)) + { + AzureRmProfileProvider.Instance.GetProfile().EnvironmentTable[TestFxEnvironmentName] = testEnvironment; + } + + AzureSubscription testSubscription = new AzureSubscription(); + if (!string.IsNullOrEmpty(currentEnvironment.SubscriptionId)) + { + testSubscription.Id = currentEnvironment.SubscriptionId; + testSubscription.Name = TestFxSubscriptionName; + testSubscription.SetEnvironment(TestFxEnvironmentName); + testSubscription.SetTenant(currentEnvironment.TenantId); + testSubscription.SetAccount(currentEnvironment.UserId); testSubscription.SetDefault(); testSubscription.SetStorageAccount(Environment.GetEnvironmentVariable("AZURE_STORAGE_ACCOUNT")); + } - var testAccount = new AzureAccount() - { - Id = currentEnvironment.UserName, - Type = AzureAccount.AccountType.User, - }; + AzureTenant testTenant = new AzureTenant(); + if (!string.IsNullOrEmpty(currentEnvironment.TenantId)) + { + testTenant.Id = currentEnvironment.TenantId; + } - testAccount.SetSubscriptions(currentEnvironment.SubscriptionId); - var testTenant = new AzureTenant() { Id = Guid.NewGuid().ToString() }; - if (!string.IsNullOrEmpty(currentEnvironment.TenantId)) - { - if (Guid.TryParse(currentEnvironment.TenantId, out _)) - { - testTenant.Id = currentEnvironment.TenantId; - } - } - AzureRmProfileProvider.Instance.Profile.DefaultContext = new AzureContext(testSubscription, testAccount, environment, testTenant); + AzureAccount testAccount = new AzureAccount(); + if (!string.IsNullOrEmpty(currentEnvironment.UserId)) + { + testAccount.Id = currentEnvironment.UserId; + testAccount.Type = AzureAccount.AccountType.User; } + else if (!string.IsNullOrEmpty(currentEnvironment.ServicePrincipalClientId) && !string.IsNullOrEmpty(currentEnvironment.ServicePrincipalSecret)) + { + testAccount.Id = currentEnvironment.ServicePrincipalClientId; + testAccount.Type = AzureAccount.AccountType.ServicePrincipal; + } + + testAccount.SetAccessToken(string.Empty); + testAccount.SetSubscriptions(currentEnvironment.SubscriptionId); + testAccount.SetTenants(currentEnvironment.TenantId); + + AzureRmProfileProvider.Instance.Profile.DefaultContext = new AzureContext(testSubscription, testAccount, testEnvironment, testTenant); } private void SetAuthenticationFactory(TestEnvironment environment) @@ -457,7 +458,7 @@ private void SetAuthenticationFactory(TestEnvironment environment) .GetAwaiter() .GetResult(); - AzureSession.Instance.AuthenticationFactory = new MockTokenAuthenticationFactory(environment.UserName, httpMessage.Headers.Authorization.Parameter); + AzureSession.Instance.AuthenticationFactory = new MockTokenAuthenticationFactory(environment.UserId, httpMessage.Headers.Authorization.Parameter); } } @@ -524,6 +525,7 @@ public virtual Collection RunPowerShellTest(params string[] scripts) Collection output = null; foreach (var script in scripts) { + Console.WriteLine($"Executing test: {script}"); TracingInterceptor?.Information(script); powershell.AddScript(script); } diff --git a/tools/TestFx/Mocks/MockClientFactory.cs b/tools/TestFx/Mocks/MockClientFactory.cs index 219711084f81..3d0d97700721 100644 --- a/tools/TestFx/Mocks/MockClientFactory.cs +++ b/tools/TestFx/Mocks/MockClientFactory.cs @@ -17,7 +17,6 @@ using Microsoft.Azure.Commands.Common.Authentication.Factories; using Microsoft.Azure.Commands.Common.Authentication.Models; using Microsoft.Rest.Azure; -using Microsoft.WindowsAzure.Commands.Utilities.Common; using System; using System.Collections.Generic; using System.Diagnostics; @@ -28,7 +27,6 @@ using System.Reflection; using System.Threading; using System.Threading.Tasks; -using Microsoft.Azure.Commands.TestFx.DelegatingHandlers; using Microsoft.Rest.ClientRuntime.Azure.TestFramework; using Microsoft.Azure.Test.HttpRecorder; using Microsoft.Azure.Commands.Common.MSGraph.Version1_0; @@ -104,7 +102,7 @@ public TClient CreateCustomArmClient(params object[] parameters) where client = realClientFactory.CreateCustomArmClient(newParameters); } - if (TestMockSupport.RunningMocked && HttpMockServer.GetCurrentMode() != HttpRecorderMode.Record) + if (HttpMockServer.Mode == HttpRecorderMode.Playback) { if (client is IAzureClient azureClient) { diff --git a/tools/TestFx/Mocks/MockContext.cs b/tools/TestFx/Mocks/MockContext.cs index e244b782e2e1..9d132f7ad3c1 100644 --- a/tools/TestFx/Mocks/MockContext.cs +++ b/tools/TestFx/Mocks/MockContext.cs @@ -16,6 +16,7 @@ using Microsoft.Azure.Commands.TestFx.DelegatingHandlers; using Microsoft.Azure.Commands.TestFx.Recorder; using Microsoft.Azure.Test.HttpRecorder; +using Microsoft.WindowsAzure.Commands.Utilities.Common; using System; using System.Collections.Generic; using System.Net.Http; @@ -66,6 +67,8 @@ public static MockContext Start( } HttpMockServer.Initialize(className, methodName); + TestMockSupport.RunningMocked = HttpMockServer.Mode == HttpRecorderMode.Playback; + return context; } diff --git a/tools/TestFx/TestEnvironment.cs b/tools/TestFx/TestEnvironment.cs index 16db63a76324..d9f1ec226a41 100644 --- a/tools/TestFx/TestEnvironment.cs +++ b/tools/TestFx/TestEnvironment.cs @@ -12,19 +12,21 @@ // limitations under the License. // ---------------------------------------------------------------------------------- +using Microsoft.Azure.Commands.Common.Authentication; using Microsoft.Azure.Test.HttpRecorder; using Microsoft.Identity.Client; +using Microsoft.Identity.Client.Extensions.Msal; using Microsoft.Rest; using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; +using System.IO; using System.Linq; using System.Net.Http; using System.Text; using System.Threading; -using System.Threading.Tasks; namespace Microsoft.Azure.Commands.TestFx { @@ -69,7 +71,7 @@ public class TestEnvironment /// /// UserName used by the Test Environment /// - public string UserName { get; set; } + public string UserId { get; private set; } /// /// Active TestEndpoint being used by the Test Environment @@ -101,12 +103,17 @@ public TestEnvironment(string connectionString) SubscriptionId = ConnectionString.GetValue(ConnectionStringKeys.SubscriptionIdKey); TenantId = ConnectionString.GetValue(ConnectionStringKeys.TenantIdKey); + UserId = ConnectionString.GetValue(ConnectionStringKeys.UserIdKey); ServicePrincipalClientId = ConnectionString.GetValue(ConnectionStringKeys.ServicePrincipalKey); ServicePrincipalSecret = ConnectionString.GetValue(ConnectionStringKeys.ServicePrincipalSecretKey); - UserName = ConnectionString.GetValue(ConnectionStringKeys.UserIdKey); OptimizeRecordedFile = ConnectionString.GetValue(ConnectionStringKeys.OptimizeRecordedFileKey); - if (string.IsNullOrEmpty(ConnectionString.GetValue(ConnectionStringKeys.BaseUriKey))) + if (string.IsNullOrWhiteSpace(ServicePrincipalClientId) && string.IsNullOrWhiteSpace(UserId)) + { + UserId = "fakeuser"; + } + + if (string.IsNullOrWhiteSpace(ConnectionString.GetValue(ConnectionStringKeys.BaseUriKey))) { BaseUri = Endpoints.ResourceManagementUri; } @@ -124,8 +131,8 @@ public TestEnvironment(string connectionString) GraphUri = new Uri(ConnectionString.GetValue(ConnectionStringKeys.GraphUriKey)); } - InitTokenDictionary(); SetupHttpRecorderMode(); + SetupTokenDictionary(); } private void InitTestEndPoints() @@ -153,35 +160,12 @@ private void LoadDefaultEnvironmentEndpoints() }; } - private void InitTokenDictionary() - { - TokenInfo = new Dictionary(); - - ConnectionString.KeyValuePairs.TryGetValue(ConnectionStringKeys.RawTokenKey, out string rawToken); - ConnectionString.KeyValuePairs.TryGetValue(ConnectionStringKeys.RawGraphTokenKey, out string rawGraphToken); - - // We need TokenInfo to be non-empty as there are cases where have taken dependency on non-empty TokenInfo in MockContext - if (string.IsNullOrEmpty(rawToken)) - { - rawToken = ConnectionStringKeys.RawTokenKey; - } - - if (string.IsNullOrEmpty(rawGraphToken)) - { - rawGraphToken = ConnectionStringKeys.RawGraphTokenKey; - } - - TokenInfo[TokenAudience.Management] = new TokenCredentials(rawToken); - TokenInfo[TokenAudience.Graph] = new TokenCredentials(rawGraphToken); - } - private void SetupHttpRecorderMode() { - string testMode = Environment.GetEnvironmentVariable(ConnectionStringKeys.AZURE_TEST_MODE_ENVKEY); - + string testMode = ConnectionString.GetValue(ConnectionStringKeys.HttpRecorderModeKey); if (string.IsNullOrEmpty(testMode)) { - testMode = ConnectionString.GetValue(ConnectionStringKeys.HttpRecorderModeKey); + testMode = Environment.GetEnvironmentVariable(ConnectionStringKeys.AZURE_TEST_MODE_ENVKEY); } // Ideally we should be throwing when incompatible environment (e.g. Environment=Foo) is provided in connection string @@ -198,48 +182,107 @@ private void SetupHttpRecorderMode() } } - private void Login() + private void SetupTokenDictionary() { - UpdateTokenInfo(TokenAudience.Management, new[] { "https://management.azure.com/.default" }); + TokenInfo = new Dictionary(); + UpdateTokenInfo(TokenAudience.Management, new[] { "https://management.core.windows.net/.default" }); UpdateTokenInfo(TokenAudience.Graph, new[] { "https://graph.microsoft.com/.default" }); + if (HttpMockServer.Mode == HttpRecorderMode.Record) + { + VerifyAuthTokens(); + } } - private void UpdateTokenInfo(TokenAudience tokenAudience, IEnumerable scopes) + private void UpdateTokenInfo(TokenAudience audience, IEnumerable scopes, string cloudInstanceUri = null) { - var accessToken = GetServicePrincipalAccessToken(scopes); - TokenInfo[tokenAudience] = new TokenCredentials(accessToken); + string tokenKey = string.Empty; + switch (audience) + { + case TokenAudience.Management: + tokenKey = ConnectionStringKeys.RawTokenKey; + break; + case TokenAudience.Graph: + tokenKey = ConnectionStringKeys.RawGraphTokenKey; + break; + } + + if (!ConnectionString.KeyValuePairs.TryGetValue(tokenKey, out var token) || string.IsNullOrWhiteSpace(token)) + { + token = GetAccessToken(scopes, cloudInstanceUri) ?? tokenKey; + } + + TokenInfo[audience] = new TokenCredentials(token); } - public string GetServicePrincipalAccessToken(IEnumerable scopes, string cloudInstanceUri = null) + public string GetAccessToken(IEnumerable scopes, string cloudInstanceUri = null) + { + string accessToken = null; + if (HttpMockServer.Mode == HttpRecorderMode.Record) + { + if (ConnectionString.KeyValuePairs.ContainsKey(ConnectionStringKeys.UserIdKey)) + { + accessToken = GetUserAccessToken(scopes, cloudInstanceUri); + } + else if (ConnectionString.KeyValuePairs.ContainsKey(ConnectionStringKeys.ServicePrincipalKey)) + { + accessToken = GetServicePrincipalAccessToken(scopes, cloudInstanceUri); + } + } + + return accessToken; + } + + private string GetUserAccessToken(IEnumerable scopes, string cloudInstanceUri = null) { if (string.IsNullOrWhiteSpace(cloudInstanceUri)) { cloudInstanceUri = MicrosoftLoginUrl; } - var spn = ConfidentialClientApplicationBuilder - .Create(ServicePrincipalClientId) - .WithClientSecret(ServicePrincipalSecret) + var pubApp = PublicClientApplicationBuilder.Create(Constants.PowerShellClientId) .WithAuthority(cloudInstanceUri, TenantId) + .WithDefaultRedirectUri() .Build(); - var authResult = Task.Run(async () => await spn.AcquireTokenForClient(scopes).ExecuteAsync().ConfigureAwait(false)); - return authResult.Result.AccessToken; + RegisterTokenCache(pubApp.UserTokenCache); + + AuthenticationResult authResult; + try + { + var userAccounts = pubApp.GetAccountsAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + var userAccount = userAccounts.FirstOrDefault(); + authResult = pubApp.AcquireTokenSilent(scopes, userAccount).ExecuteAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + } + catch (MsalUiRequiredException) + { + authResult = pubApp.AcquireTokenInteractive(scopes).ExecuteAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + } + return authResult.AccessToken; } - public string GetServicePrincipalAccessToken(IEnumerable scopes, Uri authorityUri) + private string GetServicePrincipalAccessToken(IEnumerable scopes, string cloudInstanceUri = null) { - if (authorityUri == null) + if (string.IsNullOrWhiteSpace(cloudInstanceUri)) { - throw new ArgumentNullException(nameof(authorityUri)); + cloudInstanceUri = MicrosoftLoginUrl; } - var spn = ConfidentialClientApplicationBuilder + var confApp = ConfidentialClientApplicationBuilder .Create(ServicePrincipalClientId) .WithClientSecret(ServicePrincipalSecret) - .WithAuthority(authorityUri) + .WithAuthority(cloudInstanceUri, TenantId) .Build(); - var authResult = Task.Run(async () => await spn.AcquireTokenForClient(scopes).ExecuteAsync().ConfigureAwait(false)); - return authResult.Result.AccessToken; + RegisterTokenCache(confApp.AppTokenCache); + + var authResult = confApp.AcquireTokenForClient(scopes).ExecuteAsync().ConfigureAwait(false).GetAwaiter().GetResult(); + return authResult.AccessToken; + } + + private void RegisterTokenCache(ITokenCache tokenCache) + { + var msalCacheFile = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), ".IdentityService", "msal.cache.plaintext"); + var options = new StorageCreationPropertiesBuilder(Path.GetFileName(msalCacheFile), Path.GetDirectoryName(msalCacheFile)).WithUnprotectedFile().Build(); + var helper = MsalCacheHelper.CreateAsync(options).ConfigureAwait(false).GetAwaiter().GetResult(); + helper.RegisterCache(tokenCache); } private void VerifyAuthTokens() @@ -374,15 +417,6 @@ internal void SetEnvironmentVariables() { HttpMockServer.Variables.Add(ConnectionStringKeys.SubscriptionIdKey, SubscriptionId); } - - // If User has provided Access Token in RawToken/GraphToken Key-Value, we don't need to authenticate - // We currently only check for RawToken and do not check if GraphToken is provided - if (string.IsNullOrEmpty(ConnectionString.GetValue(ConnectionStringKeys.RawTokenKey))) - { - Login(); - } - - VerifyAuthTokens(); } else if (HttpMockServer.Mode == HttpRecorderMode.Playback) {