diff --git a/src/Security/Authorization/Core/src/AuthorizationPolicy.cs b/src/Security/Authorization/Core/src/AuthorizationPolicy.cs index 30959fe85c38..044cd96bbad3 100644 --- a/src/Security/Authorization/Core/src/AuthorizationPolicy.cs +++ b/src/Security/Authorization/Core/src/AuthorizationPolicy.cs @@ -108,7 +108,24 @@ public static AuthorizationPolicy Combine(IEnumerable polic /// A new which represents the combination of the /// authorization policies provided by the specified . /// - public static async Task CombineAsync(IAuthorizationPolicyProvider policyProvider, IEnumerable authorizeData) + public static Task CombineAsync(IAuthorizationPolicyProvider policyProvider, + IEnumerable authorizeData) => CombineAsync(policyProvider, authorizeData, + Enumerable.Empty()); + + /// + /// Combines the provided by the specified + /// . + /// + /// A which provides the policies to combine. + /// A collection of authorization data used to apply authorization to a resource. + /// A collection of policies to combine. + /// + /// A new which represents the combination of the + /// authorization policies provided by the specified . + /// + public static async Task CombineAsync(IAuthorizationPolicyProvider policyProvider, + IEnumerable authorizeData, + IEnumerable policies) { if (policyProvider == null) { @@ -120,6 +137,8 @@ public static AuthorizationPolicy Combine(IEnumerable polic throw new ArgumentNullException(nameof(authorizeData)); } + var anyPolicies = policies.Any(); + // Avoid allocating enumerator if the data is known to be empty var skipEnumeratingData = false; if (authorizeData is IList dataList) @@ -137,7 +156,7 @@ public static AuthorizationPolicy Combine(IEnumerable polic policyBuilder = new AuthorizationPolicyBuilder(); } - var useDefaultPolicy = true; + var useDefaultPolicy = !(anyPolicies); if (!string.IsNullOrWhiteSpace(authorizeDatum.Policy)) { var policy = await policyProvider.GetPolicyAsync(authorizeDatum.Policy).ConfigureAwait(false); @@ -176,6 +195,16 @@ public static AuthorizationPolicy Combine(IEnumerable polic } } + if (anyPolicies) + { + policyBuilder ??= new(); + + foreach (var policy in policies) + { + policyBuilder.Combine(policy); + } + } + // If we have no policy by now, use the fallback policy if we have one if (policyBuilder == null) { diff --git a/src/Security/Authorization/Core/src/PublicAPI.Unshipped.txt b/src/Security/Authorization/Core/src/PublicAPI.Unshipped.txt index ae01fa8428ad..3ae82ce3cead 100644 --- a/src/Security/Authorization/Core/src/PublicAPI.Unshipped.txt +++ b/src/Security/Authorization/Core/src/PublicAPI.Unshipped.txt @@ -3,3 +3,4 @@ *REMOVED*~Microsoft.AspNetCore.Authorization.DefaultAuthorizationService.DefaultAuthorizationService(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerProvider! handlers, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerContextFactory! contextFactory, Microsoft.AspNetCore.Authorization.IAuthorizationEvaluator! evaluator, Microsoft.Extensions.Options.IOptions! options) -> void Microsoft.AspNetCore.Authorization.DefaultAuthorizationPolicyProvider.DefaultAuthorizationPolicyProvider(Microsoft.Extensions.Options.IOptions! options) -> void Microsoft.AspNetCore.Authorization.DefaultAuthorizationService.DefaultAuthorizationService(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerProvider! handlers, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerContextFactory! contextFactory, Microsoft.AspNetCore.Authorization.IAuthorizationEvaluator! evaluator, Microsoft.Extensions.Options.IOptions! options) -> void +static Microsoft.AspNetCore.Authorization.AuthorizationPolicy.CombineAsync(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, System.Collections.Generic.IEnumerable! authorizeData, System.Collections.Generic.IEnumerable! policies) -> System.Threading.Tasks.Task! diff --git a/src/Security/Authorization/Policy/src/AuthorizationEndpointConventionBuilderExtensions.cs b/src/Security/Authorization/Policy/src/AuthorizationEndpointConventionBuilderExtensions.cs index c3a957a78349..6551f048fbba 100644 --- a/src/Security/Authorization/Policy/src/AuthorizationEndpointConventionBuilderExtensions.cs +++ b/src/Security/Authorization/Policy/src/AuthorizationEndpointConventionBuilderExtensions.cs @@ -79,6 +79,55 @@ public static TBuilder RequireAuthorization(this TBuilder builder, par return builder; } + /// + /// Adds an authorization policy to the endpoint(s). + /// + /// The endpoint convention builder. + /// The policy. + /// The original convention builder parameter. + public static TBuilder RequireAuthorization(this TBuilder builder, AuthorizationPolicy policy) + where TBuilder : IEndpointConventionBuilder + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + + if (policy == null) + { + throw new ArgumentNullException(nameof(policy)); + } + + RequirePolicyCore(builder, policy); + return builder; + } + + /// + /// Adds an new authorization policy configured by a callback to the endpoint(s). + /// + /// + /// The endpoint convention builder. + /// The callback used to configure the policy. + /// The original convention builder parameter. + public static TBuilder RequireAuthorization(this TBuilder builder, Action configurePolicy) + where TBuilder : IEndpointConventionBuilder + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + + if (configurePolicy == null) + { + throw new ArgumentNullException(nameof(configurePolicy)); + } + + var policyBuilder = new AuthorizationPolicyBuilder(); + configurePolicy(policyBuilder); + RequirePolicyCore(builder, policyBuilder.Build()); + return builder; + } + /// /// Allows anonymous access to the endpoint by adding to the endpoint metadata. This will bypass /// all authorization checks for the endpoint including the default authorization policy and fallback authorization policy. @@ -94,6 +143,20 @@ public static TBuilder AllowAnonymous(this TBuilder builder) where TBu return builder; } + private static void RequirePolicyCore(TBuilder builder, AuthorizationPolicy policy) + where TBuilder : IEndpointConventionBuilder + { + builder.Add(endpointBuilder => + { + // Only add an authorize attribute if there isn't one + if (!endpointBuilder.Metadata.Any(meta => meta is IAuthorizeData)) + { + endpointBuilder.Metadata.Add(new AuthorizeAttribute()); + } + endpointBuilder.Metadata.Add(policy); + }); + } + private static void RequireAuthorizationCore(TBuilder builder, IEnumerable authorizeData) where TBuilder : IEndpointConventionBuilder { @@ -105,4 +168,5 @@ private static void RequireAuthorizationCore(TBuilder builder, IEnumer } }); } + } diff --git a/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs b/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs index 516f7f5d0357..0a58a00a37e0 100644 --- a/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs +++ b/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs @@ -57,7 +57,11 @@ public async Task Invoke(HttpContext context) // IMPORTANT: Changes to authorization logic should be mirrored in MVC's AuthorizeFilter var authorizeData = endpoint?.Metadata.GetOrderedMetadata() ?? Array.Empty(); - var policy = await AuthorizationPolicy.CombineAsync(_policyProvider, authorizeData); + + var policies = endpoint?.Metadata.GetOrderedMetadata() ?? Array.Empty(); + + var policy = await AuthorizationPolicy.CombineAsync(_policyProvider, authorizeData, policies); + if (policy == null) { await _next(context); diff --git a/src/Security/Authorization/Policy/src/PublicAPI.Unshipped.txt b/src/Security/Authorization/Policy/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..9c11d1e7b2d6 100644 --- a/src/Security/Authorization/Policy/src/PublicAPI.Unshipped.txt +++ b/src/Security/Authorization/Policy/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +static Microsoft.AspNetCore.Builder.AuthorizationEndpointConventionBuilderExtensions.RequireAuthorization(this TBuilder builder, Microsoft.AspNetCore.Authorization.AuthorizationPolicy! policy) -> TBuilder +static Microsoft.AspNetCore.Builder.AuthorizationEndpointConventionBuilderExtensions.RequireAuthorization(this TBuilder builder, System.Action! configurePolicy) -> TBuilder diff --git a/src/Security/Authorization/test/AuthorizationEndpointConventionBuilderExtensionsTests.cs b/src/Security/Authorization/test/AuthorizationEndpointConventionBuilderExtensionsTests.cs index c2590efe913e..8cc43a7d2f7f 100644 --- a/src/Security/Authorization/test/AuthorizationEndpointConventionBuilderExtensionsTests.cs +++ b/src/Security/Authorization/test/AuthorizationEndpointConventionBuilderExtensionsTests.cs @@ -117,6 +117,106 @@ public void RequireAuthorization_ChainedCall() Assert.True(chainedBuilder.TestProperty); } + [Fact] + public void RequireAuthorization_Policy() + { + // Arrange + var builder = new TestEndpointConventionBuilder(); + var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build(); + + // Act + builder.RequireAuthorization(policy); + + // Assert + var convention = Assert.Single(builder.Conventions); + + var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0); + convention(endpointModel); + + Assert.Equal(2, endpointModel.Metadata.Count); + var authMetadata = Assert.IsAssignableFrom(endpointModel.Metadata[0]); + Assert.Null(authMetadata.Policy); + + Assert.Equal(policy, endpointModel.Metadata[1]); + } + + [Fact] + public void RequireAuthorization_PolicyCallback() + { + // Arrange + var builder = new TestEndpointConventionBuilder(); + var requirement = new TestRequirement(); + + // Act + builder.RequireAuthorization(policyBuilder => policyBuilder.Requirements.Add(requirement)); + + // Assert + var convention = Assert.Single(builder.Conventions); + + var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0); + convention(endpointModel); + + Assert.Equal(2, endpointModel.Metadata.Count); + var authMetadata = Assert.IsAssignableFrom(endpointModel.Metadata[0]); + Assert.Null(authMetadata.Policy); + + var policy = Assert.IsAssignableFrom(endpointModel.Metadata[1]); + Assert.Equal(1, policy.Requirements.Count); + Assert.Equal(requirement, policy.Requirements[0]); + } + + [Fact] + public void RequireAuthorization_PolicyCallbackWithAuthorize() + { + // Arrange + var builder = new TestEndpointConventionBuilder(); + var authorize = new AuthorizeAttribute(); + var requirement = new TestRequirement(); + + // Act + builder.RequireAuthorization(policyBuilder => policyBuilder.Requirements.Add(requirement)); + + // Assert + var convention = Assert.Single(builder.Conventions); + + var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0); + endpointModel.Metadata.Add(authorize); + convention(endpointModel); + + // Confirm that we don't add another authorize if one already exists + Assert.Equal(2, endpointModel.Metadata.Count); + Assert.Equal(authorize, endpointModel.Metadata[0]); + var policy = Assert.IsAssignableFrom(endpointModel.Metadata[1]); + Assert.Equal(1, policy.Requirements.Count); + Assert.Equal(requirement, policy.Requirements[0]); + } + + [Fact] + public void RequireAuthorization_PolicyWithAuthorize() + { + // Arrange + var builder = new TestEndpointConventionBuilder(); + var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build(); + var authorize = new AuthorizeAttribute(); + + // Act + builder.RequireAuthorization(policy); + + // Assert + var convention = Assert.Single(builder.Conventions); + + var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0); + endpointModel.Metadata.Add(authorize); + convention(endpointModel); + + // Confirm that we don't add another authorize if one already exists + Assert.Equal(2, endpointModel.Metadata.Count); + Assert.Equal(authorize, endpointModel.Metadata[0]); + Assert.Equal(policy, endpointModel.Metadata[1]); + } + + class TestRequirement : IAuthorizationRequirement { } + [Fact] public void AllowAnonymous_Default() { diff --git a/src/Security/Authorization/test/AuthorizationMiddlewareTests.cs b/src/Security/Authorization/test/AuthorizationMiddlewareTests.cs index d6548f532191..15ee3bc641c9 100644 --- a/src/Security/Authorization/test/AuthorizationMiddlewareTests.cs +++ b/src/Security/Authorization/test/AuthorizationMiddlewareTests.cs @@ -207,6 +207,28 @@ public async Task OnAuthorizationAsync_WillCallPolicyProvider() Assert.Equal(3, next.CalledCount); } + [Fact] + public async Task CanApplyPolicyDirectlyToEndpoint() + { + // Arrange + var calledPolicy = false; + var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ => + { + calledPolicy = true; + return true; + }).Build(); + + var policyProvider = new Mock(); + policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(new AuthorizationPolicyBuilder().RequireAuthenticatedUser().Build()); + var next = new TestRequestDelegate(); + var middleware = CreateMiddleware(next.Invoke, policyProvider.Object); + var context = GetHttpContext(anonymous: false, endpoint: CreateEndpoint(new AuthorizeAttribute(), policy)); + + // Act & Assert + await middleware.Invoke(context); + Assert.True(calledPolicy); + } + [Fact] public async Task Invoke_ValidClaimShouldNotFail() { diff --git a/src/Security/Authorization/test/AuthorizationPolicyFacts.cs b/src/Security/Authorization/test/AuthorizationPolicyFacts.cs index 58af80706b0a..9a3a4d78a89a 100644 --- a/src/Security/Authorization/test/AuthorizationPolicyFacts.cs +++ b/src/Security/Authorization/test/AuthorizationPolicyFacts.cs @@ -43,6 +43,29 @@ public async Task CanCombineAuthorizeAttributes() Assert.Single(combined.Requirements.OfType()); } + [Fact] + public async Task CanReplaceDefaultPolicyDirectly() + { + // Arrange + var attributes = new AuthorizeAttribute[] { + new AuthorizeAttribute(), + new AuthorizeAttribute(), + }; + + var policies = new[] { new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build() }; + + var options = new AuthorizationOptions(); + + var provider = new DefaultAuthorizationPolicyProvider(Options.Create(options)); + + // Act + var combined = await AuthorizationPolicy.CombineAsync(provider, attributes, policies); + + // Assert + Assert.Equal(1, combined.Requirements.Count); + Assert.Empty(combined.Requirements.OfType()); + } + [Fact] public async Task CanReplaceDefaultPolicy() { diff --git a/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs b/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs index 8c21a7e23313..9f6aead906a3 100644 --- a/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs +++ b/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs @@ -1,12 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; -using System.Linq; using System.Net.WebSockets; -using System.Reflection; -using System.Threading; -using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; @@ -21,7 +16,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; -using Xunit; using Xunit.Abstractions; namespace Microsoft.AspNetCore.Http.Connections.Tests; @@ -35,6 +29,46 @@ public MapConnectionHandlerTests(ITestOutputHelper output) _output = output; } + [Fact] + public void MapConnectionHandlerFindsMetadataPolicyOnEndPoint() + { + var authCount = 0; + var policy1 = new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build(); + var req = new TestRequirement(); + using (var host = BuildWebHost("/auth", + options => authCount += options.AuthorizationData.Count, + endpoints => endpoints.RequireAuthorization(policy1).RequireAuthorization(pb => pb.Requirements.Add(req)))) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/auth/negotiate", endpoint.DisplayName); + Assert.Single(endpoint.Metadata.GetOrderedMetadata()); + var policies = endpoint.Metadata.GetOrderedMetadata(); + Assert.Equal(2, policies.Count); + Assert.Equal(policy1, policies[0]); + Assert.Equal(1, policies[1].Requirements.Count); + Assert.Equal(req, policies[1].Requirements.First()); + }, + endpoint => + { + Assert.Equal("/auth", endpoint.DisplayName); + Assert.Single(endpoint.Metadata.GetOrderedMetadata()); + var policies = endpoint.Metadata.GetOrderedMetadata(); + Assert.Equal(2, policies.Count); + Assert.Equal(policy1, policies[0]); + Assert.Equal(1, policies[1].Requirements.Count); + Assert.Equal(req, policies[1].Requirements.First()); + }); + } + + Assert.Equal(0, authCount); + } + [Fact] public void MapConnectionHandlerFindsAuthAttributeOnEndPoint() { @@ -421,6 +455,10 @@ public override Task OnConnectedAsync(ConnectionContext connection) } } + private class TestRequirement : IAuthorizationRequirement + { + } + private IHost BuildWebHost(Action configure) { return new HostBuilder() @@ -442,7 +480,7 @@ private IHost BuildWebHost(Action configure) .Build(); } - private IHost BuildWebHost(string path, Action configureOptions) where TConnectionHandler : ConnectionHandler + private IHost BuildWebHost(string path, Action configureOptions, Action configureEndpoints = null) where TConnectionHandler : ConnectionHandler { return new HostBuilder() .ConfigureWebHost(webHostBuilder => @@ -459,7 +497,11 @@ private IHost BuildWebHost(string path, Action { - routes.MapConnectionHandler(path, configureOptions); + var builder = routes.MapConnectionHandler(path, configureOptions); + if (configureEndpoints != null) + { + configureEndpoints(builder); + } }); }) .ConfigureLogging(factory => diff --git a/src/SignalR/server/SignalR/test/MapSignalRTests.cs b/src/SignalR/server/SignalR/test/MapSignalRTests.cs index 4eaaedc43f90..7eeadd7fe437 100644 --- a/src/SignalR/server/SignalR/test/MapSignalRTests.cs +++ b/src/SignalR/server/SignalR/test/MapSignalRTests.cs @@ -99,6 +99,49 @@ public void MapHubFindsAuthAttributeOnHub() Assert.Equal(0, authCount); } + [Fact] + public void MapHubFindsMetadataPolicyOnHub() + { + var authCount = 0; + var policy1 = new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build(); + var req = new TestRequirement(); + using (var host = BuildWebHost(routes => routes.MapHub("/path", options => + { + authCount += options.AuthorizationData.Count; + }) + .RequireAuthorization(policy1) + .RequireAuthorization(policy => policy.AddRequirements(req)))) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Collection(dataSource.Endpoints, + endpoint => + { + Assert.Equal("/path/negotiate", endpoint.DisplayName); + Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata().Count); + var policies = endpoint.Metadata.GetOrderedMetadata(); + Assert.Equal(2, policies.Count); + Assert.Equal(policy1, policies[0]); + Assert.Equal(1, policies[1].Requirements.Count); + Assert.Equal(req, policies[1].Requirements.First()); + }, + endpoint => + { + Assert.Equal("/path", endpoint.DisplayName); + Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata().Count); + var policies = endpoint.Metadata.GetOrderedMetadata(); + Assert.Equal(2, policies.Count); + Assert.Equal(policy1, policies[0]); + Assert.Equal(1, policies[1].Requirements.Count); + Assert.Equal(req, policies[1].Requirements.First()); + }); + } + + Assert.Equal(0, authCount); + } + [Fact] public void MapHubFindsAuthAttributeOnInheritedHub() { @@ -345,6 +388,10 @@ private class AuthHub : Hub { } + private class TestRequirement : IAuthorizationRequirement + { + } + private IHost BuildWebHost(Action configure) { return new HostBuilder()