diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index edd832b14886..7b9eff08f65d 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -33,7 +33,9 @@ public static partial class RequestDelegateFactory { private static readonly ParameterBindingMethodCache ParameterBindingMethodCache = new(); - private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ExecuteTaskWithEmptyResultMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskWithEmptyResult), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ExecuteValueTaskWithEmptyResultMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskWithEmptyResult), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskOfStringMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!; @@ -47,6 +49,8 @@ public static partial class RequestDelegateFactory private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringIsNullOrEmptyMethod = typeof(string).GetMethod(nameof(string.IsNullOrEmpty), BindingFlags.Static | BindingFlags.Public)!; private static readonly MethodInfo WrapObjectAsValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(WrapObjectAsValueTask), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo TaskOfTToValueTaskOfObjectMethod = typeof(RequestDelegateFactory).GetMethod(nameof(TaskOfTToValueTaskOfObject), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ValueTaskOfTToValueTaskOfObjectMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ValueTaskOfTToValueTaskOfObject), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo PopulateMetadataForParameterMethod = typeof(RequestDelegateFactory).GetMethod(nameof(PopulateMetadataForParameter), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo PopulateMetadataForEndpointMethod = typeof(RequestDelegateFactory).GetMethod(nameof(PopulateMetadataForEndpoint), BindingFlags.NonPublic | BindingFlags.Static)!; @@ -258,24 +262,40 @@ private static RouteHandlerFilterDelegate CreateFilterPipeline(MethodInfo method // httpContext.Response.StatusCode >= 400 // ? Task.CompletedTask // : { - // target = targetFactory(httpContext); - // handler is ((Type)target).MethodName(parameters); - // handler((string)context.Parameters[0], (int)context.Parameters[1]); + // handlerInvocation // } - var filteredInvocation = Expression.Lambda( - Expression.Condition( - Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)), - CompletedValueTaskExpr, - Expression.Block( + // To generate the handler invocation, we first create the + // target of the handler provided to the route. + // target = targetFactory(httpContext); + // This target is then used to generate the handler invocation like so; + // ((Type)target).MethodName(parameters); + // When `handler` returns an object, we generate the following wrapper + // to convert it to `ValueTask` as expected in the filter + // pipeline. + // ValueTask.FromResult(handler((string)context.Parameters[0], (int)context.Parameters[1])); + // When the `handler` is a generic Task or ValueTask we await the task and + // create a `ValueTask from the resulting value. + // new ValueTask(await handler((string)context.Parameters[0], (int)context.Parameters[1])); + // When the `handler` returns a void or a void-returning Task, then we return an EmptyHttpResult + // to as a ValueTask + // } + var handlerReturnMapping = MapHandlerReturnTypeToValueTask( + targetExpression is null + ? Expression.Call(methodInfo, factoryContext.ContextArgAccess) + : Expression.Call(targetExpression, methodInfo, factoryContext.ContextArgAccess), + methodInfo.ReturnType); + var handlerInvocation = Expression.Block( new[] { TargetExpr }, targetFactory == null ? Expression.Empty() : Expression.Assign(TargetExpr, Expression.Invoke(targetFactory, FilterContextHttpContextExpr)), - Expression.Call(WrapObjectAsValueTaskMethod, - targetExpression is null - ? Expression.Call(methodInfo, factoryContext.ContextArgAccess) - : Expression.Call(targetExpression, methodInfo, factoryContext.ContextArgAccess)) - )), + handlerReturnMapping + ); + var filteredInvocation = Expression.Lambda( + Expression.Condition( + Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)), + CompletedValueTaskExpr, + handlerInvocation), FilterContextExpr).Compile(); var routeHandlerContext = new RouteHandlerContext( methodInfo, @@ -292,6 +312,72 @@ targetExpression is null return filteredInvocation; } + private static Expression MapHandlerReturnTypeToValueTask(Expression methodCall, Type returnType) + { + if (returnType == typeof(void)) + { + return Expression.Block(methodCall, Expression.Constant(new ValueTask(EmptyHttpResult.Instance))); + } + else if (returnType == typeof(Task)) + { + return Expression.Call(ExecuteTaskWithEmptyResultMethod, methodCall); + } + else if (returnType == typeof(ValueTask)) + { + return Expression.Call(ExecuteValueTaskWithEmptyResultMethod, methodCall); + } + else if (returnType == typeof(ValueTask)) + { + return methodCall; + } + else if (returnType.IsGenericType && + returnType.GetGenericTypeDefinition() == typeof(ValueTask<>)) + { + var typeArg = returnType.GetGenericArguments()[0]; + return Expression.Call(ValueTaskOfTToValueTaskOfObjectMethod.MakeGenericMethod(typeArg), methodCall); + } + else if (returnType.IsGenericType && + returnType.GetGenericTypeDefinition() == typeof(Task<>)) + { + var typeArg = returnType.GetGenericArguments()[0]; + return Expression.Call(TaskOfTToValueTaskOfObjectMethod.MakeGenericMethod(typeArg), methodCall); + } + else + { + return Expression.Call(WrapObjectAsValueTaskMethod, methodCall); + } + } + + private static ValueTask ValueTaskOfTToValueTaskOfObject(ValueTask valueTask) + { + static async ValueTask ExecuteAwaited(ValueTask valueTask) + { + return await valueTask; + } + + if (valueTask.IsCompletedSuccessfully) + { + return new ValueTask(valueTask.Result); + } + + return ExecuteAwaited(valueTask); + } + + private static ValueTask TaskOfTToValueTaskOfObject(Task task) + { + static async ValueTask ExecuteAwaited(Task task) + { + return await task; + } + + if (task.IsCompletedSuccessfully) + { + return new ValueTask(task.Result); + } + + return ExecuteAwaited(task); + } + private static void AddTypeProvidedMetadata(MethodInfo methodInfo, List metadata, IServiceProvider? services) { object?[]? invokeArgs = null; @@ -1649,7 +1735,7 @@ private static async Task ExecuteObjectReturn(object? obj, HttpContext httpConte } } - private static Task ExecuteTask(Task task, HttpContext httpContext) + private static Task ExecuteTaskOfT(Task task, HttpContext httpContext) { EnsureRequestTaskNotNull(task); @@ -1707,6 +1793,39 @@ static async Task ExecuteAwaited(ValueTask task) return ExecuteAwaited(task); } + private static ValueTask ExecuteTaskWithEmptyResult(Task task) + { + static async ValueTask ExecuteAwaited(Task task) + { + await task; + return EmptyHttpResult.Instance; + } + + if (task.IsCompletedSuccessfully) + { + return new ValueTask(EmptyHttpResult.Instance); + } + + return ExecuteAwaited(task); + } + + private static ValueTask ExecuteValueTaskWithEmptyResult(ValueTask valueTask) + { + static async ValueTask ExecuteAwaited(ValueTask task) + { + await task; + return EmptyHttpResult.Instance; + } + + if (valueTask.IsCompletedSuccessfully) + { + valueTask.GetAwaiter().GetResult(); + return new ValueTask(EmptyHttpResult.Instance); + } + + return ExecuteAwaited(valueTask); + } + private static Task ExecuteValueTaskOfT(ValueTask task, HttpContext httpContext) { static async Task ExecuteAwaited(ValueTask task, HttpContext httpContext) @@ -2041,4 +2160,22 @@ private static void FormatTrackedParameters(FactoryContext factoryContext, Strin errorMessage.AppendLine(FormattableString.Invariant($"{kv.Key,-19} | {kv.Value,-15}")); } } + + // Due to cyclic references between Http.Extensions and + // Http.Results, we define our own instance of the `EmptyHttpResult` + // type here. + private sealed class EmptyHttpResult : IResult + { + private EmptyHttpResult() + { + } + + public static EmptyHttpResult Instance { get; } = new(); + + /// + public Task ExecuteAsync(HttpContext httpContext) + { + return Task.CompletedTask; + } + } } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 727219b43b98..b1134ba7cfb2 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -4698,6 +4698,291 @@ string HelloName(string name) Assert.Equal("HELLO, TESTNAMEPREFIX!", responseBody); } + public static object[][] TaskOfTMethods + { + get + { + Task TaskOfTMethod() + { + return Task.FromResult("foo"); + } + + async Task TaskOfTWithYieldMethod() + { + await Task.Yield(); + return "foo"; + } + + async Task TaskOfObjectWithYieldMethod() + { + await Task.Yield(); + return "foo"; + } + + return new object[][] + { + new object[] { (Func>)TaskOfTMethod }, + new object[] { (Func>)TaskOfTWithYieldMethod }, + new object[] { (Func>)TaskOfObjectWithYieldMethod } + }; + } + } + + [Theory] + [MemberData(nameof(TaskOfTMethods))] + public async Task CanInvokeFilter_OnTaskOfTReturningHandler(Delegate @delegate) + { + // Arrange + var responseBodyStream = new MemoryStream(); + var httpContext = CreateHttpContext(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(@delegate, new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + return await next(context); + } + } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + Assert.Equal(200, httpContext.Response.StatusCode); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("foo", decodedResponseBody); + } + + public static object[][] ValueTaskOfTMethods + { + get + { + ValueTask ValueTaskOfTMethod() + { + return ValueTask.FromResult("foo"); + } + + async ValueTask ValueTaskOfTWithYieldMethod() + { + await Task.Yield(); + return "foo"; + } + + async ValueTask ValueTaskOfObjectWithYield() + { + await Task.Yield(); + return "foo"; + } + + return new object[][] + { + new object[] { (Func>)ValueTaskOfTMethod }, + new object[] { (Func>)ValueTaskOfTWithYieldMethod }, + new object[] { (Func>)ValueTaskOfObjectWithYield } + }; + } + } + + [Theory] + [MemberData(nameof(ValueTaskOfTMethods))] + public async Task CanInvokeFilter_OnValueTaskOfTReturningHandler(Delegate @delegate) + { + // Arrange + var responseBodyStream = new MemoryStream(); + var httpContext = CreateHttpContext(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(@delegate, new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + return await next(context); + } + } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + Assert.Equal(200, httpContext.Response.StatusCode); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("foo", decodedResponseBody); + } + + public static object[][] VoidReturningMethods + { + get + { + void VoidMethod() { } + + ValueTask ValueTaskMethod() + { + return ValueTask.CompletedTask; + } + + Task TaskMethod() + { + return Task.CompletedTask; + } + + async ValueTask ValueTaskWithYieldMethod() + { + await Task.Yield(); + } + + async Task TaskWithYieldMethod() + { + await Task.Yield(); + } + + return new object[][] + { + new object[] { (Action)VoidMethod }, + new object[] { (Func)ValueTaskMethod }, + new object[] { (Func)TaskMethod }, + new object[] { (Func)ValueTaskWithYieldMethod }, + new object[] { (Func)TaskWithYieldMethod} + }; + } + } + + [Theory] + [MemberData(nameof(VoidReturningMethods))] + public async Task CanInvokeFilter_OnVoidReturningHandler(Delegate @delegate) + { + // Arrange + var responseBodyStream = new MemoryStream(); + var httpContext = CreateHttpContext(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(@delegate, new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + return await next(context); + } + } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + Assert.Equal(200, httpContext.Response.StatusCode); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(String.Empty, decodedResponseBody); + } + + [Fact] + public async Task CanInvokeFilter_OnTaskModifyingHttpContext() + { + // Arrange + var tcs = new TaskCompletionSource(); + async Task HandlerWithTaskAwait(HttpContext c) + { + await tcs.Task; + await Task.Yield(); + c.Response.StatusCode = 400; + }; + var responseBodyStream = new MemoryStream(); + var httpContext = CreateHttpContext(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(HandlerWithTaskAwait, new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + return await next(context); + } + } + }); + + var requestDelegate = factoryResult.RequestDelegate; + var request = requestDelegate(httpContext); + tcs.TrySetResult(); + await request; + + Assert.Equal(400, httpContext.Response.StatusCode); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal(string.Empty, decodedResponseBody); + } + + public static object[][] TasksOfTypesMethods + { + get + { + ValueTask ValueTaskOfStructMethod() + { + return ValueTask.FromResult(new TodoStruct { Name = "Test todo"}); + } + + async ValueTask ValueTaskOfStructWithYieldMethod() + { + await Task.Yield(); + return new TodoStruct { Name = "Test todo" }; + } + + Task TaskOfStructMethod() + { + return Task.FromResult(new TodoStruct { Name = "Test todo" }); + } + + async Task TaskOfStructWithYieldMethod() + { + await Task.Yield(); + return new TodoStruct { Name = "Test todo" }; + } + + return new object[][] + { + new object[] { (Func>)ValueTaskOfStructMethod }, + new object[] { (Func>)ValueTaskOfStructWithYieldMethod }, + new object[] { (Func>)TaskOfStructMethod }, + new object[] { (Func>)TaskOfStructWithYieldMethod } + }; + } + } + + [Theory] + [MemberData(nameof(TasksOfTypesMethods))] + public async Task CanInvokeFilter_OnHandlerReturningTasksOfStruct(Delegate @delegate) + { + // Arrange + var responseBodyStream = new MemoryStream(); + var httpContext = CreateHttpContext(); + httpContext.Response.Body = responseBodyStream; + + // Act + var factoryResult = RequestDelegateFactory.Create(@delegate, new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + return await next(context); + } + } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + Assert.Equal(200, httpContext.Response.StatusCode); + var deserializedResponseBody = JsonSerializer.Deserialize(responseBodyStream.ToArray(), new JsonSerializerOptions + { + PropertyNameCaseInsensitive = true + }); + Assert.Equal("Test todo", deserializedResponseBody.Name); + } + [Fact] public void Create_AddsDelegateMethodInfo_AsMetadata() {