diff --git a/src/Security/Authentication/OAuth/src/Events/OAuthEvents.cs b/src/Security/Authentication/OAuth/src/Events/OAuthEvents.cs index 9e194491b98b..979ab578186c 100644 --- a/src/Security/Authentication/OAuth/src/Events/OAuthEvents.cs +++ b/src/Security/Authentication/OAuth/src/Events/OAuthEvents.cs @@ -25,6 +25,11 @@ public class OAuthEvents : RemoteAuthenticationEvents return Task.CompletedTask; }; + /// + /// Gets or sets the delegate that is invoked when the ExchangeCode method is invoked. + /// + public Func OnExchangeCode { get; set; } = context => Task.CompletedTask; + /// /// Invoked after the provider successfully authenticates a user. /// @@ -37,5 +42,13 @@ public class OAuthEvents : RemoteAuthenticationEvents /// /// Contains redirect URI and of the challenge. public virtual Task RedirectToAuthorizationEndpoint(RedirectContext context) => OnRedirectToAuthorizationEndpoint(context); + + /// + /// Invoked before the request to exchange the code for the access token. + /// + /// Contains the code returned, the redirect URI and the . + /// + public virtual Task ExchangeCode(OAuthExchangeCodeContext context) => OnExchangeCode(context); + } -} \ No newline at end of file +} diff --git a/src/Security/Authentication/OAuth/src/Events/OAuthExchangeCodeContext.cs b/src/Security/Authentication/OAuth/src/Events/OAuthExchangeCodeContext.cs new file mode 100644 index 000000000000..c4d00a8d5bb9 --- /dev/null +++ b/src/Security/Authentication/OAuth/src/Events/OAuthExchangeCodeContext.cs @@ -0,0 +1,38 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Authentication.OAuth +{ + /// + /// Contains information about the context of exchanging code for access token . + /// + public class OAuthExchangeCodeContext : PropertiesContext + { + /// + /// Initializes a new . + /// + /// The . + /// The HTTP environment. + /// The authentication scheme. + /// The options used by the authentication middleware. + /// The parameters that will be sent as query string for the token request + public OAuthExchangeCodeContext( + AuthenticationProperties properties, + HttpContext context, + AuthenticationScheme scheme, + OAuthOptions options, + IDictionary tokenRequestParameters) + : base(context, scheme, options, properties) + { + TokenRequestParameters = tokenRequestParameters; + } + + /// + /// Gets the request parameters for the token request + /// + public IDictionary TokenRequestParameters { get; } + } +} diff --git a/src/Security/Authentication/OAuth/src/OAuthHandler.cs b/src/Security/Authentication/OAuth/src/OAuthHandler.cs index 0d2cf140663a..4655e008e201 100644 --- a/src/Security/Authentication/OAuth/src/OAuthHandler.cs +++ b/src/Security/Authentication/OAuth/src/OAuthHandler.cs @@ -99,7 +99,7 @@ protected override async Task HandleRemoteAuthenticateAsync return HandleRequestResult.Fail("Code was not found.", properties); } - using (var tokens = await ExchangeCodeAsync(code, BuildRedirectUri(Options.CallbackPath))) + using (var tokens = await ExchangeCodeAsync(code, BuildRedirectUri(Options.CallbackPath), properties)) { if (tokens.Error != null) { @@ -159,7 +159,7 @@ protected override async Task HandleRemoteAuthenticateAsync } } - protected virtual async Task ExchangeCodeAsync(string code, string redirectUri) + protected virtual async Task ExchangeCodeAsync(string code, string redirectUri, AuthenticationProperties properties) { var tokenRequestParameters = new Dictionary() { @@ -170,6 +170,9 @@ protected virtual async Task ExchangeCodeAsync(string code, { "grant_type", "authorization_code" }, }; + var exchangeCodeContext = new OAuthExchangeCodeContext(properties, Context, Scheme, Options, tokenRequestParameters); + await Events.OnExchangeCode(exchangeCodeContext); + var requestContent = new FormUrlEncodedContent(tokenRequestParameters); var requestMessage = new HttpRequestMessage(HttpMethod.Post, Options.TokenEndpoint);