diff --git a/src/Security/Authorization/Core/src/AuthorizationPolicy.cs b/src/Security/Authorization/Core/src/AuthorizationPolicy.cs index afb62f810d16..1942f7fca2eb 100644 --- a/src/Security/Authorization/Core/src/AuthorizationPolicy.cs +++ b/src/Security/Authorization/Core/src/AuthorizationPolicy.cs @@ -108,7 +108,27 @@ public static AuthorizationPolicy Combine(IEnumerable polic /// A new which represents the combination of the /// authorization policies provided by the specified . /// - public static async Task CombineAsync(IAuthorizationPolicyProvider policyProvider, IEnumerable authorizeData) + public static Task CombineAsync(IAuthorizationPolicyProvider policyProvider, + IEnumerable authorizeData) => CombineAsync(policyProvider, authorizeData, + Enumerable.Empty(), + Enumerable.Empty()); + + /// + /// Combines the provided by the specified + /// . + /// + /// A which provides the policies to combine. + /// A collection of authorization data used to apply authorization to a resource. + /// A collection of policies to combine. + /// A collection of s to add to the auth policy. + /// + /// A new which represents the combination of the + /// authorization policies provided by the specified . + /// + public static async Task CombineAsync(IAuthorizationPolicyProvider policyProvider, + IEnumerable authorizeData, + IEnumerable policies, + IEnumerable requirements) { if (policyProvider == null) { @@ -120,6 +140,9 @@ public static AuthorizationPolicy Combine(IEnumerable polic throw new ArgumentNullException(nameof(authorizeData)); } + var anyPolicies = policies.Any(); + var anyRequirements = requirements.Any(); + // Avoid allocating enumerator if the data is known to be empty var skipEnumeratingData = false; if (authorizeData is IList dataList) @@ -137,7 +160,7 @@ public static AuthorizationPolicy Combine(IEnumerable polic policyBuilder = new AuthorizationPolicyBuilder(); } - var useDefaultPolicy = true; + var useDefaultPolicy = !(anyPolicies || anyRequirements); if (!string.IsNullOrWhiteSpace(authorizeDatum.Policy)) { var policy = await policyProvider.GetPolicyAsync(authorizeDatum.Policy); @@ -176,6 +199,26 @@ public static AuthorizationPolicy Combine(IEnumerable polic } } + if (anyPolicies) + { + policyBuilder ??= new(); + + foreach (var policy in policies) + { + policyBuilder.Combine(policy); + } + } + + if (anyRequirements) + { + policyBuilder ??= new(); + + foreach (var requirement in requirements) + { + policyBuilder.Requirements.Add(requirement); + } + } + // If we have no policy by now, use the fallback policy if we have one if (policyBuilder == null) { diff --git a/src/Security/Authorization/Core/src/PublicAPI.Unshipped.txt b/src/Security/Authorization/Core/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..6fe5d6ca17f8 100644 --- a/src/Security/Authorization/Core/src/PublicAPI.Unshipped.txt +++ b/src/Security/Authorization/Core/src/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ #nullable enable +static Microsoft.AspNetCore.Authorization.AuthorizationPolicy.CombineAsync(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, System.Collections.Generic.IEnumerable! authorizeData, System.Collections.Generic.IEnumerable! policies, System.Collections.Generic.IEnumerable! requirements) -> System.Threading.Tasks.Task! diff --git a/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs b/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs index 516f7f5d0357..ebd57a4570e2 100644 --- a/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs +++ b/src/Security/Authorization/Policy/src/AuthorizationMiddleware.cs @@ -57,7 +57,12 @@ public async Task Invoke(HttpContext context) // IMPORTANT: Changes to authorization logic should be mirrored in MVC's AuthorizeFilter var authorizeData = endpoint?.Metadata.GetOrderedMetadata() ?? Array.Empty(); - var policy = await AuthorizationPolicy.CombineAsync(_policyProvider, authorizeData); + + var policies = endpoint?.Metadata.GetOrderedMetadata() ?? Array.Empty(); + var reqirements = endpoint?.Metadata.GetOrderedMetadata() ?? Array.Empty(); + + var policy = await AuthorizationPolicy.CombineAsync(_policyProvider, authorizeData, policies, reqirements); + if (policy == null) { await _next(context); diff --git a/src/Security/Authorization/test/AuthorizationMiddlewareTests.cs b/src/Security/Authorization/test/AuthorizationMiddlewareTests.cs index d6548f532191..10ae54725eaf 100644 --- a/src/Security/Authorization/test/AuthorizationMiddlewareTests.cs +++ b/src/Security/Authorization/test/AuthorizationMiddlewareTests.cs @@ -207,6 +207,64 @@ public async Task OnAuthorizationAsync_WillCallPolicyProvider() Assert.Equal(3, next.CalledCount); } + [Fact] + public async Task CanApplyPolicyDirectlyToEndpoint() + { + // Arrange + var calledPolicy = false; + var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ => + { + calledPolicy = true; + return true; + }).Build(); + + var policyProvider = new Mock(); + policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(new AuthorizationPolicyBuilder().RequireAuthenticatedUser().Build()); + var next = new TestRequestDelegate(); + var middleware = CreateMiddleware(next.Invoke, policyProvider.Object); + var context = GetHttpContext(anonymous: false, endpoint: CreateEndpoint(new AuthorizeAttribute(), policy)); + + // Act & Assert + await middleware.Invoke(context); + Assert.True(calledPolicy); + } + + [Fact] + public async Task CanApplyAdditonalRequirementsToEndpoint() + { + // Arrange + var policyProvider = new Mock(); + policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(new AuthorizationPolicyBuilder().RequireAuthenticatedUser().Build()); + var next = new TestRequestDelegate(); + var middleware = CreateMiddleware(next.Invoke, policyProvider.Object); + var context = GetHttpContext(anonymous: false, registerServices: services => + { + services.AddSingleton(); + }, + endpoint: CreateEndpoint(new AuthorizeAttribute(), new CustomRequirement("This"))); + + // Act & Assert + await middleware.Invoke(context); + } + + class CustomAuthHandler : AuthorizationHandler + { + protected override Task HandleRequirementAsync(AuthorizationHandlerContext context, CustomRequirement requirement) + { + return Task.CompletedTask; + } + } + + class CustomRequirement : IAuthorizationRequirement + { + public string Data { get; init; } + + public CustomRequirement(string data) + { + Data = data; + } + } + [Fact] public async Task Invoke_ValidClaimShouldNotFail() { diff --git a/src/Security/Authorization/test/AuthorizationPolicyFacts.cs b/src/Security/Authorization/test/AuthorizationPolicyFacts.cs index 58af80706b0a..3fb54f893f4d 100644 --- a/src/Security/Authorization/test/AuthorizationPolicyFacts.cs +++ b/src/Security/Authorization/test/AuthorizationPolicyFacts.cs @@ -43,6 +43,29 @@ public async Task CanCombineAuthorizeAttributes() Assert.Single(combined.Requirements.OfType()); } + [Fact] + public async Task CanReplaceDefaultPolicyDirectly() + { + // Arrange + var attributes = new AuthorizeAttribute[] { + new AuthorizeAttribute(), + new AuthorizeAttribute(), + }; + + var policies = new[] { new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build() }; + + var options = new AuthorizationOptions(); + + var provider = new DefaultAuthorizationPolicyProvider(Options.Create(options)); + + // Act + var combined = await AuthorizationPolicy.CombineAsync(provider, attributes, policies, Enumerable.Empty()); + + // Assert + Assert.Equal(1, combined.Requirements.Count); + Assert.Empty(combined.Requirements.OfType()); + } + [Fact] public async Task CanReplaceDefaultPolicy() {