diff --git a/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs
new file mode 100644
index 000000000000..4d3e583eaa5d
--- /dev/null
+++ b/src/Http/Http.Abstractions/src/IRouteHandlerFilter.cs
@@ -0,0 +1,20 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.AspNetCore.Http;
+
+///
+/// Provides an interface for implementing a filter targetting a route handler.
+///
+public interface IRouteHandlerFilter
+{
+ ///
+ /// Implements the core logic associated with the filter given a
+ /// and the next filter to call in the pipeline.
+ ///
+ /// The associated with the current request/response.
+ /// The next filter in the pipeline.
+ /// An awaitable result of calling the handler and apply
+ /// any modifications made by filters in the pipeline.
+ ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next);
+}
diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt
index b7318f416a16..bee1f68cba6d 100644
--- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt
+++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt
@@ -1,6 +1,8 @@
#nullable enable
*REMOVED*abstract Microsoft.AspNetCore.Http.HttpResponse.ContentType.get -> string!
Microsoft.AspNetCore.Http.EndpointMetadataCollection.GetRequiredMetadata() -> T!
+Microsoft.AspNetCore.Http.RouteHandlerFilterContext.RouteHandlerFilterContext(Microsoft.AspNetCore.Http.HttpContext! httpContext, params object![]! parameters) -> void
+Microsoft.AspNetCore.Http.IRouteHandlerFilter.InvokeAsync(Microsoft.AspNetCore.Http.RouteHandlerFilterContext! context, System.Func>! next) -> System.Threading.Tasks.ValueTask
Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata
Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata.Name.get -> string?
Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(Microsoft.AspNetCore.Routing.RouteValueDictionary? dictionary) -> void
@@ -8,3 +10,7 @@ Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(System.Co
Microsoft.AspNetCore.Routing.RouteValueDictionary.RouteValueDictionary(System.Collections.Generic.IEnumerable>? values) -> void
abstract Microsoft.AspNetCore.Http.HttpResponse.ContentType.get -> string?
Microsoft.AspNetCore.Http.Metadata.ISkipStatusCodePagesMetadata
+Microsoft.AspNetCore.Http.RouteHandlerFilterContext
+Microsoft.AspNetCore.Http.RouteHandlerFilterContext.HttpContext.get -> Microsoft.AspNetCore.Http.HttpContext!
+Microsoft.AspNetCore.Http.RouteHandlerFilterContext.Parameters.get -> System.Collections.Generic.IList!
+Microsoft.AspNetCore.Http.IRouteHandlerFilter
diff --git a/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs b/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs
new file mode 100644
index 000000000000..558d97cbd06b
--- /dev/null
+++ b/src/Http/Http.Abstractions/src/RouteHandlerFilterContext.cs
@@ -0,0 +1,35 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.AspNetCore.Http;
+
+///
+/// Provides an abstraction for wrapping the and parameters
+/// provided to a route handler.
+///
+public class RouteHandlerFilterContext
+{
+ ///
+ /// Creates a new instance of the for a given request.
+ ///
+ /// The associated with the current request.
+ /// A list of parameters provided in the current request.
+ public RouteHandlerFilterContext(HttpContext httpContext, params object[] parameters)
+ {
+ HttpContext = httpContext;
+ Parameters = parameters;
+ }
+
+ ///
+ /// The associated with the current request being processed by the filter.
+ ///
+ public HttpContext HttpContext { get; }
+
+ ///
+ /// A list of parameters provided in the current request to the filter.
+ ///
+ /// This list is not read-only to premit modifying of existing parameters by filters.
+ ///
+ ///
+ public IList Parameters { get; }
+}
diff --git a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt
index 1030d0f0793e..1d4c624f9113 100644
--- a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt
+++ b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt
@@ -1,3 +1,5 @@
#nullable enable
Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions
static Microsoft.Extensions.DependencyInjection.RouteHandlerJsonServiceExtensions.ConfigureRouteHandlerJsonOptions(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action! configureOptions) -> Microsoft.Extensions.DependencyInjection.IServiceCollection!
+Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.get -> System.Collections.Generic.IReadOnlyList?
+Microsoft.AspNetCore.Http.RequestDelegateFactoryOptions.RouteHandlerFilters.init -> void
diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs
index 954e8c94a082..5b79735f356c 100644
--- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs
+++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs
@@ -39,6 +39,7 @@ public static partial class RequestDelegateFactory
private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo StringIsNullOrEmptyMethod = typeof(string).GetMethod(nameof(string.IsNullOrEmpty), BindingFlags.Static | BindingFlags.Public)!;
+ private static readonly MethodInfo WrapObjectAsValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(WrapObjectAsValueTask), BindingFlags.NonPublic | BindingFlags.Static)!;
// Call WriteAsJsonAsync() to serialize the runtime return type rather than the declared return type.
// https://docs.microsoft.com/en-us/dotnet/standard/serialization/system-text-json-polymorphism
@@ -71,12 +72,21 @@ public static partial class RequestDelegateFactory
private static readonly MemberExpression FormFilesExpr = Expression.Property(FormExpr, typeof(IFormCollection).GetProperty(nameof(IFormCollection.Files))!);
private static readonly MemberExpression StatusCodeExpr = Expression.Property(HttpResponseExpr, typeof(HttpResponse).GetProperty(nameof(HttpResponse.StatusCode))!);
private static readonly MemberExpression CompletedTaskExpr = Expression.Property(null, (PropertyInfo)GetMemberInfo>(() => Task.CompletedTask));
+ private static readonly NewExpression CompletedValueTaskExpr = Expression.New(typeof(ValueTask).GetConstructor(new[] { typeof(Task) })!, CompletedTaskExpr);
private static readonly ParameterExpression TempSourceStringExpr = ParameterBindingMethodCache.TempSourceStringExpr;
private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null));
private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null));
private static readonly UnaryExpression TempSourceStringIsNotNullOrEmptyExpr = Expression.Not(Expression.Call(StringIsNullOrEmptyMethod, TempSourceStringExpr));
+ private static readonly ConstructorInfo RouteHandlerFilterContextConstructor = typeof(RouteHandlerFilterContext).GetConstructor(new[] { typeof(HttpContext), typeof(object[]) })!;
+ private static readonly ParameterExpression FilterContextExpr = Expression.Parameter(typeof(RouteHandlerFilterContext), "context");
+ private static readonly MemberExpression FilterContextParametersExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.Parameters))!);
+ private static readonly MemberExpression FilterContextHttpContextExpr = Expression.Property(FilterContextExpr, typeof(RouteHandlerFilterContext).GetProperty(nameof(RouteHandlerFilterContext.HttpContext))!);
+ private static readonly MemberExpression FilterContextHttpContextResponseExpr = Expression.Property(FilterContextHttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.Response))!);
+ private static readonly MemberExpression FilterContextHttpContextStatusCodeExpr = Expression.Property(FilterContextHttpContextResponseExpr, typeof(HttpResponse).GetProperty(nameof(HttpResponse.StatusCode))!);
+ private static readonly ParameterExpression InvokedFilterContextExpr = Expression.Parameter(typeof(RouteHandlerFilterContext), "filterContext");
+
private static readonly string[] DefaultAcceptsContentType = new[] { "application/json" };
private static readonly string[] FormFileContentType = new[] { "multipart/form-data" };
@@ -102,6 +112,7 @@ public static RequestDelegateResult Create(Delegate handler, RequestDelegateFact
};
var factoryContext = CreateFactoryContext(options);
+
var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext);
return new RequestDelegateResult(httpContext => targetableRequestDelegate(handler.Target, httpContext), factoryContext.Metadata);
@@ -155,6 +166,7 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions
RouteParameters = options?.RouteParameterNames?.ToList(),
ThrowOnBadRequest = options?.ThrowOnBadRequest ?? false,
DisableInferredFromBody = options?.DisableInferBodyFromParameters ?? false,
+ Filters = options?.RouteHandlerFilters?.ToList()
};
private static Func CreateTargetableRequestDelegate(MethodInfo methodInfo, Expression? targetExpression, FactoryContext factoryContext)
@@ -176,10 +188,31 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions
// }
var arguments = CreateArguments(methodInfo.GetParameters(), factoryContext);
+ var returnType = methodInfo.ReturnType;
+ factoryContext.MethodCall = CreateMethodCall(methodInfo, targetExpression, arguments);
+
+ // If there are filters registered on the route handler, then we update the method call and
+ // return type associated with the request to allow for the filter invocation pipeline.
+ if (factoryContext.Filters is { Count: > 0 })
+ {
+ var filterPipeline = CreateFilterPipeline(methodInfo, targetExpression, factoryContext);
+ Expression>> invokePipeline = (context) => filterPipeline(context);
+ returnType = typeof(ValueTask);
+ // var filterContext = new RouteHandlerFilterContext(httpContext, new[] { (object)name_local, (object)int_local });
+ // invokePipeline.Invoke(filterContext);
+ factoryContext.MethodCall = Expression.Block(
+ new[] { InvokedFilterContextExpr },
+ Expression.Assign(
+ InvokedFilterContextExpr,
+ Expression.New(RouteHandlerFilterContextConstructor,
+ new Expression[] { HttpContextExpr, Expression.NewArrayInit(typeof(object), factoryContext.BoxedArgs) })),
+ Expression.Invoke(invokePipeline, InvokedFilterContextExpr)
+ );
+ }
var responseWritingMethodCall = factoryContext.ParamCheckExpressions.Count > 0 ?
- CreateParamCheckingResponseWritingMethodCall(methodInfo, targetExpression, arguments, factoryContext) :
- CreateResponseWritingMethodCall(methodInfo, targetExpression, arguments);
+ CreateParamCheckingResponseWritingMethodCall(returnType, factoryContext) :
+ AddResponseWritingToMethodCall(factoryContext.MethodCall, returnType);
if (factoryContext.UsingTempSourceString)
{
@@ -189,6 +222,35 @@ private static FactoryContext CreateFactoryContext(RequestDelegateFactoryOptions
return HandleRequestBodyAndCompileRequestDelegate(responseWritingMethodCall, factoryContext);
}
+ private static Func> CreateFilterPipeline(MethodInfo methodInfo, Expression? target, FactoryContext factoryContext)
+ {
+ Debug.Assert(factoryContext.Filters is not null);
+ // httpContext.Response.StatusCode >= 400
+ // ? Task.CompletedTask
+ // : handler((string)context.Parameters[0], (int)context.Parameters[1])
+ var filteredInvocation = Expression.Lambda>>(
+ Expression.Condition(
+ Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)),
+ CompletedValueTaskExpr,
+ Expression.Block(
+ new[] { TargetExpr },
+ Expression.Call(WrapObjectAsValueTaskMethod,
+ target is null
+ ? Expression.Call(methodInfo, factoryContext.ContextArgAccess)
+ : Expression.Call(target, methodInfo, factoryContext.ContextArgAccess))
+ )),
+ FilterContextExpr).Compile();
+
+ for (var i = factoryContext.Filters.Count - 1; i >= 0; i--)
+ {
+ var currentFilter = factoryContext.Filters![i];
+ var nextFilter = filteredInvocation;
+ filteredInvocation = (RouteHandlerFilterContext context) => currentFilter.InvokeAsync(context, nextFilter);
+
+ }
+ return filteredInvocation;
+ }
+
private static Expression[] CreateArguments(ParameterInfo[]? parameters, FactoryContext factoryContext)
{
if (parameters is null || parameters.Length == 0)
@@ -201,6 +263,16 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory
for (var i = 0; i < parameters.Length; i++)
{
args[i] = CreateArgument(parameters[i], factoryContext);
+ // Register expressions containing the boxed and unboxed variants
+ // of the route handler's arguments for use in RouteHandlerFilterContext
+ // construction and route handler invocation.
+ // (string)context.Parameters[0];
+ factoryContext.ContextArgAccess.Add(
+ Expression.Convert(
+ Expression.Property(FilterContextParametersExpr, "Item", Expression.Constant(i)),
+ parameters[i].ParameterType));
+ // (object)name_local
+ factoryContext.BoxedArgs.Add(Expression.Convert(args[i], typeof(object)));
}
if (factoryContext.HasInferredBody && factoryContext.DisableInferredFromBody)
@@ -381,16 +453,14 @@ target is null ?
Expression.Call(methodInfo, arguments) :
Expression.Call(target, methodInfo, arguments);
- private static Expression CreateResponseWritingMethodCall(MethodInfo methodInfo, Expression? target, Expression[] arguments)
+ private static ValueTask WrapObjectAsValueTask(object? obj)
{
- var callMethod = CreateMethodCall(methodInfo, target, arguments);
- return AddResponseWritingToMethodCall(callMethod, methodInfo.ReturnType);
+ return ValueTask.FromResult(obj);
}
// If we're calling TryParse or validating parameter optionality and
// wasParamCheckFailure indicates it failed, set a 400 StatusCode instead of calling the method.
- private static Expression CreateParamCheckingResponseWritingMethodCall(
- MethodInfo methodInfo, Expression? target, Expression[] arguments, FactoryContext factoryContext)
+ private static Expression CreateParamCheckingResponseWritingMethodCall(Type returnType, FactoryContext factoryContext)
{
// {
// string tempSourceString;
@@ -440,17 +510,40 @@ private static Expression CreateParamCheckingResponseWritingMethodCall(
localVariables[factoryContext.ExtraLocals.Count] = WasParamCheckFailureExpr;
- var set400StatusAndReturnCompletedTask = Expression.Block(
- Expression.Assign(StatusCodeExpr, Expression.Constant(400)),
- CompletedTaskExpr);
-
- var methodCall = CreateMethodCall(methodInfo, target, arguments);
-
- var checkWasParamCheckFailure = Expression.Condition(WasParamCheckFailureExpr,
- set400StatusAndReturnCompletedTask,
- AddResponseWritingToMethodCall(methodCall, methodInfo.ReturnType));
+ // If filters have been registered, we set the `wasParamCheckFailure` property
+ // but do not return from the invocation to allow the filters to run.
+ if (factoryContext.Filters is { Count: > 0 })
+ {
+ // if (wasParamCheckFailure)
+ // {
+ // httpContext.Response.StatusCode = 400;
+ // }
+ // return RequestDelegateFactory.ExecuteObjectReturn(invocationPipeline.Invoke(context) as object);
+ var checkWasParamCheckFailureWithFilters = Expression.Block(
+ Expression.IfThen(
+ WasParamCheckFailureExpr,
+ Expression.Assign(StatusCodeExpr, Expression.Constant(400))),
+ AddResponseWritingToMethodCall(factoryContext.MethodCall!, returnType)
+ );
- checkParamAndCallMethod[factoryContext.ParamCheckExpressions.Count] = checkWasParamCheckFailure;
+ checkParamAndCallMethod[factoryContext.ParamCheckExpressions.Count] = checkWasParamCheckFailureWithFilters;
+ }
+ else
+ {
+ // wasParamCheckFailure ? {
+ // httpContext.Response.StatusCode = 400;
+ // return Task.CompletedTask;
+ // } : {
+ // return RequestDelegateFactory.ExecuteObjectReturn(invocationPipeline.Invoke(context) as object);
+ // }
+ var checkWasParamCheckFailure = Expression.Condition(
+ WasParamCheckFailureExpr,
+ Expression.Block(
+ Expression.Assign(StatusCodeExpr, Expression.Constant(400)),
+ CompletedTaskExpr),
+ AddResponseWritingToMethodCall(factoryContext.MethodCall!, returnType));
+ checkParamAndCallMethod[factoryContext.ParamCheckExpressions.Count] = checkWasParamCheckFailure;
+ }
return Expression.Block(localVariables, checkParamAndCallMethod);
}
@@ -1596,6 +1689,11 @@ private class FactoryContext
public bool ReadForm { get; set; }
public ParameterInfo? FirstFormRequestBodyParameter { get; set; }
+ // Properties for constructing and managing filters
+ public List ContextArgAccess { get; } = new();
+ public Expression? MethodCall { get; set; }
+ public List BoxedArgs { get; } = new();
+ public List? Filters { get; init; }
}
private static class RequestDelegateFactoryConstants
diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs
index 892cbd2c7efe..870c2a06158e 100644
--- a/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs
+++ b/src/Http/Http.Extensions/src/RequestDelegateFactoryOptions.cs
@@ -31,4 +31,9 @@ public sealed class RequestDelegateFactoryOptions
/// Prevent the from inferring a parameter should be bound from the request body without an attribute that implements .
///
public bool DisableInferBodyFromParameters { get; init; }
+
+ ///
+ /// The list of filters that must run in the pipeline for a given route handler.
+ ///
+ public IReadOnlyList? RouteHandlerFilters { get; init; }
}
diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs
index 44f88b26e236..1af56e02660e 100644
--- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs
+++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs
@@ -4200,6 +4200,196 @@ void TestAction(IFormFile file)
Assert.Equal(400, badHttpRequestException.StatusCode);
}
+ [Fact]
+ public async Task RequestDelegateFactory_InvokesFiltersButNotHandler_OnArgumentError()
+ {
+ var invoked = false;
+ // Arrange
+ string HelloName(string name)
+ {
+ invoked = true;
+ return $"Hello, {name}!";
+ };
+
+ var httpContext = CreateHttpContext();
+
+ // Act
+ var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions()
+ {
+ RouteHandlerFilters = new List() { new ModifyStringArgumentFilter() }
+ });
+ var requestDelegate = factoryResult.RequestDelegate;
+ await requestDelegate(httpContext);
+
+ // Assert
+ Assert.False(invoked);
+ Assert.Equal(400, httpContext.Response.StatusCode);
+ }
+
+ [Fact]
+ public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatProvidesCustomErrorMessage()
+ {
+ // Arrange
+ string HelloName(string name)
+ {
+ return $"Hello, {name}!";
+ };
+
+ var httpContext = CreateHttpContext();
+
+ var responseBodyStream = new MemoryStream();
+ httpContext.Response.Body = responseBodyStream;
+
+ // Act
+ var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions()
+ {
+ RouteHandlerFilters = new List() { new ProvideCustomErrorMessageFilter() }
+ });
+ var requestDelegate = factoryResult.RequestDelegate;
+ await requestDelegate(httpContext);
+
+ // Assert
+ var decodedResponseBody = JsonSerializer.Deserialize(responseBodyStream.ToArray());
+ Assert.Equal(400, httpContext.Response.StatusCode);
+ Assert.Equal("New response", decodedResponseBody!.Detail);
+ }
+
+ [Fact]
+ public async Task RequestDelegateFactory_CanInvokeMultipleEndpointFilters_ThatTouchArguments()
+ {
+ // Arrange
+ string HelloName(string name, int age)
+ {
+ return $"Hello, {name}! You are {age} years old.";
+ };
+
+ var loggerInvoked = 0;
+ void Log(string arg) => loggerInvoked++;
+
+ var httpContext = CreateHttpContext();
+
+ var responseBodyStream = new MemoryStream();
+ httpContext.Response.Body = responseBodyStream;
+
+ httpContext.Request.Query = new QueryCollection(new Dictionary
+ {
+ ["name"] = "TestName",
+ ["age"] = "25"
+ });
+
+ // Act
+ var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions()
+ {
+ RouteHandlerFilters = new List() { new ModifyIntArgumentFilter(), new LogArgumentsFilter(Log) }
+ });
+ var requestDelegate = factoryResult.RequestDelegate;
+ await requestDelegate(httpContext);
+
+ // Assert
+ var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray());
+ Assert.Equal("Hello, TestName! You are 27 years old.", responseBody);
+ Assert.Equal(2, loggerInvoked);
+ }
+
+ [Fact]
+ public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatModifiesBodyParameter()
+ {
+ // Arrange
+ Todo todo = new Todo() { Name = "Write tests", IsComplete = true };
+ string PrintTodo(Todo todo)
+ {
+ return $"{todo.Name} is {(todo.IsComplete ? "done" : "not done")}.";
+ };
+
+ var httpContext = CreateHttpContext();
+
+ var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(todo);
+ var stream = new MemoryStream(requestBodyBytes);
+ httpContext.Request.Body = stream;
+ httpContext.Request.Headers["Content-Type"] = "application/json";
+ httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(CultureInfo.InvariantCulture);
+ httpContext.Features.Set(new RequestBodyDetectionFeature(true));
+
+ var responseBodyStream = new MemoryStream();
+ httpContext.Response.Body = responseBodyStream;
+
+ // Act
+ var factoryResult = RequestDelegateFactory.Create(PrintTodo, new RequestDelegateFactoryOptions()
+ {
+ RouteHandlerFilters = new List() { new ModifyTodoArgumentFilter() }
+ });
+ var requestDelegate = factoryResult.RequestDelegate;
+ await requestDelegate(httpContext);
+
+ // Assert
+ var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray());
+ Assert.Equal("Write tests is not done.", responseBody);
+ }
+
+ [Fact]
+ public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatModifiesResult()
+ {
+ // Arrange
+ string HelloName(string name)
+ {
+ return $"Hello, {name}!";
+ };
+
+ var httpContext = CreateHttpContext();
+
+ var responseBodyStream = new MemoryStream();
+ httpContext.Response.Body = responseBodyStream;
+
+ httpContext.Request.Query = new QueryCollection(new Dictionary
+ {
+ ["name"] = "TestName"
+ });
+
+ // Act
+ var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions()
+ {
+ RouteHandlerFilters = new List() { new ModifyStringResultFilter() }
+ });
+ var requestDelegate = factoryResult.RequestDelegate;
+ await requestDelegate(httpContext);
+
+ // Assert
+ var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray());
+ Assert.Equal("HELLO, TESTNAME!", responseBody);
+ }
+
+ [Fact]
+ public async Task RequestDelegateFactory_CanInvokeMultipleEndpointFilters_ThatModifyArgumentsAndResult()
+ {
+ // Arrange
+ string HelloName(string name)
+ {
+ return $"Hello, {name}!";
+ };
+
+ var httpContext = CreateHttpContext();
+
+ var responseBodyStream = new MemoryStream();
+ httpContext.Response.Body = responseBodyStream;
+
+ httpContext.Request.Query = new QueryCollection(new Dictionary
+ {
+ ["name"] = "TestName"
+ });
+
+ // Act
+ var factoryResult = RequestDelegateFactory.Create(HelloName, new RequestDelegateFactoryOptions()
+ {
+ RouteHandlerFilters = new List() { new ModifyStringResultFilter(), new ModifyStringArgumentFilter() }
+ });
+ var requestDelegate = factoryResult.RequestDelegate;
+ await requestDelegate(httpContext);
+
+ // Assert
+ var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray());
+ Assert.Equal("HELLO, TESTNAMEPREFIX!", responseBody);
+ }
+
private DefaultHttpContext CreateHttpContext()
{
var responseFeature = new TestHttpResponseFeature();
@@ -4559,6 +4749,78 @@ public TlsConnectionFeature(X509Certificate2 clientCertificate)
throw new NotImplementedException();
}
}
+
+ private class ModifyStringArgumentFilter : IRouteHandlerFilter
+ {
+ public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next)
+ {
+ context.Parameters[0] = context.Parameters[0] != null ? $"{((string)context.Parameters[0]!)}Prefix" : "NULL";
+ return await next(context);
+ }
+ }
+
+ private class ModifyIntArgumentFilter : IRouteHandlerFilter
+ {
+ public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next)
+ {
+ context.Parameters[1] = ((int)context.Parameters[1]!) + 2;
+ return await next(context);
+ }
+ }
+
+ private class ModifyTodoArgumentFilter : IRouteHandlerFilter
+ {
+ public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next)
+ {
+ Todo originalTodo = (Todo)context.Parameters[0]!;
+ originalTodo!.IsComplete = !originalTodo.IsComplete;
+ context.Parameters[0] = originalTodo;
+ return await next(context);
+ }
+ }
+
+ private class ProvideCustomErrorMessageFilter : IRouteHandlerFilter
+ {
+ public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next)
+ {
+ if (context.HttpContext.Response.StatusCode == 400)
+ {
+ return Results.Problem("New response", statusCode: 400);
+ }
+ return await next(context);
+ }
+ }
+
+ private class LogArgumentsFilter : IRouteHandlerFilter
+ {
+ private Action _logger;
+
+ public LogArgumentsFilter(Action logger)
+ {
+ _logger = logger;
+ }
+ public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next)
+ {
+ foreach (var parameter in context.Parameters)
+ {
+ _logger(parameter!.ToString() ?? "no arg");
+ }
+ return await next(context);
+ }
+ }
+
+ private class ModifyStringResultFilter : IRouteHandlerFilter
+ {
+ public async ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next)
+ {
+ var previousResult = await next(context);
+ if (previousResult is string stringResult)
+ {
+ return stringResult.ToUpperInvariant();
+ }
+ return previousResult;
+ }
+ }
}
internal static class TestExtensionResults
diff --git a/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs b/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs
new file mode 100644
index 000000000000..9872915904a8
--- /dev/null
+++ b/src/Http/Routing/src/Builder/DelegateRouteHandlerFilter.cs
@@ -0,0 +1,19 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.AspNetCore.Http;
+
+internal class DelegateRouteHandlerFilter : IRouteHandlerFilter
+{
+ private readonly Func>, ValueTask> _routeHandlerFilter;
+
+ internal DelegateRouteHandlerFilter(Func>, ValueTask> routeHandlerFilter)
+ {
+ _routeHandlerFilter = routeHandlerFilter;
+ }
+
+ public ValueTask InvokeAsync(RouteHandlerFilterContext context, Func> next)
+ {
+ return _routeHandlerFilter(context, next);
+ }
+}
diff --git a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs
index 4c1cb9b09904..6ce2d6c2c7ea 100644
--- a/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs
+++ b/src/Http/Routing/src/Builder/EndpointRouteBuilderExtensions.cs
@@ -484,18 +484,7 @@ private static RouteHandlerBuilder Map(
var routeHandlerOptions = endpoints.ServiceProvider?.GetService>();
- var options = new RequestDelegateFactoryOptions
- {
- ServiceProvider = endpoints.ServiceProvider,
- RouteParameterNames = routeParams,
- ThrowOnBadRequest = routeHandlerOptions?.Value.ThrowOnBadRequest ?? false,
- DisableInferBodyFromParameters = disableInferBodyFromParameters,
- };
-
- var requestDelegateResult = RequestDelegateFactory.Create(handler, options);
-
var builder = new RouteEndpointBuilder(
- requestDelegateResult.RequestDelegate,
pattern,
defaultOrder)
{
@@ -518,31 +507,46 @@ private static RouteHandlerBuilder Map(
builder.DisplayName = $"{builder.DisplayName} => {endpointName}";
}
- // Add delegate attributes as metadata
- var attributes = handler.Method.GetCustomAttributes();
-
- // Add add request delegate metadata
- foreach (var metadata in requestDelegateResult.EndpointMetadata)
+ var dataSource = endpoints.DataSources.OfType().FirstOrDefault();
+ if (dataSource is null)
{
- builder.Metadata.Add(metadata);
+ dataSource = new ModelEndpointDataSource();
+ endpoints.DataSources.Add(dataSource);
}
- // This can be null if the delegate is a dynamic method or compiled from an expression tree
- if (attributes is not null)
+ var routeHandlerBuilder = new RouteHandlerBuilder(dataSource.AddEndpointBuilder(builder));
+ routeHandlerBuilder.Add(endpointBuilder =>
{
- foreach (var attribute in attributes)
+ var options = new RequestDelegateFactoryOptions
{
- builder.Metadata.Add(attribute);
+ ServiceProvider = endpoints.ServiceProvider,
+ RouteParameterNames = routeParams,
+ ThrowOnBadRequest = routeHandlerOptions?.Value.ThrowOnBadRequest ?? false,
+ DisableInferBodyFromParameters = disableInferBodyFromParameters,
+ RouteHandlerFilters = routeHandlerBuilder.RouteHandlerFilters
+ };
+ var filteredRequestDelegateResult = RequestDelegateFactory.Create(handler, options);
+ // Add request delegate metadata
+ foreach (var metadata in filteredRequestDelegateResult.EndpointMetadata)
+ {
+ endpointBuilder.Metadata.Add(metadata);
}
- }
- var dataSource = endpoints.DataSources.OfType().FirstOrDefault();
- if (dataSource is null)
- {
- dataSource = new ModelEndpointDataSource();
- endpoints.DataSources.Add(dataSource);
- }
+ // We add attributes on the handler after those automatically generated by the
+ // RDF since they have a higher specificity.
+ var attributes = handler.Method.GetCustomAttributes();
+
+ // This can be null if the delegate is a dynamic method or compiled from an expression tree
+ if (attributes is not null)
+ {
+ foreach (var attribute in attributes)
+ {
+ endpointBuilder.Metadata.Add(attribute);
+ }
+ }
+ endpointBuilder.RequestDelegate = filteredRequestDelegateResult.RequestDelegate;
+ });
- return new RouteHandlerBuilder(dataSource.AddEndpointBuilder(builder));
+ return routeHandlerBuilder;
}
}
diff --git a/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs b/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs
index 40b896f2db24..b42e22cc3d8d 100644
--- a/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs
+++ b/src/Http/Routing/src/Builder/RouteHandlerBuilder.cs
@@ -1,6 +1,8 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using Microsoft.AspNetCore.Http;
+
namespace Microsoft.AspNetCore.Builder;
///
@@ -11,6 +13,8 @@ public sealed class RouteHandlerBuilder : IEndpointConventionBuilder
private readonly IEnumerable? _endpointConventionBuilders;
private readonly IEndpointConventionBuilder? _endpointConventionBuilder;
+ internal List RouteHandlerFilters { get; } = new();
+
///
/// Instantiates a new given a single
/// .
diff --git a/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs b/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs
new file mode 100644
index 000000000000..ffec088f3e73
--- /dev/null
+++ b/src/Http/Routing/src/Builder/RouteHandlerFilterExtensions.cs
@@ -0,0 +1,49 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Diagnostics.CodeAnalysis;
+using Microsoft.AspNetCore.Builder;
+
+namespace Microsoft.AspNetCore.Http;
+
+///
+/// Extension methods for adding to a route handler.
+///
+public static class RouteHandlerFilterExtensions
+{
+ ///
+ /// Registers a filter onto the route handler.
+ ///
+ /// The .
+ /// The to register.
+ /// A that can be used to further customize the route handler.
+ public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, IRouteHandlerFilter filter)
+ {
+ builder.RouteHandlerFilters.Add(filter);
+ return builder;
+ }
+
+ ///
+ /// Registers a filter of type onto the route handler.
+ ///
+ /// The type of the to register.
+ /// The .
+ /// A that can be used to further customize the route handler.
+ public static RouteHandlerBuilder AddFilter<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TFilterType>(this RouteHandlerBuilder builder) where TFilterType : IRouteHandlerFilter, new()
+ {
+ builder.RouteHandlerFilters.Add(new TFilterType());
+ return builder;
+ }
+
+ ///
+ /// Registers a filter given a delegate onto the route handler.
+ ///
+ /// The .
+ /// A representing the core logic of the filter.
+ /// A that can be used to further customize the route handler.
+ public static RouteHandlerBuilder AddFilter(this RouteHandlerBuilder builder, Func>, ValueTask> routeHandlerFilter)
+ {
+ builder.RouteHandlerFilters.Add(new DelegateRouteHandlerFilter(routeHandlerFilter));
+ return builder;
+ }
+}
diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt
index 341f2446a1cf..4cf74e9056fd 100644
--- a/src/Http/Routing/src/PublicAPI.Unshipped.txt
+++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt
@@ -1,4 +1,5 @@
#nullable enable
+Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions
Microsoft.AspNetCore.Routing.RouteOptions.SetParameterPolicy(string! token, System.Type! type) -> void
Microsoft.AspNetCore.Routing.RouteOptions.SetParameterPolicy(string! token) -> void
static Microsoft.AspNetCore.Builder.EndpointRouteBuilderExtensions.MapPatch(this Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! endpoints, string! pattern, System.Delegate! handler) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder!
@@ -6,3 +7,6 @@ static Microsoft.AspNetCore.Builder.EndpointRouteBuilderExtensions.MapPatch(this
override Microsoft.AspNetCore.Routing.RouteValuesAddress.ToString() -> string?
*REMOVED*~Microsoft.AspNetCore.Routing.DefaultInlineConstraintResolver.DefaultInlineConstraintResolver(Microsoft.Extensions.Options.IOptions! routeOptions, System.IServiceProvider! serviceProvider) -> void
Microsoft.AspNetCore.Routing.DefaultInlineConstraintResolver.DefaultInlineConstraintResolver(Microsoft.Extensions.Options.IOptions! routeOptions, System.IServiceProvider! serviceProvider) -> void
+static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, Microsoft.AspNetCore.Http.IRouteHandlerFilter! filter) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder!
+static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder, System.Func>!, System.Threading.Tasks.ValueTask>! routeHandlerFilter) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder!
+static Microsoft.AspNetCore.Http.RouteHandlerFilterExtensions.AddFilter(this Microsoft.AspNetCore.Builder.RouteHandlerBuilder! builder) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder!
diff --git a/src/Http/Routing/src/RouteEndpointBuilder.cs b/src/Http/Routing/src/RouteEndpointBuilder.cs
index bb27282fa97e..add1f849a4a4 100644
--- a/src/Http/Routing/src/RouteEndpointBuilder.cs
+++ b/src/Http/Routing/src/RouteEndpointBuilder.cs
@@ -38,6 +38,24 @@ public RouteEndpointBuilder(
Order = order;
}
+ ///
+ /// Constructs a new instance.
+ ///
+ /// The to use in URL matching.
+ /// The order assigned to the endpoint.
+ ///
+ /// This constructor allows the to be added to the
+ /// after construction but before
+ /// is invoked.
+ ///
+ internal RouteEndpointBuilder(
+ RoutePattern routePattern,
+ int order)
+ {
+ RoutePattern = routePattern;
+ Order = order;
+ }
+
///
public override Endpoint Build()
{
diff --git a/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs b/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs
index a76b96a19be5..df21fbb1d7f1 100644
--- a/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs
+++ b/src/Http/Routing/test/UnitTests/Builder/RequestDelegateEndpointRouteBuilderExtensionsTest.cs
@@ -89,8 +89,9 @@ public async Task MapEndpoint_ReturnGenericTypeTask_GeneratedDelegate()
var endpointBuilder = builder.MapGet("/", GenericTypeTaskDelegate);
// Assert
- var endpointBuilder1 = GetRouteEndpointBuilder(builder);
- var requestDelegate = endpointBuilder1.RequestDelegate;
+ var dataSource = GetBuilderEndpointDataSource(builder);
+ var endpoint = Assert.Single(dataSource.Endpoints); // Triggers build and construction of delegate
+ var requestDelegate = endpoint.RequestDelegate;
await requestDelegate(httpContext);
var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray());
diff --git a/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs b/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs
index 5636c10cd940..1c3450159801 100644
--- a/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs
+++ b/src/Http/Routing/test/UnitTests/Builder/RouteHandlerEndpointRouteBuilderExtensionsTest.cs
@@ -208,7 +208,9 @@ public async Task MapGetWithoutRouteParameter_BuildsEndpointWithQuerySpecificBin
public void MapGet_ThrowsWithImplicitFromBody()
{
var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider()));
- var ex = Assert.Throws(() => builder.MapGet("/", (Todo todo) => { }));
+ _ = builder.MapGet("/", (Todo todo) => { });
+ var dataSource = GetBuilderEndpointDataSource(builder);
+ var ex = Assert.Throws(() => dataSource.Endpoints);
Assert.Contains("Body was inferred but the method does not allow inferred body parameters.", ex.Message);
Assert.Contains("Did you mean to register the \"Body (Inferred)\" parameter(s) as a Service or apply the [FromServices] or [FromBody] attribute?", ex.Message);
}
@@ -217,7 +219,9 @@ public void MapGet_ThrowsWithImplicitFromBody()
public void MapDelete_ThrowsWithImplicitFromBody()
{
var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider()));
- var ex = Assert.Throws(() => builder.MapDelete("/", (Todo todo) => { }));
+ _ = builder.MapDelete("/", (Todo todo) => { });
+ var dataSource = GetBuilderEndpointDataSource(builder);
+ var ex = Assert.Throws(() => dataSource.Endpoints);
Assert.Contains("Body was inferred but the method does not allow inferred body parameters.", ex.Message);
Assert.Contains("Did you mean to register the \"Body (Inferred)\" parameter(s) as a Service or apply the [FromServices] or [FromBody] attribute?", ex.Message);
}
@@ -243,7 +247,9 @@ public static object[][] NonImplicitFromBodyMethods
public void MapVerb_ThrowsWithImplicitFromBody(string method)
{
var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider()));
- var ex = Assert.Throws(() => builder.MapMethods("/", new[] { method }, (Todo todo) => { }));
+ _ = builder.MapMethods("/", new[] { method }, (Todo todo) => { });
+ var dataSource = GetBuilderEndpointDataSource(builder);
+ var ex = Assert.Throws(() => dataSource.Endpoints);
Assert.Contains("Body was inferred but the method does not allow inferred body parameters.", ex.Message);
Assert.Contains("Did you mean to register the \"Body (Inferred)\" parameter(s) as a Service or apply the [FromServices] or [FromBody] attribute?", ex.Message);
}
@@ -581,7 +587,9 @@ public async Task MapVerbWithRouteParameterDoesNotFallbackToQuery(Func(() => builder.MapGet("/", ([FromRoute] int id) => { }));
+ _ = builder.MapGet("/", ([FromRoute] int id) => { });
+ var dataSource = GetBuilderEndpointDataSource(builder);
+ var ex = Assert.Throws(() => dataSource.Endpoints);
Assert.Equal("'id' is not a route parameter.", ex.Message);
}
@@ -637,7 +645,9 @@ public async Task MapGetWithNamedFromRouteParameter_FailsForParameterName()
public void MapGetWithNamedFromRouteParameter_ThrowsForMismatchedPattern()
{
var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvider()));
- var ex = Assert.Throws(() => builder.MapGet("/{id}", ([FromRoute(Name = "value")] int id, HttpContext httpContext) => { }));
+ _ = builder.MapGet("/{id}", ([FromRoute(Name = "value")] int id, HttpContext httpContext) => { });
+ var dataSource = GetBuilderEndpointDataSource(builder);
+ var ex = Assert.Throws(() => dataSource.Endpoints);
Assert.Equal("'value' is not a route parameter.", ex.Message);
}
@@ -677,7 +687,6 @@ public void MapPost_BuildsEndpointWithCorrectEndpointMetadata()
Assert.False(endpointMetadata!.IsOptional);
Assert.Equal(typeof(Todo), endpointMetadata.RequestType);
Assert.Equal(new[] { "application/xml" }, endpointMetadata.ContentTypes);
-
}
[Fact]