Skip to content

Handle UI Freezes when IMGR is called from WPF #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 133 additions & 88 deletions src/Authentication/Authentication/Cmdlets/ConnectMgGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
// ------------------------------------------------------------------------------
namespace Microsoft.Graph.PowerShell.Authentication.Cmdlets
{
using Microsoft.Graph.Auth;
using Microsoft.Graph.PowerShell.Authentication.Helpers;
using Microsoft.Graph.PowerShell.Authentication.Models;
using Microsoft.Identity.Client;
using System;
using System.Collections.Generic;
using System.Linq;
Expand All @@ -16,10 +12,20 @@ namespace Microsoft.Graph.PowerShell.Authentication.Cmdlets
using System.Threading.Tasks;
using System.Net;
using System.Globalization;
using Microsoft.Graph.PowerShell.Authentication.Interfaces;
using Microsoft.Graph.PowerShell.Authentication.Common;
using System.Collections;
using System.Security.Cryptography.X509Certificates;

using Identity.Client;

using Microsoft.Graph.Auth;
using Microsoft.Graph.PowerShell.Authentication.Helpers;
using Microsoft.Graph.PowerShell.Authentication.Models;

using Interfaces;
using Common;

using static Helpers.AsyncHelpers;

[Cmdlet(VerbsCommunications.Connect, "MgGraph", DefaultParameterSetName = Constants.UserParameterSet)]
[Alias("Connect-Graph")]
public class ConnectMgGraph : PSCmdlet, IModuleAssemblyInitializer, IModuleAssemblyCleanup
Expand Down Expand Up @@ -70,10 +76,12 @@ public class ConnectMgGraph : PSCmdlet, IModuleAssemblyInitializer, IModuleAssem
[Alias("EnvironmentName", "NationalCloud")]
public string Environment { get; set; }

[Parameter(ParameterSetName = Constants.AppParameterSet, Mandatory = false, HelpMessage = "An x509 Certificate supplied during invocation")]
[Parameter(Mandatory = false,
ParameterSetName = Constants.AppParameterSet,
HelpMessage = "An x509 Certificate supplied during invocation")]
public X509Certificate2 Certificate { get; set; }

private CancellationTokenSource cancellationTokenSource;
private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource();

private IGraphEnvironment environment;

Expand Down Expand Up @@ -104,66 +112,102 @@ protected override void EndProcessing()
protected override void ProcessRecord()
{
base.ProcessRecord();
IAuthContext authContext = new AuthContext { TenantId = TenantId };
cancellationTokenSource = new CancellationTokenSource();
// Set selected environment to the session object.
GraphSession.Instance.Environment = environment;

switch (ParameterSetName)
try
{
case Constants.UserParameterSet:
{
// 2 mins timeout. 1 min < HTTP timeout.
TimeSpan authTimeout = new TimeSpan(0, 0, Constants.MaxDeviceCodeTimeOut);
cancellationTokenSource = new CancellationTokenSource(authTimeout);
authContext.AuthType = AuthenticationType.Delegated;
string[] processedScopes = ProcessScopes(Scopes);
authContext.Scopes = processedScopes.Length == 0 ? new string[] { "User.Read" } : processedScopes;
// Default to CurrentUser but allow the customer to change this via `ContextScope` param.
authContext.ContextScope = this.IsParameterBound(nameof(ContextScope)) ? ContextScope : ContextScope.CurrentUser;
}
break;
case Constants.AppParameterSet:
using (var asyncCommandRuntime = new CustomAsyncCommandRuntime(this, _cancellationTokenSource.Token))
{
asyncCommandRuntime.Wait(ProcessRecordAsync(), _cancellationTokenSource.Token);
}
}
catch (AggregateException aggregateException)
{
// unroll the inner exceptions to get the root cause
foreach (var innerException in aggregateException.Flatten().InnerExceptions)
{
var errorRecords = innerException.Data;
if (errorRecords.Count < 1)
{
authContext.AuthType = AuthenticationType.AppOnly;
authContext.ClientId = ClientId;
authContext.CertificateThumbprint = CertificateThumbprint;
authContext.CertificateName = CertificateName;
authContext.Certificate = Certificate;
// Default to Process but allow the customer to change this via `ContextScope` param.
authContext.ContextScope = this.IsParameterBound(nameof(ContextScope)) ? ContextScope : ContextScope.Process;
foreach (DictionaryEntry dictionaryEntry in errorRecords)
{
WriteError((ErrorRecord)dictionaryEntry.Value);
}
}
break;
case Constants.AccessTokenParameterSet:
else
{
authContext.AuthType = AuthenticationType.UserProvidedAccessToken;
authContext.ContextScope = ContextScope.Process;
// Store user provided access token to a session object.
GraphSession.Instance.UserProvidedToken = new NetworkCredential(string.Empty, AccessToken).SecurePassword;
WriteError(new ErrorRecord(innerException, string.Empty, ErrorCategory.NotSpecified, null));
}
break;
}
}
catch (Exception exception) when (exception as PipelineStoppedException == null ||
(exception as PipelineStoppedException).InnerException != null)
{
// Write exception out to error channel.
WriteError(new ErrorRecord(exception, string.Empty, ErrorCategory.NotSpecified, null));
}
}

CancellationToken cancellationToken = cancellationTokenSource.Token;

try
private async Task ProcessRecordAsync()
{
using (NoSynchronizationContext)
{
// Gets a static instance of IAuthenticationProvider when the client app hasn't changed.
IAuthenticationProvider authProvider = AuthenticationHelpers.GetAuthProvider(authContext);
IClientApplicationBase clientApplication = null;
if (ParameterSetName == Constants.UserParameterSet)
IAuthContext authContext = new AuthContext { TenantId = TenantId };
// Set selected environment to the session object.
GraphSession.Instance.Environment = environment;

switch (ParameterSetName)
{
clientApplication = (authProvider as DeviceCodeProvider).ClientApplication;
case Constants.UserParameterSet:
{
// 2 mins timeout. 1 min < HTTP timeout.
TimeSpan authTimeout = new TimeSpan(0, 0, Constants.MaxDeviceCodeTimeOut);
// To avoid re-initializing the tokenSource, use CancelAfter
_cancellationTokenSource.CancelAfter(authTimeout);
authContext.AuthType = AuthenticationType.Delegated;
string[] processedScopes = ProcessScopes(Scopes);
authContext.Scopes = processedScopes.Length == 0 ? new string[] { "User.Read" } : processedScopes;
// Default to CurrentUser but allow the customer to change this via `ContextScope` param.
authContext.ContextScope = this.IsParameterBound(nameof(ContextScope)) ? ContextScope : ContextScope.CurrentUser;
}
break;
case Constants.AppParameterSet:
{
authContext.AuthType = AuthenticationType.AppOnly;
authContext.ClientId = ClientId;
authContext.CertificateThumbprint = CertificateThumbprint;
authContext.CertificateName = CertificateName;
authContext.Certificate = Certificate;
// Default to Process but allow the customer to change this via `ContextScope` param.
authContext.ContextScope = this.IsParameterBound(nameof(ContextScope)) ? ContextScope : ContextScope.Process;
}
break;
case Constants.AccessTokenParameterSet:
{
authContext.AuthType = AuthenticationType.UserProvidedAccessToken;
authContext.ContextScope = ContextScope.Process;
// Store user provided access token to a session object.
GraphSession.Instance.UserProvidedToken = new NetworkCredential(string.Empty, AccessToken).SecurePassword;
}
break;
}
else if (ParameterSetName == Constants.AppParameterSet)

try
{
clientApplication = (authProvider as ClientCredentialProvider).ClientApplication;
}
// Gets a static instance of IAuthenticationProvider when the client app hasn't changed.
IAuthenticationProvider authProvider = AuthenticationHelpers.GetAuthProvider(authContext);
IClientApplicationBase clientApplication = null;
if (ParameterSetName == Constants.UserParameterSet)
{
clientApplication = (authProvider as DeviceCodeProvider).ClientApplication;
}
else if (ParameterSetName == Constants.AppParameterSet)
{
clientApplication = (authProvider as ClientCredentialProvider).ClientApplication;
}

// Incremental scope consent without re-instantiating the auth provider. We will use a static instance.
GraphRequestContext graphRequestContext = new GraphRequestContext();
graphRequestContext.CancellationToken = cancellationToken;
graphRequestContext.MiddlewareOptions = new Dictionary<string, IMiddlewareOption>
// Incremental scope consent without re-instantiating the auth provider. We will use a static instance.
GraphRequestContext graphRequestContext = new GraphRequestContext();
graphRequestContext.CancellationToken = _cancellationTokenSource.Token;
graphRequestContext.MiddlewareOptions = new Dictionary<string, IMiddlewareOption>
{
{
typeof(AuthenticationHandlerOption).ToString(),
Expand All @@ -178,49 +222,50 @@ protected override void ProcessRecord()
}
};

// Trigger consent.
HttpRequestMessage httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, "https://graph.microsoft.com/v1.0/me");
httpRequestMessage.Properties.Add(typeof(GraphRequestContext).ToString(), graphRequestContext);
authProvider.AuthenticateRequestAsync(httpRequestMessage).GetAwaiter().GetResult();
// Trigger consent.
HttpRequestMessage httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, "https://graph.microsoft.com/v1.0/me");
httpRequestMessage.Properties.Add(typeof(GraphRequestContext).ToString(), graphRequestContext);
await authProvider.AuthenticateRequestAsync(httpRequestMessage);

IAccount account = null;
if (clientApplication != null)
{
// Only get accounts when we are using MSAL to get an access token.
IEnumerable<IAccount> accounts = clientApplication.GetAccountsAsync().GetAwaiter().GetResult();
account = accounts.FirstOrDefault();
}
DecodeJWT(httpRequestMessage.Headers.Authorization?.Parameter, account, ref authContext);
IAccount account = null;
if (clientApplication != null)
{
// Only get accounts when we are using MSAL to get an access token.
IEnumerable<IAccount> accounts = clientApplication.GetAccountsAsync().GetAwaiter().GetResult();
account = accounts.FirstOrDefault();
}
DecodeJWT(httpRequestMessage.Headers.Authorization?.Parameter, account, ref authContext);

// Save auth context to session state.
GraphSession.Instance.AuthContext = authContext;
}
catch (AuthenticationException authEx)
{
if ((authEx.InnerException is TaskCanceledException) && cancellationToken.IsCancellationRequested)
// Save auth context to session state.
GraphSession.Instance.AuthContext = authContext;
}
catch (AuthenticationException authEx)
{
// DeviceCodeTimeout
throw new Exception(string.Format(
CultureInfo.CurrentCulture,
ErrorConstants.Message.DeviceCodeTimeout,
Constants.MaxDeviceCodeTimeOut));
if ((authEx.InnerException is TaskCanceledException) && _cancellationTokenSource.Token.IsCancellationRequested)
{
// DeviceCodeTimeout
throw new Exception(string.Format(
CultureInfo.CurrentCulture,
ErrorConstants.Message.DeviceCodeTimeout,
Constants.MaxDeviceCodeTimeOut));
}
else
{
throw authEx.InnerException ?? authEx;
}
}
else
catch (Exception ex)
{
throw authEx.InnerException ?? authEx;
throw ex.InnerException ?? ex;
}
}
catch (Exception ex)
{
throw ex.InnerException ?? ex;
}

WriteObject("Welcome To Microsoft Graph!");
WriteObject("Welcome To Microsoft Graph!");
}
}

protected override void StopProcessing()
{
cancellationTokenSource.Cancel();
_cancellationTokenSource.Cancel();
base.StopProcessing();
}

Expand Down
Loading