diff --git a/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs new file mode 100644 index 000000000000..4d3e583eaa5d --- /dev/null +++ b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http; + +/// +/// Provides an interface for implementing a filter targetting a route handler. +/// +public interface IRouteHandlerFilter +{ + /// + /// Implements the core logic associated with the filter given a + /// and the next filter to call in the pipeline. + /// + /// The associated with the current request/response. + /// The next filter in the pipeline. + /// An awaitable result of calling the handler and apply + /// any modifications made by filters in the pipeline. + ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next); +} diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt index b7318f416a16..bee1f68cba6d 100644 --- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt @@ -1,6 +1,8 @@ #nullable enable *REMOVED*abstract Microsoft.AspNetCore.Http.HttpResponse.ContentType.get -> string! Microsoft.AspNetCore.Http.EndpointMetadataCollection.GetRequiredMetadata() -> T! +Microsoft.AspNetCore.Http.RouteHandlerFilterContext.RouteHandlerFilterContext(Microsoft.AspNetCore.Http.HttpContext! httpContext, params object![]! parameters) -> void +Microsoft.AspNetCore.Http.IRouteHandlerFilter.InvokeAsync(Microsoft.AspNetCore.Http.RouteHandlerFilterContext! context, System.Func>! next) -> System.Threading.Tasks.ValueTask Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata.Name.get -> string? Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(Microsoft.AspNetCore.Routing.RouteValueDictionary? dictionary) -> void @@ -8,3 +10,7 @@ Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(System.Co Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(System.Collections.Generic.IEnumerable>? values) -> void abstract Microsoft.AspNetCore.Http.HttpResponse.ContentType.get -> string? Microsoft.AspNetCore.Http.Metadata.ISkipStatusCodePagesMetadata +Microsoft.AspNetCore.Http.RouteHandlerFilterContext +Microsoft.AspNetCore.Http.RouteHandlerFilterContext.HttpContext.get -> Microsoft.AspNetCore.Http.HttpContext! +Microsoft.AspNetCore.Http.RouteHandlerFilterContext.Parameters.get -> System.Collections.Generic.IList! +Microsoft.AspNetCore.Http.IRouteHandlerFilter diff --git a/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs b/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs new file mode 100644 index 000000000000..558d97cbd06b --- /dev/null +++ b/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs @@ -0,0 +1,35 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http; + +/// +/// Provides an abstraction for wrapping the and parameters +/// provided to a route handler. +/// +public class RouteHandlerFilterContext +{ + /// + /// Creates a new instance of the for a given request. + /// + /// The associated with the current request. + /// A list of parameters provided in the current request. + public RouteHandlerFilterContext(HttpContext httpContext, params object[] parameters) + { + HttpContext = httpContext; + Parameters = parameters; + } + + /// + /// The associated with the current request being processed by the filter. + /// + public HttpContext HttpContext { get; } + + /// + /// A list of parameters provided in the current request to the filter. + /// + /// This list is not read-only to premit modifying of existing parameters by filters. + /// + /// + public IList Parameters { get; } +} diff --git a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt index 1030d0f0793e..1d4c624f9113 100644 --- a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt @@ -1,3 +1,5 @@ #nullable enable Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions static Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions.ConfigureRouteHandlerJsonOptions(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action! configureOptions) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.get -> System.Collections.Generic.IReadOnlyList? +Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.init -> void diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 954e8c94a082..5b79735f356c 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -39,6 +39,7 @@ public static partial class RequestDelegateFactory private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringIsNullOrEmptyMethod = typeof(string).GetMethod(nameof(string.IsNullOrEmpty), BindingFlags.Static | BindingFlags.Public)!; + private static readonly MethodInfo WrapObjectAsValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(WrapObjectAsValueTask), BindingFlags.NonPublic | BindingFlags.Static)!; // Call WriteAsJsonAsync() to serialize the runtime return type rather than the declared return type. // https://docs.microsoft.com/en-us/dotnet/standard/serialization/system-text-json-polymorphism @@ -71,12 +72,21 @@ public static partial class RequestDelegateFactory private static readonly MemberExpression FormFilesExpr = Expression.Property(FormExpr, typeof(IFormCollection).GetProperty(nameof(IFormCollection.Files))!); private static readonly MemberExpression StatusCodeExpr = Expression.Property(HttpResponseExpr, typeof(HttpResponse).GetProperty(nameof(HttpResponse.StatusCode))!); private static readonly MemberExpression CompletedTaskExpr = Expression.Property(null, (PropertyInfo)GetMemberInfo>(() => Task.CompletedTask)); + private static readonly NewExpression CompletedValueTaskExpr = Expression.New(typeof(ValueTask).GetConstructor(new[] { typeof(Task) })!, CompletedTaskExpr); private static readonly ParameterExpression TempSourceStringExpr = ParameterBindingMethodCache.TempSourceStringExpr; private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null)); private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null)); private static readonly UnaryExpression TempSourceStringIsNotNullOrEmptyExpr = Expression.Not(Expression.Call(StringIsNullOrEmptyMethod, TempSourceStringExpr)); + private static readonly ConstructorInfo RouteHandlerFilterContextConstructor = typeof(RouteHandlerFilterContext).GetConstructor(new[] { typeof(HttpContext), typeof(object[]) })!; + private static readonly ParameterExpression FilterContextExpr = Expression.Parameter(typeof(RouteHandlerFilterContext), "context"); + private static readonly MemberExpression FilterContextParametersExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.Parameters))!); + private static readonly MemberExpression FilterContextHttpContextExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.HttpContext))!); + private static readonly MemberExpression FilterContextHttpContextResponseExpr = Expression.Property(FilterContextHttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.Response))!); + private static readonly MemberExpression FilterContextHttpContextStatusCodeExpr = Expression.Property(FilterContextHttpContextResponseExpr, typeof(HttpResponse).GetProperty(nameof(HttpResponse.StatusCode))!); + private static readonly ParameterExpression InvokedFilterContextExpr = Expression.Parameter(typeof(RouteHandlerFilterContext), "filterContext"); + private static readonly string[] DefaultAcceptsContentType = new[] { "application/json" }; private static readonly string[] FormFileContentType = new[] { "multipart/form-data" }; @@ -102,6 +112,7 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact }; var factoryContext = CreateFactoryContext(options); + var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext); return new RequestDelegateResult(httpContext => targetableRequestDelegate(handler.Target, httpContext), factoryContext.Metadata); @@ -155,6 +166,7 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions RouteParameters = options?.RouteParameterNames?.ToList(), ThrowOnBadRequest = options?.ThrowOnBadRequest ?? false, DisableInferredFromBody = options?.DisableInferBodyFromParameters ?? false, + Filters = options?.RouteHandlerFilters?.ToList() }; private static Func CreateTargetableRequestDelegate(MethodInfo methodInfo, Expression? targetExpression, FactoryContext factoryContext) @@ -176,10 +188,31 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions // } var arguments = CreateArguments(methodInfo.GetParameters(), factoryContext); + var returnType = methodInfo.ReturnType; + factoryContext.MethodCall = CreateMethodCall(methodInfo, targetExpression, arguments); + + // If there are filters registered on the route handler, then we update the method call and + // return type associated with the request to allow for the filter invocation pipeline. + if (factoryContext.Filters is { Count: > 0 }) + { + var filterPipeline = CreateFilterPipeline(methodInfo, targetExpression, factoryContext); + Expression>> invokePipeline = (context) => filterPipeline(context); + returnType = typeof(ValueTask); + // var filterContext = new RouteHandlerFilterContext(httpContext, new[] { (object)name_local, (object)int_local }); + // invokePipeline.Invoke(filterContext); + factoryContext.MethodCall = Expression.Block( + new[] { InvokedFilterContextExpr }, + Expression.Assign( + InvokedFilterContextExpr, + Expression.New(RouteHandlerFilterContextConstructor, + new Expression[] { HttpContextExpr, Expression.NewArrayInit(typeof(object), factoryContext.BoxedArgs) })), + Expression.Invoke(invokePipeline, InvokedFilterContextExpr) + ); + } var responseWritingMethodCall = factoryContext.ParamCheckExpressions.Count > 0 ? - CreateParamCheckingResponseWritingMethodCall(methodInfo, targetExpression, arguments, factoryContext) : - CreateResponseWritingMethodCall(methodInfo, targetExpression, arguments); + CreateParamCheckingResponseWritingMethodCall(returnType, factoryContext) : + AddResponseWritingToMethodCall(factoryContext.MethodCall, returnType); if (factoryContext.UsingTempSourceString) { @@ -189,6 +222,35 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions return HandleRequestBodyAndCompileRequestDelegate(responseWritingMethodCall, factoryContext); } + private static Func> CreateFilterPipeline(MethodInfo methodInfo, Expression? target, FactoryContext factoryContext) + { + Debug.Assert(factoryContext.Filters is not null); + // httpContext.Response.StatusCode >= 400 + // ? Task.CompletedTask + // : handler((string)context.Parameters[0], (int)context.Parameters[1]) + var filteredInvocation = Expression.Lambda>>( + Expression.Condition( + Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)), + CompletedValueTaskExpr, + Expression.Block( + new[] { TargetExpr }, + Expression.Call(WrapObjectAsValueTaskMethod, + target is null + ? Expression.Call(methodInfo, factoryContext.ContextArgAccess) + : Expression.Call(target, methodInfo, factoryContext.ContextArgAccess)) + )), + FilterContextExpr).Compile(); + + for (var i = factoryContext.Filters.Count - 1; i >= 0; i--) + { + var currentFilter = factoryContext.Filters![i]; + var nextFilter = filteredInvocation; + filteredInvocation = (RouteHandlerFilterContext context) => currentFilter.InvokeAsync(context, nextFilter); + + } + return filteredInvocation; + } + private static Expression[] CreateArguments(ParameterInfo[]? parameters, FactoryContext factoryContext) { if (parameters is null || parameters.Length == 0) @@ -201,6 +263,16 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory for (var i = 0; i < parameters.Length; i++) { args[i] = CreateArgument(parameters[i], factoryContext); + // Register expressions containing the boxed and unboxed variants + // of the route handler's arguments for use in RouteHandlerFilterContext + // construction and route handler invocation. + // (string)context.Parameters[0]; + factoryContext.ContextArgAccess.Add( + Expression.Convert( + Expression.Property(FilterContextParametersExpr, "Item", Expression.Constant(i)), + parameters[i].ParameterType)); + // (object)name_local + factoryContext.BoxedArgs.Add(Expression.Convert(args[i], typeof(object))); } if (factoryContext.HasInferredBody && factoryContext.DisableInferredFromBody) @@ -381,16 +453,14 @@ target is null ? Expression.Call(methodInfo, arguments) : Expression.Call(target, methodInfo, arguments); - private static Expression CreateResponseWritingMethodCall(MethodInfo methodInfo, Expression? target, Expression[] arguments) + private static ValueTask WrapObjectAsValueTask(object? obj) { - var callMethod = CreateMethodCall(methodInfo, target, arguments); - return AddResponseWritingToMethodCall(callMethod, methodInfo.ReturnType); + return ValueTask.FromResult(obj); } // If we're calling TryParse or validating parameter optionality and // wasParamCheckFailure indicates it failed, set a 400 StatusCode instead of calling the method. - private static Expression CreateParamCheckingResponseWritingMethodCall( - MethodInfo methodInfo, Expression? target, Expression[] arguments, FactoryContext factoryContext) + private static Expression CreateParamCheckingResponseWritingMethodCall(Type returnType, FactoryContext factoryContext) { // { // string tempSourceString; @@ -440,17 +510,40 @@ private static Expression CreateParamCheckingResponseWritingMethodCall( localVariables[factoryContext.ExtraLocals.Count] = WasParamCheckFailureExpr; - var set400StatusAndReturnCompletedTask = Expression.Block( - Expression.Assign(StatusCodeExpr, Expression.Constant(400)), - CompletedTaskExpr); - - var methodCall = CreateMethodCall(methodInfo, target, arguments); - - var checkWasParamCheckFailure = Expression.Condition(WasParamCheckFailureExpr, - set400StatusAndReturnCompletedTask, - AddResponseWritingToMethodCall(methodCall, methodInfo.ReturnType)); + // If filters have been registered, we set the `wasParamCheckFailure` property + // but do not return from the invocation to allow the filters to run. + if (factoryContext.Filters is { Count: > 0 }) + { + // if (wasParamCheckFailure) + // { + // httpContext.Response.StatusCode = 400; + // } + // return RequestDelegateFactory.ExecuteObjectReturn(invocationPipeline.Invoke(context) as object); + var checkWasParamCheckFailureWithFilters = Expression.Block( + Expression.IfThen( + WasParamCheckFailureExpr, + Expression.Assign(StatusCodeExpr, Expression.Constant(400))), + AddResponseWritingToMethodCall(factoryContext.MethodCall!, returnType) + ); - checkParamAndCallMethod[factoryContext.ParamCheckExpressions.Count] = checkWasParamCheckFailure; + checkParamAndCallMethod[factoryContext.ParamCheckExpressions.Count] = checkWasParamCheckFailureWithFilters; + } + else + { + // wasParamCheckFailure ? { + // httpContext.Response.StatusCode = 400; + // return Task.CompletedTask; + // } : { + // return RequestDelegateFactory.ExecuteObjectReturn(invocationPipeline.Invoke(context) as object); + // } + var checkWasParamCheckFailure = Expression.Condition( + WasParamCheckFailureExpr, + Expression.Block( + Expression.Assign(StatusCodeExpr, Expression.Constant(400)), + CompletedTaskExpr), + AddResponseWritingToMethodCall(factoryContext.MethodCall!, returnType)); + checkParamAndCallMethod[factoryContext.ParamCheckExpressions.Count] = checkWasParamCheckFailure; + } return Expression.Block(localVariables, checkParamAndCallMethod); } @@ -1596,6 +1689,11 @@ private class FactoryContext public bool ReadForm { get; set; } public ParameterInfo? FirstFormRequestBodyParameter { get; set; } + // Properties for constructing and managing filters + public List ContextArgAccess { get; } = new(); + public Expression? MethodCall { get; set; } + public List BoxedArgs { get; } = new(); + public List? Filters { get; init; } } private static class RequestDelegateFactoryConstants diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs index 892cbd2c7efe..870c2a06158e 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs @@ -31,4 +31,9 @@ public sealed class RequestDelegateFactoryOptions /// Prevent the from inferring a parameter should be bound from the request body without an attribute that implements . /// public bool DisableInferBodyFromParameters { get; init; } + + /// + /// The list of filters that must run in the pipeline for a given route handler. + /// + public IReadOnlyList? RouteHandlerFilters { get; init; } } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 44f88b26e236..1af56e02660e 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -4200,6 +4200,196 @@ void TestAction(IFormFile file) Assert.Equal(400, badHttpRequestException.StatusCode); } + [Fact] + public async Task RequestDelegateFactory_InvokesFiltersButNotHandler_OnArgumentError() + { + var invoked = false; + // Arrange + string HelloName(string name) + { + invoked = true; + return $"Hello, {name}!"; + }; + + var httpContext = CreateHttpContext(); + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyStringArgumentFilter() } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + Assert.False(invoked); + Assert.Equal(400, httpContext.Response.StatusCode); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatProvidesCustomErrorMessage() + { + // Arrange + string HelloName(string name) + { + return $"Hello, {name}!"; + }; + + var httpContext = CreateHttpContext(); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ProvideCustomErrorMessageFilter() } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var decodedResponseBody = JsonSerializer.Deserialize(responseBodyStream.ToArray()); + Assert.Equal(400, httpContext.Response.StatusCode); + Assert.Equal("New response", decodedResponseBody!.Detail); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeMultipleEndpointFilters_ThatTouchArguments() + { + // Arrange + string HelloName(string name, int age) + { + return $"Hello, {name}! You are {age} years old."; + }; + + var loggerInvoked = 0; + void Log(string arg) => loggerInvoked++; + + var httpContext = CreateHttpContext(); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName", + ["age"] = "25" + }); + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyIntArgumentFilter(), new LogArgumentsFilter(Log) } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("Hello, TestName! You are 27 years old.", responseBody); + Assert.Equal(2, loggerInvoked); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatModifiesBodyParameter() + { + // Arrange + Todo todo = new Todo() { Name = "Write tests", IsComplete = true }; + string PrintTodo(Todo todo) + { + return $"{todo.Name} is {(todo.IsComplete ? "done" : "not done")}."; + }; + + var httpContext = CreateHttpContext(); + + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(todo); + var stream = new MemoryStream(requestBodyBytes); + httpContext.Request.Body = stream; + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(CultureInfo.InvariantCulture); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(PrintTodo, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyTodoArgumentFilter() } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("Write tests is not done.", responseBody); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatModifiesResult() + { + // Arrange + string HelloName(string name) + { + return $"Hello, {name}!"; + }; + + var httpContext = CreateHttpContext(); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyStringResultFilter() } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("HELLO, TESTNAME!", responseBody); + } + + [Fact] + public async Task RequestDelegateFactory_CanInvokeMultipleEndpointFilters_ThatModifyArgumentsAndResult() + { + // Arrange + string HelloName(string name) + { + return $"Hello, {name}!"; + }; + + var httpContext = CreateHttpContext(); + + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + + // Act + var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions() + { + RouteHandlerFilters = new List() { new ModifyStringResultFilter(), new ModifyStringArgumentFilter() } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("HELLO, TESTNAMEPREFIX!", responseBody); + } + private DefaultHttpContext CreateHttpContext() { var responseFeature = new TestHttpResponseFeature(); @@ -4559,6 +4749,78 @@ public TlsConnectionFeature(X509Certificate2 clientCertificate) throw new NotImplementedException(); } } + + private class ModifyStringArgumentFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + context.Parameters[0] = context.Parameters[0] != null ? $"{((string)context.Parameters[0]!)}Prefix" : "NULL"; + return await next(context); + } + } + + private class ModifyIntArgumentFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + context.Parameters[1] = ((int)context.Parameters[1]!) + 2; + return await next(context); + } + } + + private class ModifyTodoArgumentFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + Todo originalTodo = (Todo)context.Parameters[0]!; + originalTodo!.IsComplete = !originalTodo.IsComplete; + context.Parameters[0] = originalTodo; + return await next(context); + } + } + + private class ProvideCustomErrorMessageFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + if (context.HttpContext.Response.StatusCode == 400) + { + return Results.Problem("New response", statusCode: 400); + } + return await next(context); + } + } + + private class LogArgumentsFilter : IRouteHandlerFilter + { + private Action _logger; + + public LogArgumentsFilter(Action logger) + { + _logger = logger; + } + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + foreach (var parameter in context.Parameters) + { + _logger(parameter!.ToString() ?? "no arg"); + } + return await next(context); + } + } + + private class ModifyStringResultFilter : IRouteHandlerFilter + { + public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + var previousResult = await next(context); + if (previousResult is string stringResult) + { + return stringResult.ToUpperInvariant(); + } + return previousResult; + } + } } internal static class TestExtensionResults diff --git a/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs b/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs new file mode 100644 index 000000000000..9872915904a8 --- /dev/null +++ b/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http; + +internal class DelegateRouteHandlerFilter : IRouteHandlerFilter +{ + private readonly Func>, ValueTask> _routeHandlerFilter; + + internal DelegateRouteHandlerFilter(Func>, ValueTask> routeHandlerFilter) + { + _routeHandlerFilter = routeHandlerFilter; + } + + public ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next) + { + return _routeHandlerFilter(context, next); + } +} diff --git a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs index 4c1cb9b09904..6ce2d6c2c7ea 100644 --- a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs +++ b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs @@ -484,18 +484,7 @@ private static RouteHandlerBuilder Map( var routeHandlerOptions = endpoints.ServiceProvider?.GetService>(); - var options = new RequestDelegateFactoryOptions - { - ServiceProvider = endpoints.ServiceProvider, - RouteParameterNames = routeParams, - ThrowOnBadRequest = routeHandlerOptions?.Value.ThrowOnBadRequest ?? false, - DisableInferBodyFromParameters = disableInferBodyFromParameters, - }; - - var requestDelegateResult = RequestDelegateFactory.Create(handler, options); - var builder = new RouteEndpointBuilder( - requestDelegateResult.RequestDelegate, pattern, defaultOrder) { @@ -518,31 +507,46 @@ private static RouteHandlerBuilder Map( builder.DisplayName = $"{builder.DisplayName} => {endpointName}"; } - // Add delegate attributes as metadata - var attributes = handler.Method.GetCustomAttributes(); - - // Add add request delegate metadata - foreach (var metadata in requestDelegateResult.EndpointMetadata) + var dataSource = endpoints.DataSources.OfType().FirstOrDefault(); + if (dataSource is null) { - builder.Metadata.Add(metadata); + dataSource = new ModelEndpointDataSource(); + endpoints.DataSources.Add(dataSource); } - // This can be null if the delegate is a dynamic method or compiled from an expression tree - if (attributes is not null) + var routeHandlerBuilder = new RouteHandlerBuilder(dataSource.AddEndpointBuilder(builder)); + routeHandlerBuilder.Add(endpointBuilder => { - foreach (var attribute in attributes) + var options = new RequestDelegateFactoryOptions { - builder.Metadata.Add(attribute); + ServiceProvider = endpoints.ServiceProvider, + RouteParameterNames = routeParams, + ThrowOnBadRequest = routeHandlerOptions?.Value.ThrowOnBadRequest ?? false, + DisableInferBodyFromParameters = disableInferBodyFromParameters, + RouteHandlerFilters = routeHandlerBuilder.RouteHandlerFilters + }; + var filteredRequestDelegateResult = RequestDelegateFactory.Create(handler, options); + // Add request delegate metadata + foreach (var metadata in filteredRequestDelegateResult.EndpointMetadata) + { + endpointBuilder.Metadata.Add(metadata); } - } - var dataSource = endpoints.DataSources.OfType().FirstOrDefault(); - if (dataSource is null) - { - dataSource = new ModelEndpointDataSource(); - endpoints.DataSources.Add(dataSource); - } + // We add attributes on the handler after those automatically generated by the + // RDF since they have a higher specificity. + var attributes = handler.Method.GetCustomAttributes(); + + // This can be null if the delegate is a dynamic method or compiled from an expression tree + if (attributes is not null) + { + foreach (var attribute in attributes) + { + endpointBuilder.Metadata.Add(attribute); + } + } + endpointBuilder.RequestDelegate = filteredRequestDelegateResult.RequestDelegate; + }); - return new RouteHandlerBuilder(dataSource.AddEndpointBuilder(builder)); + return routeHandlerBuilder; } } diff --git a/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs b/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs index 40b896f2db24..b42e22cc3d8d 100644 --- a/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs +++ b/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.AspNetCore.Http; + namespace Microsoft.AspNetCore.Builder; /// @@ -11,6 +13,8 @@ public sealed class RouteHandlerBuilder : IEndpointConventionBuilder private readonly IEnumerable? _endpointConventionBuilders; private readonly IEndpointConventionBuilder? _endpointConventionBuilder; + internal List RouteHandlerFilters { get; } = new(); + /// /// Instantiates a new given a single /// . diff --git a/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs b/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs new file mode 100644 index 000000000000..ffec088f3e73 --- /dev/null +++ b/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using Microsoft.AspNetCore.Builder; + +namespace Microsoft.AspNetCore.Http; + +/// +/// Extension methods for adding to a route handler. +/// +public static class RouteHandlerFilterExtensions +{ + /// + /// Registers a filter onto the route handler. + /// + /// The . + /// The to register. + /// A that can be used to further customize the route handler. + public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, IRouteHandlerFilter filter) + { + builder.RouteHandlerFilters.Add(filter); + return builder; + } + + /// + /// Registers a filter of type onto the route handler. + /// + /// The type of the to register. + /// The . + /// A that can be used to further customize the route handler. + public static RouteHandlerBuilder AddFilter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TFilterType>(this RouteHandlerBuilder builder) where TFilterType : IRouteHandlerFilter, new() + { + builder.RouteHandlerFilters.Add(new TFilterType()); + return builder; + } + + /// + /// Registers a filter given a delegate onto the route handler. + /// + /// The . + /// A representing the core logic of the filter. + /// A that can be used to further customize the route handler. + public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, Func>, ValueTask> routeHandlerFilter) + { + builder.RouteHandlerFilters.Add(new DelegateRouteHandlerFilter(routeHandlerFilter)); + return builder; + } +} diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index 341f2446a1cf..4cf74e9056fd 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -1,4 +1,5 @@ #nullable enable +Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions Microsoft.AspNetCore.Routing.RouteOptions.SetParameterPolicy(string! token, System.Type! type) -> void Microsoft.AspNetCore.Routing.RouteOptions.SetParameterPolicy(string! token) -> void static Microsoft.AspNetCore.Builder.EndpointRouteBuilderExtensions.MapPatch(this Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! endpoints, string! pattern, System.Delegate! handler) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! @@ -6,3 +7,6 @@ static Microsoft.AspNetCore.Builder.EndpointRouteBuilderExtensions.MapPatch(this override Microsoft.AspNetCore.Routing.RouteValuesAddress.ToString() -> string? *REMOVED*~Microsoft.AspNetCore.Routing.DefaultInlineConstraintResolver.DefaultInlineConstraintResolver(Microsoft.Extensions.Options.IOptions! routeOptions, System.IServiceProvider! serviceProvider) -> void Microsoft.AspNetCore.Routing.DefaultInlineConstraintResolver.DefaultInlineConstraintResolver(Microsoft.Extensions.Options.IOptions! routeOptions, System.IServiceProvider! serviceProvider) -> void +static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, Microsoft.AspNetCore.Http.IRouteHandlerFilter! filter) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! +static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, System.Func>!, System.Threading.Tasks.ValueTask>! routeHandlerFilter) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! +static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! diff --git a/src/Http/Routing/src/RouteEndpointBuilder.cs b/src/Http/Routing/src/RouteEndpointBuilder.cs index bb27282fa97e..add1f849a4a4 100644 --- a/src/Http/Routing/src/RouteEndpointBuilder.cs +++ b/src/Http/Routing/src/RouteEndpointBuilder.cs @@ -38,6 +38,24 @@ public RouteEndpointBuilder( Order = order; } + /// + /// Constructs a new instance. + /// + /// The to use in URL matching. + /// The order assigned to the endpoint. + /// + /// This constructor allows the to be added to the + /// after construction but before + /// is invoked. + /// + internal RouteEndpointBuilder( + RoutePattern routePattern, + int order) + { + RoutePattern = routePattern; + Order = order; + } + /// public override Endpoint Build() { diff --git a/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs b/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs index a76b96a19be5..df21fbb1d7f1 100644 --- a/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs +++ b/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs @@ -89,8 +89,9 @@ public async Task MapEndpoint_ReturnGenericTypeTask_GeneratedDelegate() var endpointBuilder = builder.MapGet("/", GenericTypeTaskDelegate); // Assert - var endpointBuilder1 = GetRouteEndpointBuilder(builder); - var requestDelegate = endpointBuilder1.RequestDelegate; + var dataSource = GetBuilderEndpointDataSource(builder); + var endpoint = Assert.Single(dataSource.Endpoints); // Triggers build and construction of delegate + var requestDelegate = endpoint.RequestDelegate; await requestDelegate(httpContext); var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); diff --git a/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs b/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs index 5636c10cd940..1c3450159801 100644 --- a/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs +++ b/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs @@ -208,7 +208,9 @@ public async Task MapGetWithoutRouteParameter_BuildsEndpointWithQuerySpecificBin public void MapGet_ThrowsWithImplicitFromBody() { var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider())); - var ex = Assert.Throws(() => builder.MapGet("/", (Todo todo) => { })); + _ = builder.MapGet("/", (Todo todo) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Contains("Body was inferred but the method does not allow inferred body parameters.", ex.Message); Assert.Contains("Did you mean to register the \"Body (Inferred)\" parameter(s) as a Service or apply the [FromServices] or [FromBody] attribute?", ex.Message); } @@ -217,7 +219,9 @@ public void MapGet_ThrowsWithImplicitFromBody() public void MapDelete_ThrowsWithImplicitFromBody() { var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider())); - var ex = Assert.Throws(() => builder.MapDelete("/", (Todo todo) => { })); + _ = builder.MapDelete("/", (Todo todo) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Contains("Body was inferred but the method does not allow inferred body parameters.", ex.Message); Assert.Contains("Did you mean to register the \"Body (Inferred)\" parameter(s) as a Service or apply the [FromServices] or [FromBody] attribute?", ex.Message); } @@ -243,7 +247,9 @@ public static object[][] NonImplicitFromBodyMethods public void MapVerb_ThrowsWithImplicitFromBody(string method) { var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider())); - var ex = Assert.Throws(() => builder.MapMethods("/", new[] { method }, (Todo todo) => { })); + _ = builder.MapMethods("/", new[] { method }, (Todo todo) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Contains("Body was inferred but the method does not allow inferred body parameters.", ex.Message); Assert.Contains("Did you mean to register the \"Body (Inferred)\" parameter(s) as a Service or apply the [FromServices] or [FromBody] attribute?", ex.Message); } @@ -581,7 +587,9 @@ public async Task MapVerbWithRouteParameterDoesNotFallbackToQuery(Func(() => builder.MapGet("/", ([FromRoute] int id) => { })); + _ = builder.MapGet("/", ([FromRoute] int id) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Equal("'id' is not a route parameter.", ex.Message); } @@ -637,7 +645,9 @@ public async Task MapGetWithNamedFromRouteParameter_FailsForParameterName() public void MapGetWithNamedFromRouteParameter_ThrowsForMismatchedPattern() { var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider())); - var ex = Assert.Throws(() => builder.MapGet("/{id}", ([FromRoute(Name = "value")] int id, HttpContext httpContext) => { })); + _ = builder.MapGet("/{id}", ([FromRoute(Name = "value")] int id, HttpContext httpContext) => { }); + var dataSource = GetBuilderEndpointDataSource(builder); + var ex = Assert.Throws(() => dataSource.Endpoints); Assert.Equal("'value' is not a route parameter.", ex.Message); } @@ -677,7 +687,6 @@ public void MapPost_BuildsEndpointWithCorrectEndpointMetadata() Assert.False(endpointMetadata!.IsOptional); Assert.Equal(typeof(Todo), endpointMetadata.RequestType); Assert.Equal(new[] { "application/xml" }, endpointMetadata.ContentTypes); - } [Fact]