diff --git a/src/Http/Http.Abstractions/src/Extensions/EndpointBuilder.cs b/src/Http/Http.Abstractions/src/Extensions/EndpointBuilder.cs index 005ecc6ea9a8..3afd75bb4091 100644 --- a/src/Http/Http.Abstractions/src/Extensions/EndpointBuilder.cs +++ b/src/Http/Http.Abstractions/src/Extensions/EndpointBuilder.cs @@ -10,6 +10,11 @@ namespace Microsoft.AspNetCore.Builder; /// public abstract class EndpointBuilder { + /// + /// Gets the list of filters that apply to this endpoint. + /// + public IList> FilterFactories { get; } = new List>(); + /// /// Gets or sets the delegate used to process requests for the endpoint. /// diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt index 43a78d85ac62..cadc3cab2da6 100644 --- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt @@ -6,6 +6,7 @@ abstract Microsoft.AspNetCore.Http.EndpointFilterInvocationContext.GetArgument Microsoft.AspNetCore.Http.HttpContext! Microsoft.AspNetCore.Builder.EndpointBuilder.ApplicationServices.get -> System.IServiceProvider! Microsoft.AspNetCore.Builder.EndpointBuilder.ApplicationServices.set -> void +Microsoft.AspNetCore.Builder.EndpointBuilder.FilterFactories.get -> System.Collections.Generic.IList!>! Microsoft.AspNetCore.Http.AsParametersAttribute Microsoft.AspNetCore.Http.AsParametersAttribute.AsParametersAttribute() -> void Microsoft.AspNetCore.Http.CookieBuilder.Extensions.get -> System.Collections.Generic.IList! diff --git a/src/Http/Routing/src/Builder/EndpointFilterExtensions.cs b/src/Http/Routing/src/Builder/EndpointFilterExtensions.cs index 6205a129a329..91ed5e9c7f4c 100644 --- a/src/Http/Routing/src/Builder/EndpointFilterExtensions.cs +++ b/src/Http/Routing/src/Builder/EndpointFilterExtensions.cs @@ -112,13 +112,7 @@ public static TBuilder AddEndpointFilterFactory(this TBuilder builder, { builder.Add(endpointBuilder => { - if (endpointBuilder is not RouteEndpointBuilder routeEndpointBuilder) - { - return; - } - - routeEndpointBuilder.EndpointFilterFactories ??= new(); - routeEndpointBuilder.EndpointFilterFactories.Add(filterFactory); + endpointBuilder.FilterFactories.Add(filterFactory); }); return builder; diff --git a/src/Http/Routing/src/Patterns/RoutePatternFactory.cs b/src/Http/Routing/src/Patterns/RoutePatternFactory.cs index 441545f5eafd..edfd3a004fc3 100644 --- a/src/Http/Routing/src/Patterns/RoutePatternFactory.cs +++ b/src/Http/Routing/src/Patterns/RoutePatternFactory.cs @@ -1084,7 +1084,15 @@ public static RoutePatternParameterPolicyReference ParameterPolicy(string parame return ParameterPolicyCore(parameterPolicy); } - internal static RoutePattern Combine(RoutePattern? left, RoutePattern right) + /// + /// Creates a that combines the specified patterns. + /// + /// A string representing the first part of the route. + /// A stirng representing the second part of the route. + /// The combined . + /// + /// + public static RoutePattern Combine(RoutePattern? left, RoutePattern right) { static IReadOnlyDictionary CombineDictionaries( IReadOnlyDictionary leftDictionary, diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index 3b87380bf45b..621294b4f017 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -33,6 +33,7 @@ static Microsoft.AspNetCore.Routing.LinkGeneratorRouteValuesAddressExtensions.Ge static Microsoft.AspNetCore.Routing.LinkGeneratorRouteValuesAddressExtensions.GetPathByRouteValues(this Microsoft.AspNetCore.Routing.LinkGenerator! generator, string? routeName, Microsoft.AspNetCore.Routing.RouteValueDictionary? values = null, Microsoft.AspNetCore.Http.PathString pathBase = default(Microsoft.AspNetCore.Http.PathString), Microsoft.AspNetCore.Http.FragmentString fragment = default(Microsoft.AspNetCore.Http.FragmentString), Microsoft.AspNetCore.Routing.LinkOptions? options = null) -> string? static Microsoft.AspNetCore.Routing.LinkGeneratorRouteValuesAddressExtensions.GetUriByRouteValues(this Microsoft.AspNetCore.Routing.LinkGenerator! generator, Microsoft.AspNetCore.Http.HttpContext! httpContext, string? routeName, Microsoft.AspNetCore.Routing.RouteValueDictionary? values = null, string? scheme = null, Microsoft.AspNetCore.Http.HostString? host = null, Microsoft.AspNetCore.Http.PathString? pathBase = null, Microsoft.AspNetCore.Http.FragmentString fragment = default(Microsoft.AspNetCore.Http.FragmentString), Microsoft.AspNetCore.Routing.LinkOptions? options = null) -> string? static Microsoft.AspNetCore.Routing.LinkGeneratorRouteValuesAddressExtensions.GetUriByRouteValues(this Microsoft.AspNetCore.Routing.LinkGenerator! generator, string? routeName, Microsoft.AspNetCore.Routing.RouteValueDictionary! values, string! scheme, Microsoft.AspNetCore.Http.HostString host, Microsoft.AspNetCore.Http.PathString pathBase = default(Microsoft.AspNetCore.Http.PathString), Microsoft.AspNetCore.Http.FragmentString fragment = default(Microsoft.AspNetCore.Http.FragmentString), Microsoft.AspNetCore.Routing.LinkOptions? options = null) -> string? +static Microsoft.AspNetCore.Routing.Patterns.RoutePatternFactory.Combine(Microsoft.AspNetCore.Routing.Patterns.RoutePattern? left, Microsoft.AspNetCore.Routing.Patterns.RoutePattern! right) -> Microsoft.AspNetCore.Routing.Patterns.RoutePattern! static Microsoft.AspNetCore.Routing.Patterns.RoutePatternFactory.Parse(string! pattern, Microsoft.AspNetCore.Routing.RouteValueDictionary? defaults, Microsoft.AspNetCore.Routing.RouteValueDictionary? parameterPolicies) -> Microsoft.AspNetCore.Routing.Patterns.RoutePattern! static Microsoft.AspNetCore.Routing.Patterns.RoutePatternFactory.Parse(string! pattern, Microsoft.AspNetCore.Routing.RouteValueDictionary? defaults, Microsoft.AspNetCore.Routing.RouteValueDictionary? parameterPolicies, Microsoft.AspNetCore.Routing.RouteValueDictionary? requiredValues) -> Microsoft.AspNetCore.Routing.Patterns.RoutePattern! static Microsoft.AspNetCore.Routing.Patterns.RoutePatternFactory.Pattern(Microsoft.AspNetCore.Routing.RouteValueDictionary? defaults, Microsoft.AspNetCore.Routing.RouteValueDictionary? parameterPolicies, System.Collections.Generic.IEnumerable! segments) -> Microsoft.AspNetCore.Routing.Patterns.RoutePattern! diff --git a/src/Http/Routing/src/RouteEndpointBuilder.cs b/src/Http/Routing/src/RouteEndpointBuilder.cs index 2660c24bb261..3b96d85ff466 100644 --- a/src/Http/Routing/src/RouteEndpointBuilder.cs +++ b/src/Http/Routing/src/RouteEndpointBuilder.cs @@ -12,10 +12,6 @@ namespace Microsoft.AspNetCore.Routing; /// public sealed class RouteEndpointBuilder : EndpointBuilder { - // TODO: Make this public as a gettable IReadOnlyList>. - // AddEndpointFilter will still be the only way to mutate this list. - internal List>? EndpointFilterFactories { get; set; } - /// /// Gets or sets the associated with this endpoint. /// diff --git a/src/Http/Routing/src/RouteEndpointDataSource.cs b/src/Http/Routing/src/RouteEndpointDataSource.cs index 476b032fbf2d..64c288819716 100644 --- a/src/Http/Routing/src/RouteEndpointDataSource.cs +++ b/src/Http/Routing/src/RouteEndpointDataSource.cs @@ -56,7 +56,7 @@ public RouteHandlerBuilder AddRouteHandler( if (isFallback) { routeAttributes |= RouteAttributes.Fallback; - } + } _routeEntries.Add(new() { @@ -196,7 +196,7 @@ private RouteEndpointBuilder CreateRouteEndpointBuilder( entrySpecificConvention(builder); } - if (isRouteHandler || builder.EndpointFilterFactories is { Count: > 0}) + if (isRouteHandler || builder.FilterFactories.Count > 0) { var routeParamNames = new List(pattern.Parameters.Count); foreach (var parameter in pattern.Parameters) @@ -211,7 +211,7 @@ private RouteEndpointBuilder CreateRouteEndpointBuilder( ThrowOnBadRequest = _throwOnBadRequest, DisableInferBodyFromParameters = ShouldDisableInferredBodyParameters(entry.HttpMethods), EndpointMetadata = builder.Metadata, - EndpointFilterFactories = builder.EndpointFilterFactories, + EndpointFilterFactories = builder.FilterFactories.AsReadOnly(), }; // We ignore the returned EndpointMetadata has been already populated since we passed in non-null EndpointMetadata. diff --git a/src/Http/Routing/test/UnitTests/RouteEndpointBuilderTest.cs b/src/Http/Routing/test/UnitTests/RouteEndpointBuilderTest.cs index 1bfc5ecc724c..508871016482 100644 --- a/src/Http/Routing/test/UnitTests/RouteEndpointBuilderTest.cs +++ b/src/Http/Routing/test/UnitTests/RouteEndpointBuilderTest.cs @@ -45,8 +45,7 @@ public async void Build_DoesNot_RunFilters() var builder = new RouteEndpointBuilder(requestDelegate, RoutePatternFactory.Parse("/"), defaultOrder); - builder.EndpointFilterFactories = new List>(); - builder.EndpointFilterFactories.Add((endopintContext, next) => + builder.FilterFactories.Add((endopintContext, next) => { endpointFilterCallCount++; diff --git a/src/Mvc/Mvc.Core/src/Controllers/ControllerActionDescriptor.cs b/src/Mvc/Mvc.Core/src/Controllers/ControllerActionDescriptor.cs index ba84f8389eda..d8a9ee154aa3 100644 --- a/src/Mvc/Mvc.Core/src/Controllers/ControllerActionDescriptor.cs +++ b/src/Mvc/Mvc.Core/src/Controllers/ControllerActionDescriptor.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Globalization; using System.Reflection; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc.Abstractions; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.Extensions.Internal; @@ -36,6 +37,8 @@ public class ControllerActionDescriptor : ActionDescriptor /// public TypeInfo ControllerTypeInfo { get; set; } = default!; + internal EndpointFilterDelegate? FilterDelegate { get; set; } + // Cache entry so we can avoid an external cache internal ControllerActionInvokerCacheEntry? CacheEntry { get; set; } diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/ActionMethodExecutor.cs b/src/Mvc/Mvc.Core/src/Infrastructure/ActionMethodExecutor.cs index 037d045a986f..57c63674fa39 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/ActionMethodExecutor.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/ActionMethodExecutor.cs @@ -4,8 +4,9 @@ #nullable enable using System.Diagnostics; -using Microsoft.AspNetCore.Mvc.Core; +using Microsoft.AspNetCore.Mvc.Controllers; using Microsoft.Extensions.Internal; +using Resources = Microsoft.AspNetCore.Mvc.Core.Resources; namespace Microsoft.AspNetCore.Mvc.Infrastructure; @@ -26,7 +27,10 @@ internal abstract class ActionMethodExecutor new AwaitableObjectResultExecutor(), }; + public static EmptyResult EmptyResultInstance { get; } = new(); + public abstract ValueTask Execute( + ActionContext actionContext, IActionResultTypeMapper mapper, ObjectMethodExecutor executor, object controller, @@ -34,6 +38,8 @@ public abstract ValueTask Execute( protected abstract bool CanExecute(ObjectMethodExecutor executor); + public abstract ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext); + public static ActionMethodExecutor GetExecutor(ObjectMethodExecutor executor) { for (var i = 0; i < Executors.Length; i++) @@ -48,17 +54,65 @@ public static ActionMethodExecutor GetExecutor(ObjectMethodExecutor executor) throw new Exception(); } + public static ActionMethodExecutor GetFilterExecutor(ControllerActionDescriptor actionDescriptor) => + new FilterActionMethodExecutor(actionDescriptor); + + private sealed class FilterActionMethodExecutor : ActionMethodExecutor + { + private readonly ControllerActionDescriptor _controllerActionDescriptor; + + public FilterActionMethodExecutor(ControllerActionDescriptor controllerActionDescriptor) + { + _controllerActionDescriptor = controllerActionDescriptor; + } + + public override async ValueTask Execute( + ActionContext actionContext, + IActionResultTypeMapper mapper, + ObjectMethodExecutor executor, + object controller, + object?[]? arguments) + { + var context = new ControllerEndpointFilterInvocationContext(_controllerActionDescriptor, actionContext, executor, mapper, controller, arguments); + var result = await _controllerActionDescriptor.FilterDelegate!(context); + return ConvertToActionResult(mapper, result, executor.IsMethodAsync ? executor.AsyncResultType! : executor.MethodReturnType); + } + + public override ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext) + { + // This is never called + throw new NotSupportedException(); + } + + protected override bool CanExecute(ObjectMethodExecutor executor) + { + // This is never called + throw new NotSupportedException(); + } + } + // void LogMessage(..) private sealed class VoidResultExecutor : ActionMethodExecutor { public override ValueTask Execute( + ActionContext actionContext, IActionResultTypeMapper mapper, ObjectMethodExecutor executor, object controller, object?[]? arguments) { executor.Execute(controller, arguments); - return new ValueTask(new EmptyResult()); + return new(EmptyResultInstance); + } + + public override ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext) + { + var executor = invocationContext.Executor; + var controller = invocationContext.Controller; + var arguments = (object[])invocationContext.Arguments; + + executor.Execute(controller, arguments); + return new(EmptyResultInstance); } protected override bool CanExecute(ObjectMethodExecutor executor) @@ -70,6 +124,7 @@ protected override bool CanExecute(ObjectMethodExecutor executor) private sealed class SyncActionResultExecutor : ActionMethodExecutor { public override ValueTask Execute( + ActionContext actionContext, IActionResultTypeMapper mapper, ObjectMethodExecutor executor, object controller, @@ -78,7 +133,19 @@ public override ValueTask Execute( var actionResult = (IActionResult)executor.Execute(controller, arguments)!; EnsureActionResultNotNull(executor, actionResult); - return new ValueTask(actionResult); + return new(actionResult); + } + + public override ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext) + { + var executor = invocationContext.Executor; + var controller = invocationContext.Controller; + var arguments = (object[])invocationContext.Arguments; + + var actionResult = (IActionResult)executor.Execute(controller, arguments)!; + EnsureActionResultNotNull(executor, actionResult); + + return new(actionResult); } protected override bool CanExecute(ObjectMethodExecutor executor) @@ -90,6 +157,7 @@ protected override bool CanExecute(ObjectMethodExecutor executor) private sealed class SyncObjectResultExecutor : ActionMethodExecutor { public override ValueTask Execute( + ActionContext actionContext, IActionResultTypeMapper mapper, ObjectMethodExecutor executor, object controller, @@ -98,7 +166,20 @@ public override ValueTask Execute( // Sync method returning arbitrary object var returnValue = executor.Execute(controller, arguments); var actionResult = ConvertToActionResult(mapper, returnValue, executor.MethodReturnType); - return new ValueTask(actionResult); + return new(actionResult); + } + + public override ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext) + { + var executor = invocationContext.Executor; + var controller = invocationContext.Controller; + var arguments = (object[])invocationContext.Arguments; + var mapper = invocationContext.Mapper; + + // Sync method returning arbitrary object + var returnValue = executor.Execute(controller, arguments); + var actionResult = ConvertToActionResult(mapper, returnValue, executor.MethodReturnType); + return new(actionResult); } // Catch-all for sync methods @@ -109,13 +190,24 @@ public override ValueTask Execute( private sealed class TaskResultExecutor : ActionMethodExecutor { public override async ValueTask Execute( + ActionContext actionContext, IActionResultTypeMapper mapper, ObjectMethodExecutor executor, object controller, object?[]? arguments) { await (Task)executor.Execute(controller, arguments)!; - return new EmptyResult(); + return EmptyResultInstance; + } + + public override async ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext) + { + var executor = invocationContext.Executor; + var controller = invocationContext.Controller; + var arguments = (object[])invocationContext.Arguments; + + await (Task)executor.Execute(controller, arguments)!; + return EmptyResultInstance; } protected override bool CanExecute(ObjectMethodExecutor executor) => executor.MethodReturnType == typeof(Task); @@ -126,13 +218,24 @@ public override async ValueTask Execute( private sealed class AwaitableResultExecutor : ActionMethodExecutor { public override async ValueTask Execute( + ActionContext actionContext, IActionResultTypeMapper mapper, ObjectMethodExecutor executor, object controller, object?[]? arguments) { await executor.ExecuteAsync(controller, arguments); - return new EmptyResult(); + return EmptyResultInstance; + } + + public override async ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext) + { + var executor = invocationContext.Executor; + var controller = invocationContext.Controller; + var arguments = (object[])invocationContext.Arguments; + + await executor.ExecuteAsync(controller, arguments); + return EmptyResultInstance; } protected override bool CanExecute(ObjectMethodExecutor executor) @@ -146,6 +249,7 @@ protected override bool CanExecute(ObjectMethodExecutor executor) private sealed class TaskOfIActionResultExecutor : ActionMethodExecutor { public override async ValueTask Execute( + ActionContext actionContext, IActionResultTypeMapper mapper, ObjectMethodExecutor executor, object controller, @@ -160,6 +264,21 @@ public override async ValueTask Execute( return actionResult; } + public override async ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext) + { + var executor = invocationContext.Executor; + var controller = invocationContext.Controller; + var arguments = (object[])invocationContext.Arguments; + + // Async method returning Task + // Avoid extra allocations by calling Execute rather than ExecuteAsync and casting to Task. + var returnValue = executor.Execute(controller, arguments); + var actionResult = await (Task)returnValue!; + EnsureActionResultNotNull(executor, actionResult); + + return actionResult; + } + protected override bool CanExecute(ObjectMethodExecutor executor) => typeof(Task).IsAssignableFrom(executor.MethodReturnType); } @@ -169,6 +288,7 @@ protected override bool CanExecute(ObjectMethodExecutor executor) private sealed class TaskOfActionResultExecutor : ActionMethodExecutor { public override async ValueTask Execute( + ActionContext actionContext, IActionResultTypeMapper mapper, ObjectMethodExecutor executor, object controller, @@ -181,6 +301,19 @@ public override async ValueTask Execute( return actionResult; } + public override async ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext) + { + var executor = invocationContext.Executor; + var controller = invocationContext.Controller; + var arguments = (object[])invocationContext.Arguments; + + // Async method returning awaitable-of-IActionResult (e.g., Task) + // We have to use ExecuteAsync because we don't know the awaitable's type at compile time. + var actionResult = (IActionResult)await executor.ExecuteAsync(controller, arguments); + EnsureActionResultNotNull(executor, actionResult); + return actionResult; + } + protected override bool CanExecute(ObjectMethodExecutor executor) { // Async method returning awaitable-of - IActionResult(e.g., Task) @@ -193,6 +326,7 @@ protected override bool CanExecute(ObjectMethodExecutor executor) private sealed class AwaitableObjectResultExecutor : ActionMethodExecutor { public override async ValueTask Execute( + ActionContext actionContext, IActionResultTypeMapper mapper, ObjectMethodExecutor executor, object controller, @@ -204,6 +338,18 @@ public override async ValueTask Execute( return actionResult; } + public override async ValueTask Execute(ControllerEndpointFilterInvocationContext invocationContext) + { + var executor = invocationContext.Executor; + var controller = invocationContext.Controller; + var arguments = (object[])invocationContext.Arguments; + var mapper = invocationContext.Mapper; + + var returnValue = await executor.ExecuteAsync(controller, arguments); + var actionResult = ConvertToActionResult(mapper, returnValue, executor.AsyncResultType!); + return actionResult; + } + protected override bool CanExecute(ObjectMethodExecutor executor) => true; } diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvoker.cs b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvoker.cs index 0f1be9b5c6d6..85926ea101e7 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvoker.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvoker.cs @@ -390,7 +390,7 @@ private Task InvokeActionMethodAsync() var actionMethodExecutor = _cacheEntry.ActionMethodExecutor; var orderedArguments = PrepareArguments(_arguments, objectMethodExecutor); - var actionResultValueTask = actionMethodExecutor.Execute(_mapper, objectMethodExecutor, _instance!, orderedArguments); + var actionResultValueTask = actionMethodExecutor.Execute(ControllerContext, _mapper, objectMethodExecutor, _instance!, orderedArguments); if (actionResultValueTask.IsCompletedSuccessfully) { _result = actionResultValueTask.Result; @@ -428,7 +428,7 @@ static async Task Logged(ControllerActionInvoker invoker) controller); Log.ActionMethodExecuting(logger, controllerContext, orderedArguments); var stopwatch = ValueStopwatch.StartNew(); - var actionResultValueTask = actionMethodExecutor.Execute(invoker._mapper, objectMethodExecutor, controller!, orderedArguments); + var actionResultValueTask = actionMethodExecutor.Execute(controllerContext, invoker._mapper, objectMethodExecutor, controller!, orderedArguments); if (actionResultValueTask.IsCompletedSuccessfully) { result = actionResultValueTask.Result; diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerCache.cs b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerCache.cs index d9668df8b6bd..da2830c746f8 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerCache.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerCache.cs @@ -69,6 +69,9 @@ public ControllerActionInvokerCache( _mvcOptions); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterExecutor = actionDescriptor.FilterDelegate is not null + ? ActionMethodExecutor.GetFilterExecutor(actionDescriptor) + : null; cacheEntry = new ControllerActionInvokerCacheEntry( filterFactoryResult.CacheableFilters, @@ -76,6 +79,7 @@ public ControllerActionInvokerCache( controllerReleaser, propertyBinderFactory, objectMethodExecutor, + filterExecutor ?? actionMethodExecutor, actionMethodExecutor); actionDescriptor.CacheEntry = cacheEntry; diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerCacheEntry.cs b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerCacheEntry.cs index 802d498cd9e5..43490a965def 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerCacheEntry.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvokerCacheEntry.cs @@ -15,7 +15,8 @@ internal ControllerActionInvokerCacheEntry( Func? controllerReleaser, ControllerBinderDelegate? controllerBinderDelegate, ObjectMethodExecutor objectMethodExecutor, - ActionMethodExecutor actionMethodExecutor) + ActionMethodExecutor actionMethodExecutor, + ActionMethodExecutor innerActionMethodExecutor) { ControllerFactory = controllerFactory; ControllerReleaser = controllerReleaser; @@ -23,6 +24,7 @@ internal ControllerActionInvokerCacheEntry( CachedFilters = cachedFilters; ObjectMethodExecutor = objectMethodExecutor; ActionMethodExecutor = actionMethodExecutor; + InnerActionMethodExecutor = innerActionMethodExecutor; } public FilterItem[] CachedFilters { get; } @@ -35,5 +37,9 @@ internal ControllerActionInvokerCacheEntry( internal ObjectMethodExecutor ObjectMethodExecutor { get; } + // This includes the execution of the filter delegate (if there's a filter) internal ActionMethodExecutor ActionMethodExecutor { get; } + + // This is called inside of the filter delegate + internal ActionMethodExecutor InnerActionMethodExecutor { get; } } diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerEndpointFilterInvocationContext.cs b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerEndpointFilterInvocationContext.cs new file mode 100644 index 000000000000..c93c24ef44ae --- /dev/null +++ b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerEndpointFilterInvocationContext.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#nullable enable + +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Controllers; +using Microsoft.Extensions.Internal; + +namespace Microsoft.AspNetCore.Mvc.Infrastructure; + +internal class ControllerEndpointFilterInvocationContext : EndpointFilterInvocationContext +{ + public ControllerEndpointFilterInvocationContext( + ControllerActionDescriptor actionDescriptor, + ActionContext actionContext, + ObjectMethodExecutor executor, + IActionResultTypeMapper mapper, + object controller, + object?[]? arguments) + { + ActionDescriptor = actionDescriptor; + ActionContext = actionContext; + Mapper = mapper; + Executor = executor; + Controller = controller; + Arguments = arguments ?? Array.Empty(); + } + + public object Controller { get; } + + internal IActionResultTypeMapper Mapper { get; } + + internal ActionContext ActionContext { get; } + + internal ObjectMethodExecutor Executor { get; } + + internal ControllerActionDescriptor ActionDescriptor { get; } + + public override HttpContext HttpContext => ActionContext.HttpContext; + + public override IList Arguments { get; } + + public override T GetArgument(int index) + { + return (T)Arguments[index]!; + } +} diff --git a/src/Mvc/Mvc.Core/src/Routing/ActionEndpointDataSourceBase.cs b/src/Mvc/Mvc.Core/src/Routing/ActionEndpointDataSourceBase.cs index 9d04bc583a34..18f7077342c3 100644 --- a/src/Mvc/Mvc.Core/src/Routing/ActionEndpointDataSourceBase.cs +++ b/src/Mvc/Mvc.Core/src/Routing/ActionEndpointDataSourceBase.cs @@ -7,6 +7,7 @@ using Microsoft.AspNetCore.Mvc.Abstractions; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Patterns; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Mvc.Routing; @@ -46,8 +47,21 @@ public override IReadOnlyList Endpoints } } + public override IReadOnlyList GetGroupedEndpoints(RouteGroupContext context) + { + return CreateEndpoints( + context.Prefix, + _actions.ActionDescriptors.Items, + Conventions, + context.Conventions); + } + // Will be called with the lock. - protected abstract List CreateEndpoints(IReadOnlyList actions, IReadOnlyList> conventions); + protected abstract List CreateEndpoints( + RoutePattern? groupPrefix, + IReadOnlyList actions, + IReadOnlyList> conventions, + IReadOnlyList> groupConventions); protected void Subscribe() { @@ -97,7 +111,7 @@ private void UpdateEndpoints() { lock (Lock) { - var endpoints = CreateEndpoints(_actions.ActionDescriptors.Items, Conventions); + var endpoints = CreateEndpoints(groupPrefix: null, _actions.ActionDescriptors.Items, Conventions, Array.Empty>()); // See comments in DefaultActionDescriptorCollectionProvider. These steps are done // in a specific order to ensure callers always see a consistent state. diff --git a/src/Mvc/Mvc.Core/src/Routing/ActionEndpointFactory.cs b/src/Mvc/Mvc.Core/src/Routing/ActionEndpointFactory.cs index 2de07901be75..50d2650f5d1d 100644 --- a/src/Mvc/Mvc.Core/src/Routing/ActionEndpointFactory.cs +++ b/src/Mvc/Mvc.Core/src/Routing/ActionEndpointFactory.cs @@ -7,6 +7,7 @@ using Microsoft.AspNetCore.Http.Metadata; using Microsoft.AspNetCore.Mvc.Abstractions; using Microsoft.AspNetCore.Mvc.ActionConstraints; +using Microsoft.AspNetCore.Mvc.Controllers; using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.Routing; @@ -20,8 +21,11 @@ internal sealed class ActionEndpointFactory private readonly RoutePatternTransformer _routePatternTransformer; private readonly RequestDelegate _requestDelegate; private readonly IRequestDelegateFactory[] _requestDelegateFactories; + private readonly IServiceProvider _serviceProvider; - public ActionEndpointFactory(RoutePatternTransformer routePatternTransformer, IEnumerable requestDelegateFactories) + public ActionEndpointFactory(RoutePatternTransformer routePatternTransformer, + IEnumerable requestDelegateFactories, + IServiceProvider serviceProvider) { if (routePatternTransformer == null) { @@ -31,6 +35,7 @@ public ActionEndpointFactory(RoutePatternTransformer routePatternTransformer, IE _routePatternTransformer = routePatternTransformer; _requestDelegate = CreateRequestDelegate(); _requestDelegateFactories = requestDelegateFactories.ToArray(); + _serviceProvider = serviceProvider; } public void AddEndpoints( @@ -38,8 +43,10 @@ public void AddEndpoints( HashSet routeNames, ActionDescriptor action, IReadOnlyList routes, + IReadOnlyList> groupConventions, IReadOnlyList> conventions, - bool createInertEndpoints) + bool createInertEndpoints, + RoutePattern? groupPrefix = null) { if (endpoints == null) { @@ -81,8 +88,9 @@ public void AddEndpoints( dataTokens: null, suppressLinkGeneration: false, suppressPathMatching: false, - conventions, - Array.Empty>()); + groupConventions: groupConventions, + conventions: conventions, + perRouteConventions: Array.Empty>()); endpoints.Add(builder.Build()); } @@ -102,6 +110,8 @@ public void AddEndpoints( continue; } + updatedRoutePattern = RoutePatternFactory.Combine(groupPrefix, updatedRoutePattern); + var requestDelegate = CreateRequestDelegate(action, route.DataTokens) ?? _requestDelegate; // We suppress link generation for each conventionally routed endpoint. We generate a single endpoint per-route @@ -118,8 +128,9 @@ public void AddEndpoints( route.DataTokens, suppressLinkGeneration: true, suppressPathMatching: false, - conventions, - route.Conventions); + groupConventions: groupConventions, + conventions: conventions, + perRouteConventions: route.Conventions); endpoints.Add(builder.Build()); } } @@ -145,6 +156,8 @@ public void AddEndpoints( "To fix this error, choose a different parameter name."); } + updatedRoutePattern = RoutePatternFactory.Combine(groupPrefix, updatedRoutePattern); + var builder = new RouteEndpointBuilder(requestDelegate, updatedRoutePattern, action.AttributeRouteInfo.Order) { DisplayName = action.DisplayName, @@ -157,7 +170,8 @@ public void AddEndpoints( dataTokens: null, action.AttributeRouteInfo.SuppressLinkGeneration, action.AttributeRouteInfo.SuppressPathMatching, - conventions, + groupConventions: groupConventions, + conventions: conventions, perRouteConventions: Array.Empty>()); endpoints.Add(builder.Build()); } @@ -168,7 +182,9 @@ public void AddConventionalLinkGenerationRoute( HashSet routeNames, HashSet keys, ConventionalRouteEntry route, - IReadOnlyList> conventions) + IReadOnlyList> groupConventions, + IReadOnlyList> conventions, + RoutePattern? groupPrefix = null) { if (endpoints == null) { @@ -212,13 +228,15 @@ public void AddConventionalLinkGenerationRoute( throw new InvalidOperationException("Failed to create a conventional route for pattern: " + route.Pattern); } + pattern = RoutePatternFactory.Combine(groupPrefix, pattern); + var builder = new RouteEndpointBuilder(context => Task.CompletedTask, pattern, route.Order) { DisplayName = "Route: " + route.Pattern.RawText, Metadata = - { - new SuppressMatchingMetadata(), - }, + { + new SuppressMatchingMetadata(), + }, }; if (route.RouteName != null) @@ -236,6 +254,11 @@ public void AddConventionalLinkGenerationRoute( builder.Metadata.Add(new EndpointNameMetadata(route.RouteName)); } + for (var i = 0; i < groupConventions.Count; i++) + { + groupConventions[i](builder); + } + for (var i = 0; i < conventions.Count; i++) { conventions[i](builder); @@ -301,7 +324,7 @@ private static (RoutePattern resolvedRoutePattern, IDictionary return (attributeRoutePattern, resolvedRequiredValues ?? action.RouteValues); } - private static void AddActionDataToBuilder( + private void AddActionDataToBuilder( EndpointBuilder builder, HashSet routeNames, ActionDescriptor action, @@ -309,9 +332,19 @@ private static void AddActionDataToBuilder( RouteValueDictionary? dataTokens, bool suppressLinkGeneration, bool suppressPathMatching, + IReadOnlyList> groupConventions, IReadOnlyList> conventions, IReadOnlyList> perRouteConventions) { + // REVIEW: The RouteEndpointDataSource adds HttpMethodMetadata before running group conventions + // do we need to do the same here? + + // Group metadata has the lowest precedence. + for (var i = 0; i < groupConventions.Count; i++) + { + groupConventions[i](builder); + } + // Add action metadata first so it has a low precedence if (action.EndpointMetadata != null) { @@ -406,6 +439,30 @@ private static void AddActionDataToBuilder( { perRouteConventions[i](builder); } + + if (builder.FilterFactories.Count > 0 && action is ControllerActionDescriptor cad) + { + var routeHandlerFilters = builder.FilterFactories; + + EndpointFilterDelegate del = static invocationContext => + { + // By the time this is called, we have the cache entry + var controllerInvocationContext = (ControllerEndpointFilterInvocationContext)invocationContext; + return controllerInvocationContext.ActionDescriptor.CacheEntry!.InnerActionMethodExecutor.Execute(controllerInvocationContext); + }; + + var context = new EndpointFilterFactoryContext(cad.MethodInfo, builder.Metadata, _serviceProvider); + + var initialFilteredInvocation = del; + + for (var i = routeHandlerFilters.Count - 1; i >= 0; i--) + { + var filterFactory = routeHandlerFilters[i]; + del = filterFactory(context, del); + } + + cad.FilterDelegate = ReferenceEquals(del, initialFilteredInvocation) ? null : del; + } } private RequestDelegate? CreateRequestDelegate(ActionDescriptor action, RouteValueDictionary? dataTokens = null) diff --git a/src/Mvc/Mvc.Core/src/Routing/ControllerActionEndpointDataSource.cs b/src/Mvc/Mvc.Core/src/Routing/ControllerActionEndpointDataSource.cs index 48e9767e1571..583c2a7fa5ac 100644 --- a/src/Mvc/Mvc.Core/src/Routing/ControllerActionEndpointDataSource.cs +++ b/src/Mvc/Mvc.Core/src/Routing/ControllerActionEndpointDataSource.cs @@ -7,6 +7,7 @@ using Microsoft.AspNetCore.Mvc.Controllers; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Patterns; namespace Microsoft.AspNetCore.Mvc.Routing; @@ -60,7 +61,11 @@ public ControllerActionEndpointConventionBuilder AddRoute( } } - protected override List CreateEndpoints(IReadOnlyList actions, IReadOnlyList> conventions) + protected override List CreateEndpoints( + RoutePattern? groupPrefix, + IReadOnlyList actions, + IReadOnlyList> conventions, + IReadOnlyList> groupConventions) { var endpoints = new List(); var keys = new HashSet(StringComparer.OrdinalIgnoreCase); @@ -80,7 +85,14 @@ protected override List CreateEndpoints(IReadOnlyList 0) { @@ -99,7 +111,14 @@ protected override List CreateEndpoints(IReadOnlyList()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert Assert.Equal("VoidResultExecutor", actionMethodExecutor.GetType().Name); @@ -27,17 +35,25 @@ public void ActionMethodExecutor_ExecutesVoidActions() Assert.IsType(valueTask.Result); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsReturningIActionResult() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsReturningIActionResult(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnIActionResult)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert Assert.Equal("SyncActionResultExecutor", actionMethodExecutor.GetType().Name); @@ -45,34 +61,50 @@ public void ActionMethodExecutor_ExecutesActionsReturningIActionResult() Assert.IsType(valueTask.Result); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsReturningSubTypeOfActionResult() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsReturningSubTypeOfActionResult(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnsIActionResultSubType)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert Assert.Equal("SyncActionResultExecutor", actionMethodExecutor.GetType().Name); Assert.IsType(valueTask.Result); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsReturningActionResultOfT() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsReturningActionResultOfT(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnsActionResultOfT)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert var result = Assert.IsType(valueTask.Result); @@ -83,17 +115,25 @@ public void ActionMethodExecutor_ExecutesActionsReturningActionResultOfT() Assert.Equal(typeof(TestModel), result.DeclaredType); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsReturningModelAsModel() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsReturningModelAsModel(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnsModelAsModel)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert var result = Assert.IsType(valueTask.Result); @@ -104,17 +144,25 @@ public void ActionMethodExecutor_ExecutesActionsReturningModelAsModel() Assert.Equal(typeof(TestModel), result.DeclaredType); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsReturningModelAsObject() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsReturningModelAsObject(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnModelAsObject)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert var result = Assert.IsType(valueTask.Result); @@ -125,87 +173,127 @@ public void ActionMethodExecutor_ExecutesActionsReturningModelAsObject() Assert.Equal(typeof(object), result.DeclaredType); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsReturningActionResultAsObject() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsReturningActionResultAsObject(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnsIActionResultSubType)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert Assert.Equal("SyncActionResultExecutor", actionMethodExecutor.GetType().Name); Assert.IsType(valueTask.Result); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsReturnTask() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsReturnTask(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnsTask)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert - Assert.Equal("TaskResultExecutor", actionMethodExecutor.GetType().Name); + Assert.Equal("TaskResultExecutor", actionMethodExecutor.GetType().Name); Assert.True(controller.Executed); Assert.IsType(valueTask.Result); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsReturnAwaitable() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsReturnAwaitable(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnsAwaitable)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var awaitableResult = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert Assert.Equal("AwaitableResultExecutor", actionMethodExecutor.GetType().Name); Assert.True(controller.Executed); - Assert.IsType(awaitableResult.Result); + Assert.IsType(valueTask.Result); } - [Fact] - public void ActionMethodExecutorExecutesActionsAsynchronouslyReturningIActionResult() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutorExecutesActionsAsynchronouslyReturningIActionResult(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnIActionResultAsync)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert Assert.Equal("TaskOfIActionResultExecutor", actionMethodExecutor.GetType().Name); Assert.IsType(valueTask.Result); } - [Fact] - public async Task ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningActionResultSubType() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningActionResultSubType(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnActionResultAsync)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert await valueTask; @@ -213,17 +301,25 @@ public async Task ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningAct Assert.IsType(valueTask.Result); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningModel() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningModel(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnsModelAsModelAsync)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert var result = Assert.IsType(valueTask.Result); @@ -234,17 +330,25 @@ public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningModel() Assert.Equal(typeof(TestModel), result.DeclaredType); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningModelAsObject() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningModelAsObject(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnsModelAsObjectAsync)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert var result = Assert.IsType(valueTask.Result); @@ -255,34 +359,50 @@ public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningModelAsOb Assert.Equal(typeof(object), result.DeclaredType); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningIActionResultAsObject() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningIActionResultAsObject(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnIActionResultAsObjectAsync)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert Assert.Equal("AwaitableObjectResultExecutor", actionMethodExecutor.GetType().Name); Assert.IsType(valueTask.Result); } - [Fact] - public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningActionResultOfT() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningActionResultOfT(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnActionResultOFTAsync)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act - var valueTask = actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty()); + var valueTask = Execute(actionMethodExecutor, filterContext, withFilter); // Assert var result = Assert.IsType(valueTask.Result); @@ -293,22 +413,40 @@ public void ActionMethodExecutor_ExecutesActionsAsynchronouslyReturningActionRes Assert.Equal(typeof(TestModel), result.DeclaredType); } - [Fact] - public void ActionMethodExecutor_ThrowsIfIConvertFromIActionResult_ReturnsNull() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ActionMethodExecutor_ThrowsIfIConvertFromIActionResult_ReturnsNull(bool withFilter) { // Arrange var mapper = new ActionResultTypeMapper(); var controller = new TestController(); var objectMethodExecutor = GetExecutor(nameof(TestController.ReturnsCustomConvertibleFromIActionResult)); var actionMethodExecutor = ActionMethodExecutor.GetExecutor(objectMethodExecutor); + var filterContext = new ControllerEndpointFilterInvocationContext(new Controllers.ControllerActionDescriptor(), + new ActionContext(), + objectMethodExecutor, + mapper, + controller, + Array.Empty()); // Act & Assert - var ex = Assert.Throws( - () => actionMethodExecutor.Execute(mapper, objectMethodExecutor, controller, Array.Empty())); + var ex = await Assert.ThrowsAsync(() => Execute(actionMethodExecutor, filterContext, withFilter).AsTask()); Assert.Equal($"Cannot return null from an action method with a return type of '{typeof(CustomConvertibleFromAction)}'.", ex.Message); } + private async ValueTask Execute(ActionMethodExecutor actionMethodExecutor, + ControllerEndpointFilterInvocationContext context, + bool withFilter) + { + if (withFilter) + { + return (IActionResult)await actionMethodExecutor.Execute(context); + } + return await actionMethodExecutor.Execute(context.ActionContext, context.Mapper, context.Executor, context.Controller, (object[])context.Arguments); + } + private static ObjectMethodExecutor GetExecutor(string methodName) { var type = typeof(TestController); @@ -351,7 +489,7 @@ public YieldAwaitable ReturnsAwaitable() public Task ReturnIActionResultAsync() => Task.FromResult((IActionResult)new StatusCodeResult(201)); - public Task ReturnActionResultAsync() => Task.FromResult(new ViewResult { StatusCode = 200}); + public Task ReturnActionResultAsync() => Task.FromResult(new ViewResult { StatusCode = 200 }); public Task ReturnsIActionResultSubTypeAsync() => Task.FromResult(new StatusCodeResult(200)); diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/ControllerActionInvokerTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/ControllerActionInvokerTest.cs index 1f28d1ed98e5..dc1972fc5827 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/ControllerActionInvokerTest.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/ControllerActionInvokerTest.cs @@ -1333,6 +1333,7 @@ public async Task Invoke_UsesDefaultValuesIfNotBound() (_, __) => default, (_, __, ___) => Task.CompletedTask, objectMethodExecutor, + controllerMethodExecutor, controllerMethodExecutor); var invoker = new ControllerActionInvoker( @@ -1663,6 +1664,7 @@ private ControllerActionInvoker CreateInvoker( return Task.CompletedTask; }, objectMethodExecutor, + actionMethodExecutor, actionMethodExecutor); var actionContext = new ActionContext(httpContext, routeData, actionDescriptor); diff --git a/src/Mvc/Mvc.Core/test/Routing/ActionEndpointDataSourceBaseTest.cs b/src/Mvc/Mvc.Core/test/Routing/ActionEndpointDataSourceBaseTest.cs index 8cfbca80b433..b8168b4941d7 100644 --- a/src/Mvc/Mvc.Core/test/Routing/ActionEndpointDataSourceBaseTest.cs +++ b/src/Mvc/Mvc.Core/test/Routing/ActionEndpointDataSourceBaseTest.cs @@ -146,7 +146,7 @@ private protected ActionEndpointDataSourceBase CreateDataSource(IActionDescripto var serviceProvider = services.BuildServiceProvider(); - var endpointFactory = new ActionEndpointFactory(serviceProvider.GetRequiredService(), Enumerable.Empty()); + var endpointFactory = new ActionEndpointFactory(serviceProvider.GetRequiredService(), Enumerable.Empty(), serviceProvider); return CreateDataSource(actions, endpointFactory); } diff --git a/src/Mvc/Mvc.Core/test/Routing/ActionEndpointFactoryTest.cs b/src/Mvc/Mvc.Core/test/Routing/ActionEndpointFactoryTest.cs index 4d893e84eadd..e98110181c24 100644 --- a/src/Mvc/Mvc.Core/test/Routing/ActionEndpointFactoryTest.cs +++ b/src/Mvc/Mvc.Core/test/Routing/ActionEndpointFactoryTest.cs @@ -25,7 +25,7 @@ public ActionEndpointFactoryTest() }); Services = serviceCollection.BuildServiceProvider(); - Factory = new ActionEndpointFactory(Services.GetRequiredService(), Enumerable.Empty()); + Factory = new ActionEndpointFactory(Services.GetRequiredService(), Enumerable.Empty(), Services); } internal ActionEndpointFactory Factory { get; } @@ -266,10 +266,10 @@ public void RequestDelegateFactoryWorks() requestDelegateFactory.Setup(m => m.CreateRequestDelegate(action, It.IsAny())).Returns(del); // Act - var factory = new ActionEndpointFactory(Services.GetRequiredService(), new[] { requestDelegateFactory.Object }); + var factory = new ActionEndpointFactory(Services.GetRequiredService(), new[] { requestDelegateFactory.Object }, Services); var endpoints = new List(); - factory.AddEndpoints(endpoints, new HashSet(), action, Array.Empty(), Array.Empty>(), createInertEndpoints: false); + factory.AddEndpoints(endpoints, new HashSet(), action, Array.Empty(), Array.Empty>(), Array.Empty>(), createInertEndpoints: false); var endpoint = Assert.IsType(Assert.Single(endpoints)); @@ -366,7 +366,7 @@ public void AddEndpoints_CreatesInertEndpoint() private RouteEndpoint CreateAttributeRoutedEndpoint(ActionDescriptor action) { var endpoints = new List(); - Factory.AddEndpoints(endpoints, new HashSet(), action, Array.Empty(), Array.Empty>(), createInertEndpoints: false); + Factory.AddEndpoints(endpoints, new HashSet(), action, Array.Empty(), Array.Empty>(), Array.Empty>(), createInertEndpoints: false); return Assert.IsType(Assert.Single(endpoints)); } @@ -380,7 +380,7 @@ private RouteEndpoint CreateConventionalRoutedEndpoint(ActionDescriptor action, Assert.NotNull(action.RouteValues); var endpoints = new List(); - Factory.AddEndpoints(endpoints, new HashSet(), action, new[] { route, }, Array.Empty>(), createInertEndpoints: false); + Factory.AddEndpoints(endpoints, new HashSet(), action, new[] { route, }, Array.Empty>(), Array.Empty>(), createInertEndpoints: false); var endpoint = Assert.IsType(Assert.Single(endpoints)); // This should be true for all conventional-routed actions. @@ -397,7 +397,7 @@ private IReadOnlyList CreateConventionalRoutedEndpoints(ActionDescript private IReadOnlyList CreateConventionalRoutedEndpoints(ActionDescriptor action, IReadOnlyList routes, bool createInertEndpoints = false) { var endpoints = new List(); - Factory.AddEndpoints(endpoints, new HashSet(), action, routes, Array.Empty>(), createInertEndpoints); + Factory.AddEndpoints(endpoints, new HashSet(), action, routes, Array.Empty>(), Array.Empty>(), createInertEndpoints); return endpoints.ToList(); } diff --git a/src/Mvc/Mvc.Core/test/Routing/ControllerActionEndpointDataSourceTest.cs b/src/Mvc/Mvc.Core/test/Routing/ControllerActionEndpointDataSourceTest.cs index ccc3972fec1f..f6ea46192109 100644 --- a/src/Mvc/Mvc.Core/test/Routing/ControllerActionEndpointDataSourceTest.cs +++ b/src/Mvc/Mvc.Core/test/Routing/ControllerActionEndpointDataSourceTest.cs @@ -1,10 +1,12 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Mvc.Abstractions; using Microsoft.AspNetCore.Mvc.Controllers; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Patterns; using Moq; namespace Microsoft.AspNetCore.Mvc.Routing; @@ -373,6 +375,106 @@ public void Endpoints_AppliesConventions_RouteSpecificMetadata() }); } + [Fact] + public void GroupedEndpoints_AppliesConventions_RouteSpecificMetadata() + { + // Arrange + var actions = new List + { + new ControllerActionDescriptor + { + AttributeRouteInfo = new AttributeRouteInfo() + { + Template = "/test", + }, + RouteValues = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "action", "Test" }, + { "controller", "Test" }, + }, + EndpointMetadata = new[] { "A" }, + }, + new ControllerActionDescriptor + { + RouteValues = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "action", "Index" }, + { "controller", "Home" }, + }, + } + }; + + var mockDescriptorProvider = new Mock(); + mockDescriptorProvider.Setup(m => m.ActionDescriptors).Returns(new ActionDescriptorCollection(actions, 0)); + + var dataSource = (ControllerActionEndpointDataSource)CreateDataSource(mockDescriptorProvider.Object); + dataSource.AddRoute("1", "/1/{controller}/{action}/{id?}", null, null, null).Add(b => b.Metadata.Add("A")); + dataSource.AddRoute("2", "/2/{controller}/{action}/{id?}", null, null, null).Add(b => b.Metadata.Add("B")); + + dataSource.DefaultBuilder.Add((b) => + { + b.Metadata.Add("Hi there"); + }); + + // Act + var groupConventions = new List>() + { + b => b.Metadata.Add(new GroupMetadata()), + b => b.Metadata.Add("group") + }; + var sp = Mock.Of(); + var groupPattern = RoutePatternFactory.Parse("/group1"); + var endpoints = dataSource.GetGroupedEndpoints(new RouteGroupContext(groupPattern, groupConventions, sp)); + + // Assert + Assert.Collection( + endpoints.OfType().Where(e => !SupportsLinkGeneration(e)).OrderBy(e => e.RoutePattern.RawText), + e => + { + Assert.Equal("/group1/1/{controller}/{action}/{id?}", e.RoutePattern.RawText); + Assert.Same(actions[1], e.Metadata.GetMetadata()); + Assert.Equal(new[] { "group", "Hi there", "A" }, e.Metadata.GetOrderedMetadata()); + Assert.NotNull(e.Metadata.GetMetadata()); + }, + e => + { + Assert.Equal("/group1/2/{controller}/{action}/{id?}", e.RoutePattern.RawText); + Assert.Same(actions[1], e.Metadata.GetMetadata()); + Assert.Equal(new[] { "group", "Hi there", "B" }, e.Metadata.GetOrderedMetadata()); + Assert.NotNull(e.Metadata.GetMetadata()); + }); + + Assert.Collection( + endpoints.OfType().Where(e => SupportsLinkGeneration(e)).OrderBy(e => e.RoutePattern.RawText), + e => + { + Assert.Equal("/group1/1/{controller}/{action}/{id?}", e.RoutePattern.RawText); + Assert.Null(e.Metadata.GetMetadata()); + // Group conventions are applied first, then endpoint specific metadata, then normal conventions, then per route conventions + Assert.Equal(new[] { "group", "Hi there", "A" }, e.Metadata.GetOrderedMetadata()); + Assert.NotNull(e.Metadata.GetMetadata()); + }, + e => + { + Assert.Equal("/group1/2/{controller}/{action}/{id?}", e.RoutePattern.RawText); + Assert.Null(e.Metadata.GetMetadata()); + Assert.Equal(new[] { "group", "Hi there", "B" }, e.Metadata.GetOrderedMetadata()); + Assert.NotNull(e.Metadata.GetMetadata()); + }, + e => + { + Assert.Equal("/group1/test", e.RoutePattern.RawText); + Assert.Same(actions[0], e.Metadata.GetMetadata()); + // Group conventions are applied first, then endpoint specific metadata, then normal conventions + Assert.Equal(new[] { "group", "A", "Hi there" }, e.Metadata.GetOrderedMetadata()); + Assert.NotNull(e.Metadata.GetMetadata()); + }); + } + + private class GroupMetadata + { + } + private static bool SupportsLinkGeneration(RouteEndpoint endpoint) { return !(endpoint.Metadata.GetMetadata()?.SuppressLinkGeneration == true); diff --git a/src/Mvc/Mvc.RazorPages/src/Infrastructure/DefaultPageLoader.cs b/src/Mvc/Mvc.RazorPages/src/Infrastructure/DefaultPageLoader.cs index 08907ef30c87..d273efbf4405 100644 --- a/src/Mvc/Mvc.RazorPages/src/Infrastructure/DefaultPageLoader.cs +++ b/src/Mvc/Mvc.RazorPages/src/Infrastructure/DefaultPageLoader.cs @@ -71,6 +71,7 @@ private async Task LoadAsyncCore(PageActionDescrip routeNames: new HashSet(StringComparer.OrdinalIgnoreCase), action: compiled, routes: Array.Empty(), + groupConventions: Array.Empty>(), conventions: new Action[] { b => diff --git a/src/Mvc/Mvc.RazorPages/src/Infrastructure/PageActionEndpointDataSource.cs b/src/Mvc/Mvc.RazorPages/src/Infrastructure/PageActionEndpointDataSource.cs index f5c2ed29af98..c6c96cfe2416 100644 --- a/src/Mvc/Mvc.RazorPages/src/Infrastructure/PageActionEndpointDataSource.cs +++ b/src/Mvc/Mvc.RazorPages/src/Infrastructure/PageActionEndpointDataSource.cs @@ -7,6 +7,7 @@ using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.Mvc.Routing; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Patterns; namespace Microsoft.AspNetCore.Mvc.RazorPages.Infrastructure; @@ -40,7 +41,11 @@ public PageActionEndpointDataSource( // selection. Set to true by builder methods that do dynamic/fallback selection. public bool CreateInertEndpoints { get; set; } - protected override List CreateEndpoints(IReadOnlyList actions, IReadOnlyList> conventions) + protected override List CreateEndpoints( + RoutePattern? groupPrefix, + IReadOnlyList actions, + IReadOnlyList> conventions, + IReadOnlyList> groupConventions) { var endpoints = new List(); var routeNames = new HashSet(StringComparer.OrdinalIgnoreCase); @@ -48,7 +53,14 @@ protected override List CreateEndpoints(IReadOnlyList(), conventions, CreateInertEndpoints); + _endpointFactory.AddEndpoints(endpoints, + routeNames, + action, + Array.Empty(), + groupConventions: groupConventions, + conventions: conventions, + CreateInertEndpoints, + groupPrefix); } } diff --git a/src/Mvc/Mvc.RazorPages/test/Infrastructure/DefaultPageLoaderTest.cs b/src/Mvc/Mvc.RazorPages/test/Infrastructure/DefaultPageLoaderTest.cs index 483cab8c52dd..166884b5eb58 100644 --- a/src/Mvc/Mvc.RazorPages/test/Infrastructure/DefaultPageLoaderTest.cs +++ b/src/Mvc/Mvc.RazorPages/test/Infrastructure/DefaultPageLoaderTest.cs @@ -35,7 +35,7 @@ public async Task LoadAsync_InvokesApplicationModelProviders() var compilerProvider = GetCompilerProvider(); var mvcOptions = Options.Create(new MvcOptions()); - var endpointFactory = new ActionEndpointFactory(Mock.Of(), Enumerable.Empty()); + var endpointFactory = new ActionEndpointFactory(Mock.Of(), Enumerable.Empty(), Mock.Of()); var provider1 = new Mock(); var provider2 = new Mock(); @@ -118,7 +118,7 @@ public async Task LoadAsync_CreatesEndpoint_WithRoute() var compilerProvider = GetCompilerProvider(); var mvcOptions = Options.Create(new MvcOptions()); - var endpointFactory = new ActionEndpointFactory(transformer.Object, Enumerable.Empty()); + var endpointFactory = new ActionEndpointFactory(transformer.Object, Enumerable.Empty(), Mock.Of()); var provider = new Mock(); @@ -158,7 +158,7 @@ public async Task LoadAsync_InvokesApplicationModelProviders_WithTheRightOrder() var descriptor = new PageActionDescriptor(); var compilerProvider = GetCompilerProvider(); var mvcOptions = Options.Create(new MvcOptions()); - var endpointFactory = new ActionEndpointFactory(Mock.Of(), Enumerable.Empty()); + var endpointFactory = new ActionEndpointFactory(Mock.Of(), Enumerable.Empty(), Mock.Of()); var provider1 = new Mock(); provider1.SetupGet(p => p.Order).Returns(10); @@ -235,7 +235,7 @@ public async Task LoadAsync_CachesResults() var compilerProvider = GetCompilerProvider(); var mvcOptions = Options.Create(new MvcOptions()); - var endpointFactory = new ActionEndpointFactory(transformer.Object, Enumerable.Empty()); + var endpointFactory = new ActionEndpointFactory(transformer.Object, Enumerable.Empty(), Mock.Of()); var provider = new Mock(); @@ -297,7 +297,7 @@ public async Task LoadAsync_IsUniquePerPageDescriptor() var compilerProvider = GetCompilerProvider(); var mvcOptions = Options.Create(new MvcOptions()); - var endpointFactory = new ActionEndpointFactory(transformer.Object, Enumerable.Empty()); + var endpointFactory = new ActionEndpointFactory(transformer.Object, Enumerable.Empty(), Mock.Of()); var provider = new Mock(); @@ -334,7 +334,7 @@ public async Task LoadAsync_CompiledPageActionDescriptor_ReturnsSelf() { // Arrange var mvcOptions = Options.Create(new MvcOptions()); - var endpointFactory = new ActionEndpointFactory(Mock.Of(), Enumerable.Empty()); + var endpointFactory = new ActionEndpointFactory(Mock.Of(), Enumerable.Empty(), Mock.Of()); var loader = new DefaultPageLoader( new[] { Mock.Of() }, Mock.Of(), diff --git a/src/Mvc/Mvc.RazorPages/test/Infrastructure/PageActionEndpointDataSourceTest.cs b/src/Mvc/Mvc.RazorPages/test/Infrastructure/PageActionEndpointDataSourceTest.cs index 4bcf50dfda5c..9d5424e682c3 100644 --- a/src/Mvc/Mvc.RazorPages/test/Infrastructure/PageActionEndpointDataSourceTest.cs +++ b/src/Mvc/Mvc.RazorPages/test/Infrastructure/PageActionEndpointDataSourceTest.cs @@ -1,10 +1,13 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Mvc.Abstractions; +using Microsoft.AspNetCore.Mvc.Controllers; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.Mvc.Routing; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Patterns; using Moq; namespace Microsoft.AspNetCore.Mvc.RazorPages.Infrastructure; @@ -86,6 +89,83 @@ public void Endpoints_AppliesConventions() }); } + [Fact] + public void GroupedEndpoints_AppliesConventions_RouteSpecificMetadata() + { + // Arrange + var actions = new List + { + new PageActionDescriptor + { + AttributeRouteInfo = new AttributeRouteInfo() + { + Template = "/test", + }, + RouteValues = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "action", "Test" }, + { "controller", "Test" }, + }, + EndpointMetadata = new List() { "A" } + }, + new PageActionDescriptor + { + AttributeRouteInfo = new AttributeRouteInfo() + { + Template = "/test2", + }, + RouteValues = new Dictionary(StringComparer.OrdinalIgnoreCase) + { + { "action", "Test" }, + { "controller", "Test" }, + }, + EndpointMetadata = new List() { "B" } + }, + }; + + var mockDescriptorProvider = new Mock(); + mockDescriptorProvider.Setup(m => m.ActionDescriptors).Returns(new ActionDescriptorCollection(actions, 0)); + + var dataSource = (PageActionEndpointDataSource)CreateDataSource(mockDescriptorProvider.Object); + + dataSource.DefaultBuilder.Add((b) => + { + b.Metadata.Add("Hi there"); + }); + + // Act + var groupConventions = new List>() + { + b => b.Metadata.Add(new GroupMetadata()), + b => b.Metadata.Add("group") + }; + var sp = Mock.Of(); + var groupPattern = RoutePatternFactory.Parse("/group1"); + var endpoints = dataSource.GetGroupedEndpoints(new RouteGroupContext(groupPattern, groupConventions, sp)); + + // Assert + Assert.Collection( + endpoints.OfType().OrderBy(e => e.RoutePattern.RawText), + e => + { + Assert.Equal("/group1/test", e.RoutePattern.RawText); + Assert.Same(actions[0], e.Metadata.GetMetadata()); + Assert.Equal(new[] { "group", "A", "Hi there" }, e.Metadata.GetOrderedMetadata()); + Assert.NotNull(e.Metadata.GetMetadata()); + }, + e => + { + Assert.Equal("/group1/test2", e.RoutePattern.RawText); + Assert.Same(actions[1], e.Metadata.GetMetadata()); + Assert.Equal(new[] { "group", "B", "Hi there" }, e.Metadata.GetOrderedMetadata()); + Assert.NotNull(e.Metadata.GetMetadata()); + }); + } + + private class GroupMetadata + { + } + private protected override ActionEndpointDataSourceBase CreateDataSource(IActionDescriptorCollectionProvider actions, ActionEndpointFactory endpointFactory) { return new PageActionEndpointDataSource(new PageActionEndpointDataSourceIdProvider(), actions, endpointFactory, new OrderedEndpointsSequenceProvider()); diff --git a/src/Mvc/perf/Microbenchmarks/Microsoft.AspNetCore.Mvc/ControllerActionEndpointDatasourceBenchmark.cs b/src/Mvc/perf/Microbenchmarks/Microsoft.AspNetCore.Mvc/ControllerActionEndpointDatasourceBenchmark.cs index 301d2e3eddc4..8161987ad554 100644 --- a/src/Mvc/perf/Microbenchmarks/Microsoft.AspNetCore.Mvc/ControllerActionEndpointDatasourceBenchmark.cs +++ b/src/Mvc/perf/Microbenchmarks/Microsoft.AspNetCore.Mvc/ControllerActionEndpointDatasourceBenchmark.cs @@ -107,12 +107,20 @@ private ControllerActionEndpointDataSource CreateDataSource(IActionDescriptorCol var dataSource = new ControllerActionEndpointDataSource( new ControllerActionEndpointDataSourceIdProvider(), actionDescriptorCollectionProvider, - new ActionEndpointFactory(new MockRoutePatternTransformer(), Enumerable.Empty()), + new ActionEndpointFactory(new MockRoutePatternTransformer(), Enumerable.Empty(), new MockServiceProvider()), new OrderedEndpointsSequenceProvider()); return dataSource; } + private sealed class MockServiceProvider : IServiceProvider + { + public object GetService(Type serviceType) + { + return null; + } + } + private sealed class MockRoutePatternTransformer : RoutePatternTransformer { public override RoutePattern SubstituteRequiredValues(RoutePattern original, object requiredValues) diff --git a/src/Mvc/test/Mvc.FunctionalTests/ControllerEndpointFiltersTest.cs b/src/Mvc/test/Mvc.FunctionalTests/ControllerEndpointFiltersTest.cs new file mode 100644 index 000000000000..fa8a03c76ece --- /dev/null +++ b/src/Mvc/test/Mvc.FunctionalTests/ControllerEndpointFiltersTest.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Http; +using System.Net.Http.Json; +using System.Text.Json; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Testing; +using RoutingWebSite; + +namespace Microsoft.AspNetCore.Mvc.FunctionalTests; + +public class ControllerEndpointFiltersTest : IClassFixture> +{ + public ControllerEndpointFiltersTest(MvcTestFixture fixture) + { + Factory = fixture.Factories.FirstOrDefault() ?? fixture.WithWebHostBuilder(ConfigureWebHostBuilder); + } + + private static void ConfigureWebHostBuilder(IWebHostBuilder builder) => builder.UseStartup(); + + public WebApplicationFactory Factory { get; } + + [Fact] + public async Task CanApplyEndpointFilterToController() + { + using var client = Factory.CreateClient(); + + var response = await client.GetAsync("Items/Index"); + var content = await response.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.True(content.TryGetValue(nameof(IEndpointFilter), out var endpointFilterCalled)); + Assert.True(((JsonElement)endpointFilterCalled).GetBoolean()); + } + + [Fact] + public async Task CanCaptureMethodInfoFromControllerAction() + { + using var client = Factory.CreateClient(); + + var response = await client.GetAsync("Items/Index"); + var content = await response.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.True(content.TryGetValue(nameof(EndpointFilterFactoryContext.MethodInfo.Name), out var methodInfo)); + Assert.Equal("Index", ((JsonElement)methodInfo).GetString()); + } + + [Fact] + public async Task CanInterceptActionResultViaFilter() + { + using var client = Factory.CreateClient(); + + var response = await client.GetAsync("Items/IndexWithSelectiveFilter"); + var content = await response.Content.ReadAsStringAsync(); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal("Intercepted", content); + } + + [Fact] + public async Task CanAccessArgumentsFromAction() + { + using var client = Factory.CreateClient(); + + var response = await client.GetAsync("Items/IndexWithArgument/foobar"); + var content = await response.Content.ReadFromJsonAsync>(); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.True(content.TryGetValue(nameof(EndpointFilterInvocationContext.Arguments), out var argument)); + Assert.Equal("foobar", ((JsonElement)argument).GetString()); + } +} diff --git a/src/Mvc/test/Mvc.FunctionalTests/RoutingGroupsTest.cs b/src/Mvc/test/Mvc.FunctionalTests/RoutingGroupsTest.cs index 29bae6189d2c..f9bab6bd2ddd 100644 --- a/src/Mvc/test/Mvc.FunctionalTests/RoutingGroupsTest.cs +++ b/src/Mvc/test/Mvc.FunctionalTests/RoutingGroupsTest.cs @@ -4,7 +4,9 @@ using System.Net; using System.Net.Http; using System.Net.Http.Json; +using System.Text.Json; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc.Testing; using RoutingWebSite; diff --git a/src/Mvc/test/Mvc.FunctionalTests/RoutingGroupsWithMetadataTest.cs b/src/Mvc/test/Mvc.FunctionalTests/RoutingGroupsWithMetadataTest.cs new file mode 100644 index 000000000000..03b26ccee76e --- /dev/null +++ b/src/Mvc/test/Mvc.FunctionalTests/RoutingGroupsWithMetadataTest.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Net.Http.Json; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Testing; +using RoutingWebSite; + +namespace Microsoft.AspNetCore.Mvc.FunctionalTests; + +public class RoutingGroupsWithMetadataTests : IClassFixture> +{ + public RoutingGroupsWithMetadataTests(MvcTestFixture fixture) + { + Factory = fixture.Factories.FirstOrDefault() ?? fixture.WithWebHostBuilder(ConfigureWebHostBuilder); + } + + private static void ConfigureWebHostBuilder(IWebHostBuilder builder) => builder.UseStartup(); + + public WebApplicationFactory Factory { get; } + + [Fact] + public async Task OrderedGroupMetadataForControllers() + { + using var client = Factory.CreateClient(); + + var response = await client.GetAsync("group1/metadata"); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + var content = await response.Content.ReadFromJsonAsync(); + + Assert.Equal(new[] { "A", "C", "B" }, content); + } +} diff --git a/src/Mvc/test/WebSites/RoutingWebSite/Controllers/ConventionalControllerWithMetadata.cs b/src/Mvc/test/WebSites/RoutingWebSite/Controllers/ConventionalControllerWithMetadata.cs new file mode 100644 index 000000000000..5dcc521bbf44 --- /dev/null +++ b/src/Mvc/test/WebSites/RoutingWebSite/Controllers/ConventionalControllerWithMetadata.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Mvc; +using Mvc.RoutingWebSite.Infrastructure; + +namespace Mvc.RoutingWebSite.Controllers; + +public class ConventionalControllerWithMetadata : Controller +{ + [Metadata("C")] + public IActionResult GetMetadata() + { + return Ok(HttpContext.GetEndpoint().Metadata.GetOrderedMetadata().Select(m => m.Value)); + } +} diff --git a/src/Mvc/test/WebSites/RoutingWebSite/Controllers/ItemsController.cs b/src/Mvc/test/WebSites/RoutingWebSite/Controllers/ItemsController.cs new file mode 100644 index 000000000000..5156a01901f0 --- /dev/null +++ b/src/Mvc/test/WebSites/RoutingWebSite/Controllers/ItemsController.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Mvc; + +namespace RoutingWebSite; + +[Route("Items/[action]")] +public class ItemsController : Controller +{ + public ActionResult> Index() + { + return Ok(HttpContext.Items); + } + + public string IndexWithSelectiveFilter() + { + return "Default response"; + } + + [Route("{arg}")] + public ActionResult> IndexWithArgument(string arg) + { + return Ok(HttpContext.Items); + } +} diff --git a/src/Mvc/test/WebSites/RoutingWebSite/Infrastructure/ManualControllerFeatureProvider.cs b/src/Mvc/test/WebSites/RoutingWebSite/Infrastructure/ManualControllerFeatureProvider.cs new file mode 100644 index 000000000000..7d21b35053a4 --- /dev/null +++ b/src/Mvc/test/WebSites/RoutingWebSite/Infrastructure/ManualControllerFeatureProvider.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Reflection; +using Microsoft.AspNetCore.Mvc.ApplicationParts; +using Microsoft.AspNetCore.Mvc.Controllers; + +namespace Mvc.RoutingWebSite.Infrastructure; + +internal class ManualControllerFeatureProvider : IApplicationFeatureProvider +{ + private readonly Action _action; + public ManualControllerFeatureProvider(Action action) + { + _action = action; + } + + public void PopulateFeature(IEnumerable parts, ControllerFeature feature) + { + _action(feature); + } +} + diff --git a/src/Mvc/test/WebSites/RoutingWebSite/Infrastructure/MetadataAttribute.cs b/src/Mvc/test/WebSites/RoutingWebSite/Infrastructure/MetadataAttribute.cs new file mode 100644 index 000000000000..7fd9caef586d --- /dev/null +++ b/src/Mvc/test/WebSites/RoutingWebSite/Infrastructure/MetadataAttribute.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Mvc.RoutingWebSite.Infrastructure; + +internal class MetadataAttribute : Attribute +{ + public string Value { get; set; } + + public MetadataAttribute(string value) + { + Value = value; + } +} + diff --git a/src/Mvc/test/WebSites/RoutingWebSite/StartupForEndpointFilters.cs b/src/Mvc/test/WebSites/RoutingWebSite/StartupForEndpointFilters.cs new file mode 100644 index 000000000000..268c19d75f7c --- /dev/null +++ b/src/Mvc/test/WebSites/RoutingWebSite/StartupForEndpointFilters.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Infrastructure; + +namespace RoutingWebSite; + +public class StartupForEndpointFilters +{ + // Set up application services + public void ConfigureServices(IServiceCollection services) + { + services.AddMvc().AddNewtonsoftJson(); + + // Used by some controllers defined in this project. + services.Configure(options => options.ConstraintMap["slugify"] = typeof(SlugifyParameterTransformer)); + services.AddScoped(); + // This is used by test response generator + services.AddSingleton(); + } + + public virtual void Configure(IApplicationBuilder app) + { + app.UseRouting(); + app.UseEndpoints(builder => + { + builder.MapControllers().AddEndpointFilterFactory((context, next) => + { + return async ic => + { + ic.HttpContext.Items[nameof(IEndpointFilter)] = true; + ic.HttpContext.Items[nameof(EndpointFilterFactoryContext.MethodInfo.Name)] = context.MethodInfo.Name; + var result = await next(ic); + if (context.MethodInfo.Name == "IndexWithSelectiveFilter") + { + return "Intercepted"; + } + return result; + }; + }).AddEndpointFilterFactory((context, next) => + { + if (context.MethodInfo.GetParameters().Length >= 1 && context.MethodInfo.GetParameters()[0].ParameterType == typeof(string)) + { + return ic => + { + var firstArg = ic.GetArgument(0); + ic.HttpContext.Items[nameof(EndpointFilterInvocationContext.Arguments)] = firstArg; + return next(ic); + }; + } + + return ic => next(ic); + }); + }); + } +} diff --git a/src/Mvc/test/WebSites/RoutingWebSite/StartupForRouteGroupsWithMetadata.cs b/src/Mvc/test/WebSites/RoutingWebSite/StartupForRouteGroupsWithMetadata.cs new file mode 100644 index 000000000000..336705f908f3 --- /dev/null +++ b/src/Mvc/test/WebSites/RoutingWebSite/StartupForRouteGroupsWithMetadata.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Reflection; +using Microsoft.AspNetCore.Mvc.ApplicationParts; +using Microsoft.AspNetCore.Mvc.Controllers; +using Mvc.RoutingWebSite.Controllers; +using Mvc.RoutingWebSite.Infrastructure; + +namespace RoutingWebSite; + +public class StartupForRouteGroupsWithMetadata +{ + // Set up application services + public void ConfigureServices(IServiceCollection services) + { + var builder = services.AddControllers(); + + // Remove the default controller feature provider so we don't find all of the controllers + // in this app, we do this because adding controllers to multple groups with the same name + // does not work. + var old = builder.PartManager.FeatureProviders.OfType>().FirstOrDefault(); + builder.PartManager.FeatureProviders.Remove(old); + builder.PartManager.FeatureProviders.Add( + new ManualControllerFeatureProvider(f => + { + f.Controllers.Add(typeof(ItemsController).GetTypeInfo()); + f.Controllers.Add(typeof(ConventionalControllerWithMetadata).GetTypeInfo()); + })); + } + + public virtual void Configure(IApplicationBuilder app) + { + app.UseRouting(); + app.UseEndpoints(builder => + { + // Map all controllers (defined in the + builder.MapControllers(); + + builder.MapGroup("/group1") + .WithMetadata(new MetadataAttribute("A")) + .MapControllerRoute("route1", "/metadata", new + { + controller = nameof(ConventionalControllerWithMetadata), + action = nameof(ConventionalControllerWithMetadata.GetMetadata) + }) + .WithMetadata(new MetadataAttribute("B")); + }); + } +} +