Skip to content

Allow specifying authz policies on endpoints #41153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 19, 2022
Merged
33 changes: 31 additions & 2 deletions src/Security/Authorization/Core/src/AuthorizationPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,24 @@ public static AuthorizationPolicy Combine(IEnumerable<AuthorizationPolicy> polic
/// A new <see cref="AuthorizationPolicy"/> which represents the combination of the
/// authorization policies provided by the specified <paramref name="policyProvider"/>.
/// </returns>
public static async Task<AuthorizationPolicy?> CombineAsync(IAuthorizationPolicyProvider policyProvider, IEnumerable<IAuthorizeData> authorizeData)
public static Task<AuthorizationPolicy?> CombineAsync(IAuthorizationPolicyProvider policyProvider,
IEnumerable<IAuthorizeData> authorizeData) => CombineAsync(policyProvider, authorizeData,
Enumerable.Empty<AuthorizationPolicy>());

/// <summary>
/// Combines the <see cref="AuthorizationPolicy"/> provided by the specified
/// <paramref name="policyProvider"/>.
/// </summary>
/// <param name="policyProvider">A <see cref="IAuthorizationPolicyProvider"/> which provides the policies to combine.</param>
/// <param name="authorizeData">A collection of authorization data used to apply authorization to a resource.</param>
/// <param name="policies">A collection of <see cref="AuthorizationPolicy"/> policies to combine.</param>
/// <returns>
/// A new <see cref="AuthorizationPolicy"/> which represents the combination of the
/// authorization policies provided by the specified <paramref name="policyProvider"/>.
/// </returns>
public static async Task<AuthorizationPolicy?> CombineAsync(IAuthorizationPolicyProvider policyProvider,
IEnumerable<IAuthorizeData> authorizeData,
IEnumerable<AuthorizationPolicy> policies)
{
if (policyProvider == null)
{
Expand All @@ -120,6 +137,8 @@ public static AuthorizationPolicy Combine(IEnumerable<AuthorizationPolicy> polic
throw new ArgumentNullException(nameof(authorizeData));
}

var anyPolicies = policies.Any();

// Avoid allocating enumerator if the data is known to be empty
var skipEnumeratingData = false;
if (authorizeData is IList<IAuthorizeData> dataList)
Expand All @@ -137,7 +156,7 @@ public static AuthorizationPolicy Combine(IEnumerable<AuthorizationPolicy> polic
policyBuilder = new AuthorizationPolicyBuilder();
}

var useDefaultPolicy = true;
var useDefaultPolicy = !(anyPolicies);
if (!string.IsNullOrWhiteSpace(authorizeDatum.Policy))
{
var policy = await policyProvider.GetPolicyAsync(authorizeDatum.Policy).ConfigureAwait(false);
Expand Down Expand Up @@ -176,6 +195,16 @@ public static AuthorizationPolicy Combine(IEnumerable<AuthorizationPolicy> polic
}
}

if (anyPolicies)
{
policyBuilder ??= new();

foreach (var policy in policies)
{
policyBuilder.Combine(policy);
}
}

// If we have no policy by now, use the fallback policy if we have one
if (policyBuilder == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*REMOVED*~Microsoft.AspNetCore.Authorization.DefaultAuthorizationService.DefaultAuthorizationService(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerProvider! handlers, Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.Authorization.DefaultAuthorizationService!>! logger, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerContextFactory! contextFactory, Microsoft.AspNetCore.Authorization.IAuthorizationEvaluator! evaluator, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Authorization.AuthorizationOptions!>! options) -> void
Microsoft.AspNetCore.Authorization.DefaultAuthorizationPolicyProvider.DefaultAuthorizationPolicyProvider(Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Authorization.AuthorizationOptions!>! options) -> void
Microsoft.AspNetCore.Authorization.DefaultAuthorizationService.DefaultAuthorizationService(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerProvider! handlers, Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.Authorization.DefaultAuthorizationService!>! logger, Microsoft.AspNetCore.Authorization.IAuthorizationHandlerContextFactory! contextFactory, Microsoft.AspNetCore.Authorization.IAuthorizationEvaluator! evaluator, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.Authorization.AuthorizationOptions!>! options) -> void
static Microsoft.AspNetCore.Authorization.AuthorizationPolicy.CombineAsync(Microsoft.AspNetCore.Authorization.IAuthorizationPolicyProvider! policyProvider, System.Collections.Generic.IEnumerable<Microsoft.AspNetCore.Authorization.IAuthorizeData!>! authorizeData, System.Collections.Generic.IEnumerable<Microsoft.AspNetCore.Authorization.AuthorizationPolicy!>! policies) -> System.Threading.Tasks.Task<Microsoft.AspNetCore.Authorization.AuthorizationPolicy?>!
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,55 @@ public static TBuilder RequireAuthorization<TBuilder>(this TBuilder builder, par
return builder;
}

/// <summary>
/// Adds an authorization policy to the endpoint(s).
/// </summary>
/// <param name="builder">The endpoint convention builder.</param>
/// <param name="policy">The <see cref="AuthorizationPolicy"/> policy.</param>
/// <returns>The original convention builder parameter.</returns>
public static TBuilder RequireAuthorization<TBuilder>(this TBuilder builder, AuthorizationPolicy policy)
where TBuilder : IEndpointConventionBuilder
{
if (builder == null)
{
throw new ArgumentNullException(nameof(builder));
}

if (policy == null)
{
throw new ArgumentNullException(nameof(policy));
}

RequirePolicyCore(builder, policy);
return builder;
}

/// <summary>
/// Adds an new authorization policy configured by a callback to the endpoint(s).
/// </summary>
/// <typeparam name="TBuilder"></typeparam>
/// <param name="builder">The endpoint convention builder.</param>
/// <param name="configurePolicy">The callback used to configure the policy.</param>
/// <returns>The original convention builder parameter.</returns>
public static TBuilder RequireAuthorization<TBuilder>(this TBuilder builder, Action<AuthorizationPolicyBuilder> configurePolicy)
where TBuilder : IEndpointConventionBuilder
{
if (builder == null)
{
throw new ArgumentNullException(nameof(builder));
}

if (configurePolicy == null)
{
throw new ArgumentNullException(nameof(configurePolicy));
}

var policyBuilder = new AuthorizationPolicyBuilder();
configurePolicy(policyBuilder);
RequirePolicyCore(builder, policyBuilder.Build());
return builder;
}

/// <summary>
/// Allows anonymous access to the endpoint by adding <see cref="AllowAnonymousAttribute" /> to the endpoint metadata. This will bypass
/// all authorization checks for the endpoint including the default authorization policy and fallback authorization policy.
Expand All @@ -94,6 +143,20 @@ public static TBuilder AllowAnonymous<TBuilder>(this TBuilder builder) where TBu
return builder;
}

private static void RequirePolicyCore<TBuilder>(TBuilder builder, AuthorizationPolicy policy)
where TBuilder : IEndpointConventionBuilder
{
builder.Add(endpointBuilder =>
{
// Only add an authorize attribute if there isn't one
if (!endpointBuilder.Metadata.Any(meta => meta is IAuthorizeData))
{
endpointBuilder.Metadata.Add(new AuthorizeAttribute());
}
endpointBuilder.Metadata.Add(policy);
});
}

private static void RequireAuthorizationCore<TBuilder>(TBuilder builder, IEnumerable<IAuthorizeData> authorizeData)
where TBuilder : IEndpointConventionBuilder
{
Expand All @@ -105,4 +168,5 @@ private static void RequireAuthorizationCore<TBuilder>(TBuilder builder, IEnumer
}
});
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ public async Task Invoke(HttpContext context)

// IMPORTANT: Changes to authorization logic should be mirrored in MVC's AuthorizeFilter
var authorizeData = endpoint?.Metadata.GetOrderedMetadata<IAuthorizeData>() ?? Array.Empty<IAuthorizeData>();
var policy = await AuthorizationPolicy.CombineAsync(_policyProvider, authorizeData);

var policies = endpoint?.Metadata.GetOrderedMetadata<AuthorizationPolicy>() ?? Array.Empty<AuthorizationPolicy>();

var policy = await AuthorizationPolicy.CombineAsync(_policyProvider, authorizeData, policies);

if (policy == null)
{
await _next(context);
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#nullable enable
static Microsoft.AspNetCore.Builder.AuthorizationEndpointConventionBuilderExtensions.RequireAuthorization<TBuilder>(this TBuilder builder, Microsoft.AspNetCore.Authorization.AuthorizationPolicy! policy) -> TBuilder
static Microsoft.AspNetCore.Builder.AuthorizationEndpointConventionBuilderExtensions.RequireAuthorization<TBuilder>(this TBuilder builder, System.Action<Microsoft.AspNetCore.Authorization.AuthorizationPolicyBuilder!>! configurePolicy) -> TBuilder
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,106 @@ public void RequireAuthorization_ChainedCall()
Assert.True(chainedBuilder.TestProperty);
}

[Fact]
public void RequireAuthorization_Policy()
{
// Arrange
var builder = new TestEndpointConventionBuilder();
var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build();

// Act
builder.RequireAuthorization(policy);

// Assert
var convention = Assert.Single(builder.Conventions);

var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
convention(endpointModel);

Assert.Equal(2, endpointModel.Metadata.Count);
var authMetadata = Assert.IsAssignableFrom<IAuthorizeData>(endpointModel.Metadata[0]);
Assert.Null(authMetadata.Policy);

Assert.Equal(policy, endpointModel.Metadata[1]);
}

[Fact]
public void RequireAuthorization_PolicyCallback()
{
// Arrange
var builder = new TestEndpointConventionBuilder();
var requirement = new TestRequirement();

// Act
builder.RequireAuthorization(policyBuilder => policyBuilder.Requirements.Add(requirement));

// Assert
var convention = Assert.Single(builder.Conventions);

var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
convention(endpointModel);

Assert.Equal(2, endpointModel.Metadata.Count);
var authMetadata = Assert.IsAssignableFrom<IAuthorizeData>(endpointModel.Metadata[0]);
Assert.Null(authMetadata.Policy);

var policy = Assert.IsAssignableFrom<AuthorizationPolicy>(endpointModel.Metadata[1]);
Assert.Equal(1, policy.Requirements.Count);
Assert.Equal(requirement, policy.Requirements[0]);
}

[Fact]
public void RequireAuthorization_PolicyCallbackWithAuthorize()
{
// Arrange
var builder = new TestEndpointConventionBuilder();
var authorize = new AuthorizeAttribute();
var requirement = new TestRequirement();

// Act
builder.RequireAuthorization(policyBuilder => policyBuilder.Requirements.Add(requirement));

// Assert
var convention = Assert.Single(builder.Conventions);

var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
endpointModel.Metadata.Add(authorize);
convention(endpointModel);

// Confirm that we don't add another authorize if one already exists
Assert.Equal(2, endpointModel.Metadata.Count);
Assert.Equal(authorize, endpointModel.Metadata[0]);
var policy = Assert.IsAssignableFrom<AuthorizationPolicy>(endpointModel.Metadata[1]);
Assert.Equal(1, policy.Requirements.Count);
Assert.Equal(requirement, policy.Requirements[0]);
}

[Fact]
public void RequireAuthorization_PolicyWithAuthorize()
{
// Arrange
var builder = new TestEndpointConventionBuilder();
var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build();
var authorize = new AuthorizeAttribute();

// Act
builder.RequireAuthorization(policy);

// Assert
var convention = Assert.Single(builder.Conventions);

var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
endpointModel.Metadata.Add(authorize);
convention(endpointModel);

// Confirm that we don't add another authorize if one already exists
Assert.Equal(2, endpointModel.Metadata.Count);
Assert.Equal(authorize, endpointModel.Metadata[0]);
Assert.Equal(policy, endpointModel.Metadata[1]);
}

class TestRequirement : IAuthorizationRequirement { }

[Fact]
public void AllowAnonymous_Default()
{
Expand Down
22 changes: 22 additions & 0 deletions src/Security/Authorization/test/AuthorizationMiddlewareTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,28 @@ public async Task OnAuthorizationAsync_WillCallPolicyProvider()
Assert.Equal(3, next.CalledCount);
}

[Fact]
public async Task CanApplyPolicyDirectlyToEndpoint()
{
// Arrange
var calledPolicy = false;
var policy = new AuthorizationPolicyBuilder().RequireAssertion(_ =>
{
calledPolicy = true;
return true;
}).Build();

var policyProvider = new Mock<IAuthorizationPolicyProvider>();
policyProvider.Setup(p => p.GetDefaultPolicyAsync()).ReturnsAsync(new AuthorizationPolicyBuilder().RequireAuthenticatedUser().Build());
var next = new TestRequestDelegate();
var middleware = CreateMiddleware(next.Invoke, policyProvider.Object);
var context = GetHttpContext(anonymous: false, endpoint: CreateEndpoint(new AuthorizeAttribute(), policy));

// Act & Assert
await middleware.Invoke(context);
Assert.True(calledPolicy);
}

[Fact]
public async Task Invoke_ValidClaimShouldNotFail()
{
Expand Down
23 changes: 23 additions & 0 deletions src/Security/Authorization/test/AuthorizationPolicyFacts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,29 @@ public async Task CanCombineAuthorizeAttributes()
Assert.Single(combined.Requirements.OfType<RolesAuthorizationRequirement>());
}

[Fact]
public async Task CanReplaceDefaultPolicyDirectly()
{
// Arrange
var attributes = new AuthorizeAttribute[] {
new AuthorizeAttribute(),
new AuthorizeAttribute(),
};

var policies = new[] { new AuthorizationPolicyBuilder().RequireAssertion(_ => true).Build() };

var options = new AuthorizationOptions();

var provider = new DefaultAuthorizationPolicyProvider(Options.Create(options));

// Act
var combined = await AuthorizationPolicy.CombineAsync(provider, attributes, policies);

// Assert
Assert.Equal(1, combined.Requirements.Count);
Assert.Empty(combined.Requirements.OfType<DenyAnonymousAuthorizationRequirement>());
}

[Fact]
public async Task CanReplaceDefaultPolicy()
{
Expand Down
Loading