diff --git a/samples/SocialSample/Startup.cs b/samples/SocialSample/Startup.cs index a0b193b8e..35896e84b 100644 --- a/samples/SocialSample/Startup.cs +++ b/samples/SocialSample/Startup.cs @@ -67,6 +67,10 @@ public void ConfigureServices(IServiceCollection services) o.Fields.Add("name"); o.Fields.Add("email"); o.SaveTokens = true; + o.Events = new OAuthEvents() + { + OnRemoteFailure = HandleOnRemoteFailure + }; }) // You must first create an app with Google and add its ID and Secret to your user-secrets. // https://console.developers.google.com/project @@ -81,6 +85,10 @@ public void ConfigureServices(IServiceCollection services) o.Scope.Add("profile"); o.Scope.Add("email"); o.SaveTokens = true; + o.Events = new OAuthEvents() + { + OnRemoteFailure = HandleOnRemoteFailure + }; }) // You must first create an app with Google and add its ID and Secret to your user-secrets. // https://console.developers.google.com/project @@ -93,12 +101,7 @@ public void ConfigureServices(IServiceCollection services) o.SaveTokens = true; o.Events = new OAuthEvents() { - OnRemoteFailure = ctx => - { - ctx.Response.Redirect("/error?FailureMessage=" + UrlEncoder.Default.Encode(ctx.Failure.Message)); - ctx.HandleResponse(); - return Task.FromResult(0); - } + OnRemoteFailure = HandleOnRemoteFailure }; o.ClaimActions.MapJsonSubKey("urn:google:image", "image", "url"); o.ClaimActions.Remove(ClaimTypes.GivenName); @@ -116,12 +119,7 @@ public void ConfigureServices(IServiceCollection services) o.ClaimActions.MapJsonKey("urn:twitter:profilepicture", "profile_image_url", ClaimTypes.Uri); o.Events = new TwitterEvents() { - OnRemoteFailure = ctx => - { - ctx.Response.Redirect("/error?FailureMessage=" + UrlEncoder.Default.Encode(ctx.Failure.Message)); - ctx.HandleResponse(); - return Task.FromResult(0); - } + OnRemoteFailure = HandleOnRemoteFailure }; }) /* Azure AD app model v2 has restrictions that prevent the use of plain HTTP for redirect URLs. @@ -139,6 +137,10 @@ public void ConfigureServices(IServiceCollection services) o.TokenEndpoint = MicrosoftAccountDefaults.TokenEndpoint; o.Scope.Add("https://graph.microsoft.com/user.read"); o.SaveTokens = true; + o.Events = new OAuthEvents() + { + OnRemoteFailure = HandleOnRemoteFailure + }; }) // You must first create an app with Microsoft Account and add its ID and Secret to your user-secrets. // https://azure.microsoft.com/en-us/documentation/articles/active-directory-v2-app-registration/ @@ -148,6 +150,10 @@ public void ConfigureServices(IServiceCollection services) o.ClientSecret = Configuration["microsoftaccount:clientsecret"]; o.SaveTokens = true; o.Scope.Add("offline_access"); + o.Events = new OAuthEvents() + { + OnRemoteFailure = HandleOnRemoteFailure + }; }) // You must first create an app with GitHub and add its ID and Secret to your user-secrets. // https://github.com/settings/applications/ @@ -159,6 +165,10 @@ public void ConfigureServices(IServiceCollection services) o.AuthorizationEndpoint = "https://github.com/login/oauth/authorize"; o.TokenEndpoint = "https://github.com/login/oauth/access_token"; o.SaveTokens = true; + o.Events = new OAuthEvents() + { + OnRemoteFailure = HandleOnRemoteFailure + }; }) // You must first create an app with GitHub and add its ID and Secret to your user-secrets. // https://github.com/settings/applications/ @@ -180,6 +190,7 @@ public void ConfigureServices(IServiceCollection services) o.ClaimActions.MapJsonKey("urn:github:url", "url"); o.Events = new OAuthEvents { + OnRemoteFailure = HandleOnRemoteFailure, OnCreatingTicket = async context => { // Get the GitHub user @@ -198,6 +209,30 @@ public void ConfigureServices(IServiceCollection services) }); } + private async Task HandleOnRemoteFailure(RemoteFailureContext context) + { + context.Response.StatusCode = 500; + context.Response.ContentType = "text/html"; + await context.Response.WriteAsync(""); + await context.Response.WriteAsync("A remote failure has occurred: " + UrlEncoder.Default.Encode(context.Failure.Message) + "
"); + + if (context.Properties != null) + { + await context.Response.WriteAsync("Properties:
"); + foreach (var pair in context.Properties.Items) + { + await context.Response.WriteAsync($"-{ UrlEncoder.Default.Encode(pair.Key)}={ UrlEncoder.Default.Encode(pair.Value)}
"); + } + } + + await context.Response.WriteAsync("Home"); + await context.Response.WriteAsync(""); + + // context.Response.Redirect("/error?FailureMessage=" + UrlEncoder.Default.Encode(context.Failure.Message)); + + context.HandleResponse(); + } + public void Configure(IApplicationBuilder app) { app.UseDeveloperExceptionPage(); diff --git a/src/Microsoft.AspNetCore.Authentication.OAuth/OAuthHandler.cs b/src/Microsoft.AspNetCore.Authentication.OAuth/OAuthHandler.cs index 007d7dbef..80680a7cf 100644 --- a/src/Microsoft.AspNetCore.Authentication.OAuth/OAuthHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.OAuth/OAuthHandler.cs @@ -44,9 +44,22 @@ public OAuthHandler(IOptionsMonitor options, ILoggerFactory logger, Ur protected override async Task HandleRemoteAuthenticateAsync() { - AuthenticationProperties properties = null; var query = Request.Query; + var state = query["state"]; + var properties = Options.StateDataFormat.Unprotect(state); + + if (properties == null) + { + return HandleRequestResult.Fail("The oauth state was missing or invalid."); + } + + // OAuth2 10.12 CSRF + if (!ValidateCorrelationId(properties)) + { + return HandleRequestResult.Fail("Correlation failed.", properties); + } + var error = query["error"]; if (!StringValues.IsNullOrEmpty(error)) { @@ -63,39 +76,26 @@ protected override async Task HandleRemoteAuthenticateAsync failureMessage.Append(";Uri=").Append(errorUri); } - return HandleRequestResult.Fail(failureMessage.ToString()); + return HandleRequestResult.Fail(failureMessage.ToString(), properties); } var code = query["code"]; - var state = query["state"]; - - properties = Options.StateDataFormat.Unprotect(state); - if (properties == null) - { - return HandleRequestResult.Fail("The oauth state was missing or invalid."); - } - - // OAuth2 10.12 CSRF - if (!ValidateCorrelationId(properties)) - { - return HandleRequestResult.Fail("Correlation failed."); - } if (StringValues.IsNullOrEmpty(code)) { - return HandleRequestResult.Fail("Code was not found."); + return HandleRequestResult.Fail("Code was not found.", properties); } var tokens = await ExchangeCodeAsync(code, BuildRedirectUri(Options.CallbackPath)); if (tokens.Error != null) { - return HandleRequestResult.Fail(tokens.Error); + return HandleRequestResult.Fail(tokens.Error, properties); } if (string.IsNullOrEmpty(tokens.AccessToken)) { - return HandleRequestResult.Fail("Failed to retrieve access token."); + return HandleRequestResult.Fail("Failed to retrieve access token.", properties); } var identity = new ClaimsIdentity(ClaimsIssuer); @@ -141,7 +141,7 @@ protected override async Task HandleRemoteAuthenticateAsync } else { - return HandleRequestResult.Fail("Failed to retrieve user information from remote server."); + return HandleRequestResult.Fail("Failed to retrieve user information from remote server.", properties); } } diff --git a/src/Microsoft.AspNetCore.Authentication.OpenIdConnect/OpenIdConnectHandler.cs b/src/Microsoft.AspNetCore.Authentication.OpenIdConnect/OpenIdConnectHandler.cs index bf365ceca..7f65afdce 100644 --- a/src/Microsoft.AspNetCore.Authentication.OpenIdConnect/OpenIdConnectHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.OpenIdConnect/OpenIdConnectHandler.cs @@ -491,13 +491,10 @@ protected override async Task HandleRemoteAuthenticateAsync return HandleRequestResult.Fail("No message."); } + AuthenticationProperties properties = null; try { - AuthenticationProperties properties = null; - if (!string.IsNullOrEmpty(authorizationResponse.State)) - { - properties = Options.StateDataFormat.Unprotect(authorizationResponse.State); - } + properties = ReadPropertiesAndClearState(authorizationResponse); var messageReceivedContext = await RunMessageReceivedEventAsync(authorizationResponse, properties); if (messageReceivedContext.Result != null) @@ -521,8 +518,7 @@ protected override async Task HandleRemoteAuthenticateAsync return HandleRequestResult.Fail(Resources.MessageStateIsNullOrEmpty); } - // if state exists and we failed to 'unprotect' this is not a message we should process. - properties = Options.StateDataFormat.Unprotect(authorizationResponse.State); + properties = ReadPropertiesAndClearState(authorizationResponse); } if (properties == null) @@ -533,21 +529,20 @@ protected override async Task HandleRemoteAuthenticateAsync // Not for us? return HandleRequestResult.SkipHandler(); } + + // if state exists and we failed to 'unprotect' this is not a message we should process. return HandleRequestResult.Fail(Resources.MessageStateIsInvalid); } - properties.Items.TryGetValue(OpenIdConnectDefaults.UserstatePropertiesKey, out string userstate); - authorizationResponse.State = userstate; - if (!ValidateCorrelationId(properties)) { - return HandleRequestResult.Fail("Correlation failed."); + return HandleRequestResult.Fail("Correlation failed.", properties); } // if any of the error fields are set, throw error null if (!string.IsNullOrEmpty(authorizationResponse.Error)) { - return HandleRequestResult.Fail(CreateOpenIdConnectProtocolException(authorizationResponse, response: null)); + return HandleRequestResult.Fail(CreateOpenIdConnectProtocolException(authorizationResponse, response: null), properties); } if (_configuration == null && Options.ConfigurationManager != null) @@ -635,8 +630,7 @@ protected override async Task HandleRemoteAuthenticateAsync // At least a cursory validation is required on the new IdToken, even if we've already validated the one from the authorization response. // And we'll want to validate the new JWT in ValidateTokenResponse. - JwtSecurityToken tokenEndpointJwt; - var tokenEndpointUser = ValidateToken(tokenEndpointResponse.IdToken, properties, validationParameters, out tokenEndpointJwt); + var tokenEndpointUser = ValidateToken(tokenEndpointResponse.IdToken, properties, validationParameters, out var tokenEndpointJwt); // Avoid reading & deleting the nonce cookie, running the event, etc, if it was already done as part of the authorization response validation. if (user == null) @@ -722,8 +716,25 @@ protected override async Task HandleRemoteAuthenticateAsync return authenticationFailedContext.Result; } - return HandleRequestResult.Fail(exception); + return HandleRequestResult.Fail(exception, properties); + } + } + + private AuthenticationProperties ReadPropertiesAndClearState(OpenIdConnectMessage message) + { + AuthenticationProperties properties = null; + if (!string.IsNullOrEmpty(message.State)) + { + properties = Options.StateDataFormat.Unprotect(message.State); + + if (properties != null) + { + // If properties can be decoded from state, clear the message state. + properties.Items.TryGetValue(OpenIdConnectDefaults.UserstatePropertiesKey, out var userstate); + message.State = userstate; + } } + return properties; } private void PopulateSessionProperties(OpenIdConnectMessage message, AuthenticationProperties properties) @@ -830,7 +841,7 @@ protected virtual async Task GetUserInformationAsync( } else { - return HandleRequestResult.Fail("Unknown response type: " + contentType.MediaType); + return HandleRequestResult.Fail("Unknown response type: " + contentType.MediaType, properties); } var userInformationReceivedContext = await RunUserInformationReceivedEventAsync(principal, properties, message, user); diff --git a/src/Microsoft.AspNetCore.Authentication.Twitter/TwitterHandler.cs b/src/Microsoft.AspNetCore.Authentication.Twitter/TwitterHandler.cs index e8a961df3..acfd765d9 100644 --- a/src/Microsoft.AspNetCore.Authentication.Twitter/TwitterHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.Twitter/TwitterHandler.cs @@ -46,7 +46,6 @@ public TwitterHandler(IOptionsMonitor options, ILoggerFactory lo protected override async Task HandleRemoteAuthenticateAsync() { - AuthenticationProperties properties = null; var query = Request.Query; var protectedRequestToken = Request.Cookies[Options.StateCookie.Name]; @@ -57,25 +56,25 @@ protected override async Task HandleRemoteAuthenticateAsync return HandleRequestResult.Fail("Invalid state cookie."); } - properties = requestToken.Properties; + var properties = requestToken.Properties; // REVIEW: see which of these are really errors var returnedToken = query["oauth_token"]; if (StringValues.IsNullOrEmpty(returnedToken)) { - return HandleRequestResult.Fail("Missing oauth_token"); + return HandleRequestResult.Fail("Missing oauth_token", properties); } if (!string.Equals(returnedToken, requestToken.Token, StringComparison.Ordinal)) { - return HandleRequestResult.Fail("Unmatched token"); + return HandleRequestResult.Fail("Unmatched token", properties); } var oauthVerifier = query["oauth_verifier"]; if (StringValues.IsNullOrEmpty(oauthVerifier)) { - return HandleRequestResult.Fail("Missing or blank oauth_verifier"); + return HandleRequestResult.Fail("Missing or blank oauth_verifier", properties); } var cookieOptions = Options.StateCookie.Build(Context, Clock.UtcNow); diff --git a/src/Microsoft.AspNetCore.Authentication/Events/RemoteFailureContext.cs b/src/Microsoft.AspNetCore.Authentication/Events/RemoteFailureContext.cs index becdfb543..6b3598f40 100644 --- a/src/Microsoft.AspNetCore.Authentication/Events/RemoteFailureContext.cs +++ b/src/Microsoft.AspNetCore.Authentication/Events/RemoteFailureContext.cs @@ -25,5 +25,10 @@ public RemoteFailureContext( /// User friendly error message for the error. /// public Exception Failure { get; set; } + + /// + /// Additional state values for the authentication session. + /// + public AuthenticationProperties Properties { get; set; } } } diff --git a/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationResult.cs b/src/Microsoft.AspNetCore.Authentication/HandleRequestResult.cs similarity index 72% rename from src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationResult.cs rename to src/Microsoft.AspNetCore.Authentication/HandleRequestResult.cs index 8bcd2be01..3f6c2d917 100644 --- a/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationResult.cs +++ b/src/Microsoft.AspNetCore.Authentication/HandleRequestResult.cs @@ -49,13 +49,31 @@ public class HandleRequestResult : AuthenticateResult /// /// Indicates that there was a failure during authentication. /// - /// The failure message. + /// The failure exception. + /// Additional state values for the authentication session. /// The result. - public static new HandleRequestResult Fail(string failureMessage) + public static new HandleRequestResult Fail(Exception failure, AuthenticationProperties properties) { - return new HandleRequestResult() { Failure = new Exception(failureMessage) }; + return new HandleRequestResult() { Failure = failure, Properties = properties }; } + /// + /// Indicates that there was a failure during authentication. + /// + /// The failure message. + /// The result. + public static new HandleRequestResult Fail(string failureMessage) + => Fail(new Exception(failureMessage)); + + /// + /// Indicates that there was a failure during authentication. + /// + /// The failure message. + /// Additional state values for the authentication session. + /// The result. + public static new HandleRequestResult Fail(string failureMessage, AuthenticationProperties properties) + => Fail(new Exception(failureMessage), properties); + /// /// Discontinue all processing for this request and return to the client. /// The caller is responsible for generating the full response. diff --git a/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationHandler.cs b/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationHandler.cs index 4051ee666..bea4895d6 100644 --- a/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationHandler.cs @@ -49,6 +49,7 @@ public virtual async Task HandleRequestAsync() AuthenticationTicket ticket = null; Exception exception = null; + AuthenticationProperties properties = null; try { var authResult = await HandleRemoteAuthenticateAsync(); @@ -66,8 +67,8 @@ public virtual async Task HandleRequestAsync() } else if (!authResult.Succeeded) { - exception = authResult.Failure ?? - new InvalidOperationException("Invalid return state, unable to redirect."); + exception = authResult.Failure ?? new InvalidOperationException("Invalid return state, unable to redirect."); + properties = authResult.Properties; } ticket = authResult?.Ticket; @@ -80,7 +81,10 @@ public virtual async Task HandleRequestAsync() if (exception != null) { Logger.RemoteAuthenticationError(exception.Message); - var errorContext = new RemoteFailureContext(Context, Scheme, Options, exception); + var errorContext = new RemoteFailureContext(Context, Scheme, Options, exception) + { + Properties = properties + }; await Events.RemoteFailure(errorContext); if (errorContext.Result != null) @@ -95,11 +99,14 @@ public virtual async Task HandleRequestAsync() } else if (errorContext.Result.Failure != null) { - throw new InvalidOperationException("An error was returned from the RemoteFailure event.", errorContext.Result.Failure); + throw new Exception("An error was returned from the RemoteFailure event.", errorContext.Result.Failure); } } - throw exception; + if (errorContext.Failure != null) + { + throw new Exception("An error was encountered while handling the remote login.", errorContext.Failure); + } } // We have a ticket if we get here @@ -107,7 +114,7 @@ public virtual async Task HandleRequestAsync() { ReturnUri = ticket.Properties.RedirectUri }; - // REVIEW: is this safe or good? + ticket.Properties.RedirectUri = null; // Mark which provider produced this identity so we can cross-check later in HandleAuthenticateAsync diff --git a/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs index 8f2cc52f9..51bc67cc3 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs @@ -253,6 +253,7 @@ public async Task ReplyPathWithErrorFails(bool redirect) { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; + o.StateDataFormat = new TestStateDataFormat(); o.Events = redirect ? new OAuthEvents() { OnRemoteFailure = ctx => @@ -263,7 +264,8 @@ public async Task ReplyPathWithErrorFails(bool redirect) } } : new OAuthEvents(); }); - var sendTask = server.SendAsync("https://example.com/signin-google?error=OMG&error_description=SoBad&error_uri=foobar"); + var sendTask = server.SendAsync("https://example.com/signin-google?error=OMG&error_description=SoBad&error_uri=foobar&state=protected_state", + ".AspNetCore.Correlation.Google.corrilationId=N"); if (redirect) { var transaction = await sendTask; @@ -1075,5 +1077,37 @@ private static TestServer CreateServer(Action configureOptions, F }); return new TestServer(builder); } + + private class TestStateDataFormat : ISecureDataFormat + { + private AuthenticationProperties Data { get; set; } + + public string Protect(AuthenticationProperties data) + { + return "protected_state"; + } + + public string Protect(AuthenticationProperties data, string purpose) + { + throw new NotImplementedException(); + } + + public AuthenticationProperties Unprotect(string protectedText) + { + Assert.Equal("protected_state", protectedText); + var properties = new AuthenticationProperties(new Dictionary() + { + { ".xsrf", "corrilationId" }, + { "testkey", "testvalue" } + }); + properties.RedirectUri = "http://testhost/redirect"; + return properties; + } + + public AuthenticationProperties Unprotect(string protectedText, string purpose) + { + throw new NotImplementedException(); + } + } } } diff --git a/test/Microsoft.AspNetCore.Authentication.Test/OAuthTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/OAuthTests.cs index 30c33eb1d..81d2360ec 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/OAuthTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/OAuthTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Net; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication.Cookies; @@ -10,6 +11,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Net.Http.Headers; using Xunit; namespace Microsoft.AspNetCore.Authentication.OAuth @@ -20,20 +22,13 @@ public class OAuthTests public async Task VerifySignInSchemeCannotBeSetToSelf() { var server = CreateServer( - app => { }, services => services.AddAuthentication().AddOAuth("weeblie", o => { o.SignInScheme = "weeblie"; o.ClientId = "whatever"; o.ClientSecret = "whatever"; - }), - context => - { - // REVIEW: Gross. - context.ChallengeAsync("weeblie").GetAwaiter().GetResult(); - return true; - }); - var error = await Assert.ThrowsAsync(() => server.SendAsync("https://example.com/challenge")); + })); + var error = await Assert.ThrowsAsync(() => server.SendAsync("https://example.com/")); Assert.Contains("cannot be set to itself", error.Message); } @@ -54,7 +49,6 @@ public async Task VerifySchemeDefaults() public async Task ThrowsIfClientIdMissing() { var server = CreateServer( - app => { }, services => services.AddAuthentication().AddOAuth("weeblie", o => { o.SignInScheme = "whatever"; @@ -62,22 +56,14 @@ public async Task ThrowsIfClientIdMissing() o.ClientSecret = "whatever"; o.TokenEndpoint = "/"; o.AuthorizationEndpoint = "/"; - }), - context => - { - // REVIEW: Gross. - Assert.Throws("ClientId", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult()); - return true; - }); - var transaction = await server.SendAsync("http://example.com/challenge"); - Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); + })); + await Assert.ThrowsAsync("ClientId", () => server.SendAsync("http://example.com/")); } [Fact] public async Task ThrowsIfClientSecretMissing() { var server = CreateServer( - app => { }, services => services.AddAuthentication().AddOAuth("weeblie", o => { o.SignInScheme = "whatever"; @@ -85,22 +71,14 @@ public async Task ThrowsIfClientSecretMissing() o.CallbackPath = "/"; o.TokenEndpoint = "/"; o.AuthorizationEndpoint = "/"; - }), - context => - { - // REVIEW: Gross. - Assert.Throws("ClientSecret", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult()); - return true; - }); - var transaction = await server.SendAsync("http://example.com/challenge"); - Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); + })); + await Assert.ThrowsAsync("ClientSecret", () => server.SendAsync("http://example.com/")); } [Fact] public async Task ThrowsIfCallbackPathMissing() { var server = CreateServer( - app => { }, services => services.AddAuthentication().AddOAuth("weeblie", o => { o.ClientId = "Whatever;"; @@ -108,22 +86,14 @@ public async Task ThrowsIfCallbackPathMissing() o.TokenEndpoint = "/"; o.AuthorizationEndpoint = "/"; o.SignInScheme = "eh"; - }), - context => - { - // REVIEW: Gross. - Assert.Throws("CallbackPath", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult()); - return true; - }); - var transaction = await server.SendAsync("http://example.com/challenge"); - Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); + })); + await Assert.ThrowsAsync("CallbackPath", () => server.SendAsync("http://example.com/")); } [Fact] public async Task ThrowsIfTokenEndpointMissing() { var server = CreateServer( - app => { }, services => services.AddAuthentication().AddOAuth("weeblie", o => { o.ClientId = "Whatever;"; @@ -131,22 +101,14 @@ public async Task ThrowsIfTokenEndpointMissing() o.CallbackPath = "/"; o.AuthorizationEndpoint = "/"; o.SignInScheme = "eh"; - }), - context => - { - // REVIEW: Gross. - Assert.Throws("TokenEndpoint", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult()); - return true; - }); - var transaction = await server.SendAsync("http://example.com/challenge"); - Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); + })); + await Assert.ThrowsAsync("TokenEndpoint", () => server.SendAsync("http://example.com/")); } [Fact] public async Task ThrowsIfAuthorizationEndpointMissing() { var server = CreateServer( - app => { }, services => services.AddAuthentication().AddOAuth("weeblie", o => { o.ClientId = "Whatever;"; @@ -154,22 +116,14 @@ public async Task ThrowsIfAuthorizationEndpointMissing() o.CallbackPath = "/"; o.TokenEndpoint = "/"; o.SignInScheme = "eh"; - }), - context => - { - // REVIEW: Gross. - Assert.Throws("AuthorizationEndpoint", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult()); - return true; - }); - var transaction = await server.SendAsync("http://example.com/challenge"); - Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); + })); + await Assert.ThrowsAsync("AuthorizationEndpoint", () => server.SendAsync("http://example.com/")); } [Fact] public async Task RedirectToIdentityProvider_SetsCorrelationIdCookiePath_ToCallBackPath() { var server = CreateServer( - app => { }, s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -181,9 +135,9 @@ public async Task RedirectToIdentityProvider_SetsCorrelationIdCookiePath_ToCallB opt.TokenEndpoint = "https://example.com/provider/token"; opt.CallbackPath = "/oauth-callback"; }), - ctx => + async ctx => { - ctx.ChallengeAsync("Weblie").ConfigureAwait(false).GetAwaiter().GetResult(); + await ctx.ChallengeAsync("Weblie"); return true; }); @@ -201,7 +155,6 @@ public async Task RedirectToIdentityProvider_SetsCorrelationIdCookiePath_ToCallB public async Task RedirectToAuthorizeEndpoint_CorrelationIdCookieOptions_CanBeOverriden() { var server = CreateServer( - app => { }, s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -214,9 +167,9 @@ public async Task RedirectToAuthorizeEndpoint_CorrelationIdCookieOptions_CanBeOv opt.CallbackPath = "/oauth-callback"; opt.CorrelationCookie.Path = "/"; }), - ctx => + async ctx => { - ctx.ChallengeAsync("Weblie").ConfigureAwait(false).GetAwaiter().GetResult(); + await ctx.ChallengeAsync("Weblie"); return true; }); @@ -230,15 +183,50 @@ public async Task RedirectToAuthorizeEndpoint_CorrelationIdCookieOptions_CanBeOv Assert.Contains("path=/", correlation); } - private static TestServer CreateServer(Action configure, Action configureServices, Func handler) + [Fact] + public async Task RemoteAuthenticationFailed_OAuthError_IncludesProperties() + { + var server = CreateServer( + s => s.AddAuthentication().AddOAuth( + "Weblie", + opt => + { + opt.ClientId = "Test Id"; + opt.ClientSecret = "secret"; + opt.SignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; + opt.AuthorizationEndpoint = "https://example.com/provider/login"; + opt.TokenEndpoint = "https://example.com/provider/token"; + opt.CallbackPath = "/oauth-callback"; + opt.StateDataFormat = new TestStateDataFormat(); + opt.Events = new OAuthEvents() + { + OnRemoteFailure = context => + { + Assert.Contains("declined", context.Failure.Message); + Assert.Equal("testvalue", context.Properties.Items["testkey"]); + context.Response.StatusCode = StatusCodes.Status406NotAcceptable; + context.HandleResponse(); + return Task.CompletedTask; + } + }; + })); + + var transaction = await server.SendAsync("https://www.example.com/oauth-callback?error=declined&state=protected_state", + ".AspNetCore.Correlation.Weblie.corrilationId=N"); + + Assert.Equal(HttpStatusCode.NotAcceptable, transaction.Response.StatusCode); + Assert.Null(transaction.Response.Headers.Location); + } + + private static TestServer CreateServer(Action configureServices, Func> handler = null) { var builder = new WebHostBuilder() .Configure(app => { - configure?.Invoke(app); + app.UseAuthentication(); app.Use(async (context, next) => { - if (handler == null || !handler(context)) + if (handler == null || ! await handler(context)) { await next(); } @@ -247,5 +235,37 @@ private static TestServer CreateServer(Action configure, Ac .ConfigureServices(configureServices); return new TestServer(builder); } + + private class TestStateDataFormat : ISecureDataFormat + { + private AuthenticationProperties Data { get; set; } + + public string Protect(AuthenticationProperties data) + { + return "protected_state"; + } + + public string Protect(AuthenticationProperties data, string purpose) + { + throw new NotImplementedException(); + } + + public AuthenticationProperties Unprotect(string protectedText) + { + Assert.Equal("protected_state", protectedText); + var properties = new AuthenticationProperties(new Dictionary() + { + { ".xsrf", "corrilationId" }, + { "testkey", "testvalue" } + }); + properties.RedirectUri = "http://testhost/redirect"; + return properties; + } + + public AuthenticationProperties Unprotect(string protectedText, string purpose) + { + throw new NotImplementedException(); + } + } } } diff --git a/test/Microsoft.AspNetCore.Authentication.Test/OpenIdConnect/OpenIdConnectEventTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/OpenIdConnect/OpenIdConnectEventTests.cs index f3fc26187..87bdc3f3c 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/OpenIdConnect/OpenIdConnectEventTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/OpenIdConnect/OpenIdConnectEventTests.cs @@ -95,7 +95,7 @@ public async Task OnMessageReceived_Fail_NoMoreEventsRun() return PostAsync(server, "signin-oidc", ""); }); - Assert.Equal("Authentication was aborted from user code.", exception.Message); + Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message); Assert.True(messageReceived); Assert.True(remoteFailure); @@ -191,7 +191,7 @@ public async Task OnTokenValidated_Fail_NoMoreEventsRun() return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state"); }); - Assert.Equal("Authentication was aborted from user code.", exception.Message); + Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message); Assert.True(messageReceived); Assert.True(tokenValidated); @@ -348,7 +348,7 @@ public async Task OnAuthorizationCodeReceived_Fail_NoMoreEventsRun() return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state&code=my_code"); }); - Assert.Equal("Authentication was aborted from user code.", exception.Message); + Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message); Assert.True(messageReceived); Assert.True(tokenValidated); @@ -532,7 +532,7 @@ public async Task OnTokenResponseReceived_Fail_NoMoreEventsRun() return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state&code=my_code"); }); - Assert.Equal("Authentication was aborted from user code.", exception.Message); + Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message); Assert.True(messageReceived); Assert.True(tokenValidated); @@ -731,7 +731,7 @@ public async Task OnTokenValidatedBackchannel_Fail_NoMoreEventsRun() return PostAsync(server, "signin-oidc", "state=protected_state&code=my_code"); }); - Assert.Equal("Authentication was aborted from user code.", exception.Message); + Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message); Assert.True(messageReceived); Assert.True(codeReceived); @@ -943,7 +943,7 @@ public async Task OnUserInformationReceived_Fail_NoMoreEventsRun() return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state&code=my_code"); }); - Assert.Equal("Authentication was aborted from user code.", exception.Message); + Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message); Assert.True(messageReceived); Assert.True(tokenValidated); @@ -1186,7 +1186,7 @@ public async Task OnAuthenticationFailed_Fail_NoMoreEventsRun() return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state&code=my_code"); }); - Assert.Equal("Authentication was aborted from user code.", exception.Message); + Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message); Assert.True(messageReceived); Assert.True(tokenValidated); @@ -1450,6 +1450,7 @@ public async Task OnRemoteFailure_Handled_NoMoreEventsRun() { remoteFailure = true; Assert.Equal("TestException", context.Failure.Message); + Assert.Equal("testvalue", context.Properties.Items["testkey"]); context.HandleResponse(); context.Response.StatusCode = StatusCodes.Status202Accepted; return Task.FromResult(0); @@ -1877,7 +1878,8 @@ public AuthenticationProperties Unprotect(string protectedText) var properties = new AuthenticationProperties(new Dictionary() { { ".xsrf", "corrilationId" }, - { OpenIdConnectDefaults.RedirectUriForCodePropertiesKey, "redirect_uri" } + { OpenIdConnectDefaults.RedirectUriForCodePropertiesKey, "redirect_uri" }, + { "testkey", "testvalue" } }); properties.RedirectUri = "http://testhost/redirect"; return properties; diff --git a/test/Microsoft.AspNetCore.Authentication.Test/TwitterTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/TwitterTests.cs index 6c661af45..735cb3314 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/TwitterTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/TwitterTests.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. See License.txt in the project root for license information. using System; +using System.Linq; using System.Net; using System.Net.Http; using System.Security.Claims; @@ -11,6 +12,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Net.Http.Headers; using Xunit; namespace Microsoft.AspNetCore.Authentication.Twitter @@ -60,26 +62,12 @@ public async Task ChallengeWillTriggerApplyRedirectEvent() }; o.BackchannelHttpHandler = new TestHttpMessageHandler { - Sender = req => - { - if (req.RequestUri.AbsoluteUri == "https://api.twitter.com/oauth/request_token") - { - return new HttpResponseMessage(HttpStatusCode.OK) - { - Content = - new StringContent("oauth_callback_confirmed=true&oauth_token=test_oauth_token&oauth_token_secret=test_oauth_token_secret", - Encoding.UTF8, - "application/x-www-form-urlencoded") - }; - } - return null; - } + Sender = BackchannelRequestToken }; }, - context => + async context => { - // REVIEW: Gross - context.ChallengeAsync("Twitter").GetAwaiter().GetResult(); + await context.ChallengeAsync("Twitter"); return true; }); var transaction = await server.SendAsync("http://example.com/challenge"); @@ -168,7 +156,6 @@ public async Task ForbidThrows() Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } - [Fact] public async Task ChallengeWillTriggerRedirection() { @@ -178,35 +165,70 @@ public async Task ChallengeWillTriggerRedirection() o.ConsumerSecret = "Test Consumer Secret"; o.BackchannelHttpHandler = new TestHttpMessageHandler { - Sender = req => + Sender = BackchannelRequestToken + }; + }, + async context => + { + await context.ChallengeAsync("Twitter"); + return true; + }); + var transaction = await server.SendAsync("http://example.com/challenge"); + Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); + var location = transaction.Response.Headers.Location.AbsoluteUri; + Assert.Contains("https://api.twitter.com/oauth/authenticate?oauth_token=", location); + } + + [Fact] + public async Task BadCallbackCallsRemoteAuthFailedWithState() + { + var server = CreateServer(o => + { + o.ConsumerKey = "Test Consumer Key"; + o.ConsumerSecret = "Test Consumer Secret"; + o.BackchannelHttpHandler = new TestHttpMessageHandler + { + Sender = BackchannelRequestToken + }; + o.Events = new TwitterEvents() + { + OnRemoteFailure = context => { - if (req.RequestUri.AbsoluteUri == "https://api.twitter.com/oauth/request_token") - { - return new HttpResponseMessage(HttpStatusCode.OK) - { - Content = - new StringContent("oauth_callback_confirmed=true&oauth_token=test_oauth_token&oauth_token_secret=test_oauth_token_secret", - Encoding.UTF8, - "application/x-www-form-urlencoded") - }; - } - return null; + Assert.NotNull(context.Failure); + Assert.NotNull(context.Properties); + Assert.Equal("testvalue", context.Properties.Items["testkey"]); + context.Response.StatusCode = StatusCodes.Status406NotAcceptable; + context.HandleResponse(); + return Task.CompletedTask; } }; }, - context => - { - // REVIEW: gross - context.ChallengeAsync("Twitter").GetAwaiter().GetResult(); - return true; - }); + async context => + { + var properties = new AuthenticationProperties(); + properties.Items["testkey"] = "testvalue"; + await context.ChallengeAsync("Twitter", properties); + return true; + }); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.AbsoluteUri; Assert.Contains("https://api.twitter.com/oauth/authenticate?oauth_token=", location); + Assert.True(transaction.Response.Headers.TryGetValues(HeaderNames.SetCookie, out var setCookie)); + Assert.True(SetCookieHeaderValue.TryParseList(setCookie.ToList(), out var setCookieValues)); + Assert.Single(setCookieValues); + var setCookieValue = setCookieValues.Single(); + var cookie = new CookieHeaderValue(setCookieValue.Name, setCookieValue.Value); + + var request = new HttpRequestMessage(HttpMethod.Get, "/signin-twitter"); + request.Headers.Add(HeaderNames.Cookie, cookie.ToString()); + var client = server.CreateClient(); + var response = await client.SendAsync(request); + + Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); } - private static TestServer CreateServer(Action options, Func handler = null) + private static TestServer CreateServer(Action options, Func> handler = null) { var builder = new WebHostBuilder() .Configure(app => @@ -228,7 +250,7 @@ private static TestServer CreateServer(Action options, Func(() => context.ForbidAsync("Twitter")); } - else if (handler == null || !handler(context)) + else if (handler == null || ! await handler(context)) { await next(); } @@ -247,5 +269,20 @@ private static TestServer CreateServer(Action options, Func