From 73339fd0495fcad55a6db8112e0658e35494285e Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 19 Aug 2021 15:53:14 -0700 Subject: [PATCH 1/3] Added support for type based parameter binding (#35496) * Added support for type based parameter binding - Added a convention that allows custom async binding logic to run for parameters that have a static BindAsync method that takes an HttpContext and return a ValueTask of object. This allows customers to write custom binders based solely on type (it's an extension of the existing TryParse pattern). - There's allocation overhead per request once there's a parameter binder for a delegate. This is because we need to box all of the arguments since we're not using generated code to compute data from the list of binders. - Changed TryParse tests to BindAsync tests and added more tests. --- .../src/RequestDelegateFactory.cs | 161 +++++++-- .../test/RequestDelegateFactoryTests.cs | 317 +++++++++++++++++- .../test/TryParseMethodCacheTests.cs | 128 ++++++- .../EndpointMetadataApiDescriptionProvider.cs | 1 + ...pointMetadataApiDescriptionProviderTest.cs | 15 + src/Shared/TryParseMethodCache.cs | 29 +- 6 files changed, 616 insertions(+), 35 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index d6d80df8bae2..3f4cd4cc79ec 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -45,6 +45,7 @@ public static partial class RequestDelegateFactory private static readonly ParameterExpression HttpContextExpr = Expression.Parameter(typeof(HttpContext), "httpContext"); private static readonly ParameterExpression BodyValueExpr = Expression.Parameter(typeof(object), "bodyValue"); private static readonly ParameterExpression WasParamCheckFailureExpr = Expression.Variable(typeof(bool), "wasParamCheckFailure"); + private static readonly ParameterExpression BoundValuesArrayExpr = Expression.Parameter(typeof(object[]), "boundValues"); private static readonly MemberExpression RequestServicesExpr = Expression.Property(HttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.RequestServices))!); private static readonly MemberExpression HttpRequestExpr = Expression.Property(HttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.Request))!); @@ -194,7 +195,6 @@ private static Expression[] CreateArguments(ParameterInfo[]? parameters, Factory { var errorMessage = BuildErrorMessageForMultipleBodyParameters(factoryContext); throw new InvalidOperationException(errorMessage); - } return args; @@ -259,7 +259,11 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext { return RequestAbortedExpr; } - else if (parameter.ParameterType == typeof(string) || TryParseMethodCache.HasTryParseMethod(parameter)) + else if (TryParseMethodCache.HasBindAsyncMethod(parameter)) + { + return BindParameterFromBindAsync(parameter, factoryContext); + } + else if (parameter.ParameterType == typeof(string) || TryParseMethodCache.HasTryParseStringMethod(parameter)) { // 1. We bind from route values only, if route parameters are non-null and the parameter name is in that set. // 2. We bind from query only, if route parameters are non-null and the parameter name is NOT in that set. @@ -267,7 +271,6 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext // when RDF.Create is manually invoked. if (factoryContext.RouteParameters is { } routeParams) { - if (routeParams.Contains(parameter.Name, StringComparer.OrdinalIgnoreCase)) { // We're in the fallback case and we have a parameter and route parameter match so don't fallback @@ -353,7 +356,6 @@ private static Expression CreateParamCheckingResponseWritingMethodCall( var localVariables = new ParameterExpression[factoryContext.ExtraLocals.Count + 1]; var checkParamAndCallMethod = new Expression[factoryContext.ParamCheckExpressions.Count + 1]; - for (var i = 0; i < factoryContext.ExtraLocals.Count; i++) { localVariables[i] = factoryContext.ExtraLocals[i]; @@ -500,14 +502,33 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, { if (factoryContext.JsonRequestBodyType is null) { + if (factoryContext.ParameterBinders.Count > 0) + { + // We need to generate the code for reading from the custom binders calling into the delegate + var continuation = Expression.Lambda>( + responseWritingMethodCall, TargetExpr, HttpContextExpr, BoundValuesArrayExpr).Compile(); + + // Looping over arrays is faster + var binders = factoryContext.ParameterBinders.ToArray(); + var count = binders.Length; + + return async (target, httpContext) => + { + var boundValues = new object?[count]; + + for (var i = 0; i < count; i++) + { + boundValues[i] = await binders[i](httpContext); + } + + await continuation(target, httpContext, boundValues); + }; + } + return Expression.Lambda>( responseWritingMethodCall, TargetExpr, HttpContextExpr).Compile(); } - // We need to generate the code for reading from the body before calling into the delegate - var invoker = Expression.Lambda>( - responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr).Compile(); - var bodyType = factoryContext.JsonRequestBodyType; object? defaultBodyValue = null; @@ -516,31 +537,82 @@ private static Expression AddResponseWritingToMethodCall(Expression methodCall, defaultBodyValue = Activator.CreateInstance(bodyType); } - return async (target, httpContext) => + if (factoryContext.ParameterBinders.Count > 0) { - object? bodyValue = defaultBodyValue; - var feature = httpContext.Features.Get(); - if (feature?.CanHaveBody == true) + // We need to generate the code for reading from the body before calling into the delegate + var continuation = Expression.Lambda>( + responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr, BoundValuesArrayExpr).Compile(); + + // Looping over arrays is faster + var binders = factoryContext.ParameterBinders.ToArray(); + var count = binders.Length; + + return async (target, httpContext) => { - try + // Run these first so that they can potentially read and rewind the body + var boundValues = new object?[count]; + + for (var i = 0; i < count; i++) { - bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType); + boundValues[i] = await binders[i](httpContext); } - catch (IOException ex) + + var bodyValue = defaultBodyValue; + var feature = httpContext.Features.Get(); + if (feature?.CanHaveBody == true) { - Log.RequestBodyIOException(httpContext, ex); - return; + try + { + bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType); + } + catch (IOException ex) + { + Log.RequestBodyIOException(httpContext, ex); + return; + } + catch (InvalidDataException ex) + { + Log.RequestBodyInvalidDataException(httpContext, ex); + httpContext.Response.StatusCode = 400; + return; + } } - catch (InvalidDataException ex) - { - Log.RequestBodyInvalidDataException(httpContext, ex); - httpContext.Response.StatusCode = 400; - return; + await continuation(target, httpContext, bodyValue, boundValues); + }; + } + else + { + // We need to generate the code for reading from the body before calling into the delegate + var continuation = Expression.Lambda>( + responseWritingMethodCall, TargetExpr, HttpContextExpr, BodyValueExpr).Compile(); + + return async (target, httpContext) => + { + var bodyValue = defaultBodyValue; + var feature = httpContext.Features.Get(); + if (feature?.CanHaveBody == true) + { + try + { + bodyValue = await httpContext.Request.ReadFromJsonAsync(bodyType); + } + catch (IOException ex) + { + Log.RequestBodyIOException(httpContext, ex); + return; + } + catch (InvalidDataException ex) + { + + Log.RequestBodyInvalidDataException(httpContext, ex); + httpContext.Response.StatusCode = 400; + return; + } } - } - await invoker(target, httpContext, bodyValue); - }; + await continuation(target, httpContext, bodyValue); + }; + } } private static Expression GetValueFromProperty(Expression sourceExpression, string key) @@ -739,6 +811,42 @@ private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo return BindParameterFromValue(parameter, Expression.Coalesce(routeValue, queryValue), factoryContext); } + private static Expression BindParameterFromBindAsync(ParameterInfo parameter, FactoryContext factoryContext) + { + // We reference the boundValues array by parameter index here + var nullability = NullabilityContext.Create(parameter); + var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable; + + // Get the BindAsync method + var body = TryParseMethodCache.FindBindAsyncMethod(parameter.ParameterType)!; + + // Compile the delegate to the BindAsync method for this parameter index + var bindAsyncDelegate = Expression.Lambda>>(body, HttpContextExpr).Compile(); + factoryContext.ParameterBinders.Add(bindAsyncDelegate); + + // boundValues[index] + var boundValueExpr = Expression.ArrayIndex(BoundValuesArrayExpr, Expression.Constant(factoryContext.ParameterBinders.Count - 1)); + + if (!isOptional) + { + var checkRequiredBodyBlock = Expression.Block( + Expression.IfThen( + Expression.Equal(boundValueExpr, Expression.Constant(null)), + Expression.Block( + Expression.Assign(WasParamCheckFailureExpr, Expression.Constant(true)), + Expression.Call(LogRequiredParameterNotProvidedMethod, + HttpContextExpr, Expression.Constant(parameter.ParameterType.Name), Expression.Constant(parameter.Name)) + ) + ) + ); + + factoryContext.ParamCheckExpressions.Add(checkRequiredBodyBlock); + } + + // (ParamterType)boundValues[i] + return Expression.Convert(boundValueExpr, parameter.ParameterType); + } + private static Expression BindParameterFromBody(ParameterInfo parameter, bool allowEmpty, FactoryContext factoryContext) { if (factoryContext.JsonRequestBodyType is not null) @@ -749,7 +857,6 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al { factoryContext.TrackedParameters.Remove(parameterName); factoryContext.TrackedParameters.Add(parameterName, "UNKNOWN"); - } } @@ -881,7 +988,6 @@ static async Task ExecuteAwaited(Task task, HttpContext httpContext) private static Task ExecuteTaskOfString(Task task, HttpContext httpContext) { - SetPlaintextContentType(httpContext); EnsureRequestTaskNotNull(task); @@ -988,6 +1094,7 @@ private class FactoryContext public bool UsingTempSourceString { get; set; } public List ExtraLocals { get; } = new(); public List ParamCheckExpressions { get; } = new(); + public List>> ParameterBinders { get; } = new(); public Dictionary TrackedParameters { get; } = new(); public bool HasMultipleBodyParameters { get; set; } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index a8bd904420c5..25b5ffe6750e 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -499,6 +499,53 @@ public static bool TryParse(string? value, out MyTryParsableRecord? result) } } + private class MyBindAsyncTypeThatThrows + { + public static ValueTask BindAsync(HttpContext context) + { + throw new InvalidOperationException("BindAsync failed"); + } + } + + private record MyBindAsyncRecord(Uri Uri) + { + public static ValueTask BindAsync(HttpContext context) + { + if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri)) + { + return ValueTask.FromResult(null); + } + + return ValueTask.FromResult(new MyBindAsyncRecord(uri)); + } + + // TryParse(HttpContext, ...) should be preferred over TryParse(string, ...) if there's + // no [FromRoute] or [FromQuery] attributes. + public static bool TryParse(string? value, out MyBindAsyncRecord? result) + { + throw new NotImplementedException(); + } + } + + private record struct MyBindAsyncStruct(Uri Uri) + { + public static ValueTask BindAsync(HttpContext context) + { + if (!Uri.TryCreate(context.Request.Headers.Referer, UriKind.Absolute, out var uri)) + { + return ValueTask.FromResult(null); + } + + return ValueTask.FromResult(new MyBindAsyncStruct(uri)); + } + + // TryParse(HttpContext, ...) should be preferred over TryParse(string, ...) if there's + // no [FromRoute] or [FromQuery] attributes. + public static bool TryParse(string? value, out MyBindAsyncStruct result) => + throw new NotImplementedException(); + } + + [Theory] [MemberData(nameof(TryParsableParameters))] public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromRouteValue(Delegate action, string? routeValue, object? expectedParameterValue) @@ -560,6 +607,84 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR Assert.Equal(42, httpContext.Items["tryParsable"]); } + [Fact] + public async Task RequestDelegatePrefersBindAsyncOverTryParseString() + { + var httpContext = new DefaultHttpContext(); + + httpContext.Request.Headers.Referer = "https://example.org"; + + var requestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncRecord tryParsable) => + { + httpContext.Items["tryParsable"] = tryParsable; + }); + + await requestDelegate(httpContext); + + Assert.Equal(new MyBindAsyncRecord(new Uri("https://example.org")), httpContext.Items["tryParsable"]); + } + + [Fact] + public async Task RequestDelegatePrefersBindAsyncOverTryParseStringForNonNullableStruct() + { + var httpContext = new DefaultHttpContext(); + + httpContext.Request.Headers.Referer = "https://example.org"; + + var requestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncStruct tryParsable) => + { + httpContext.Items["tryParsable"] = tryParsable; + }); + + await requestDelegate(httpContext); + + Assert.Equal(new MyBindAsyncStruct(new Uri("https://example.org")), httpContext.Items["tryParsable"]); + } + + [Fact] + public async Task RequestDelegateUsesTryParseStringoOverBindAsyncGivenExplicitAttribute() + { + var fromRouteRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, [FromRoute] MyBindAsyncRecord tryParsable) => { }); + var fromQueryRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, [FromQuery] MyBindAsyncRecord tryParsable) => { }); + + var httpContext = new DefaultHttpContext + { + Request = + { + RouteValues = + { + ["tryParsable"] = "foo" + }, + Query = new QueryCollection(new Dictionary + { + ["tryParsable"] = "foo" + }), + }, + }; + + await Assert.ThrowsAsync(() => fromRouteRequestDelegate(httpContext)); + await Assert.ThrowsAsync(() => fromQueryRequestDelegate(httpContext)); + } + + [Fact] + public async Task RequestDelegateUsesTryParseStringOverBindAsyncGivenNullableStruct() + { + var fromRouteRequestDelegate = RequestDelegateFactory.Create((HttpContext httpContext, MyBindAsyncStruct? tryParsable) => { }); + + var httpContext = new DefaultHttpContext + { + Request = + { + RouteValues = + { + ["tryParsable"] = "foo" + }, + }, + }; + + await Assert.ThrowsAsync(() => fromRouteRequestDelegate(httpContext)); + } + public static object[][] DelegatesWithAttributesOnNotTryParsableParameters { get @@ -629,11 +754,169 @@ void TestAction([FromRoute] int tryParsable, [FromRoute] int tryParsable2) Assert.Equal(LogLevel.Debug, logs[0].LogLevel); Assert.Equal(@"Failed to bind parameter ""Int32 tryParsable"" from ""invalid!"".", logs[0].Message); - Assert.Equal(new EventId(3, "ParamaterBindingFailed"), logs[0].EventId); - Assert.Equal(LogLevel.Debug, logs[0].LogLevel); + Assert.Equal(new EventId(3, "ParamaterBindingFailed"), logs[1].EventId); + Assert.Equal(LogLevel.Debug, logs[1].LogLevel); Assert.Equal(@"Failed to bind parameter ""Int32 tryParsable2"" from ""invalid again!"".", logs[1].Message); } + [Fact] + public async Task RequestDelegateLogsBindAsyncFailuresAndSets400Response() + { + // Not supplying any headers will cause the HttpContext TryParse overload to fail. + var httpContext = new DefaultHttpContext() + { + RequestServices = new ServiceCollection().AddSingleton(LoggerFactory).BuildServiceProvider(), + }; + + var invoked = false; + + var requestDelegate = RequestDelegateFactory.Create((MyBindAsyncRecord arg1, MyBindAsyncRecord arg2) => + { + invoked = true; + }); + + await requestDelegate(httpContext); + + Assert.False(invoked); + Assert.False(httpContext.RequestAborted.IsCancellationRequested); + Assert.Equal(400, httpContext.Response.StatusCode); + + var logs = TestSink.Writes.ToArray(); + + Assert.Equal(2, logs.Length); + + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), logs[0].EventId); + Assert.Equal(LogLevel.Debug, logs[0].LogLevel); + Assert.Equal(@"Required parameter ""MyBindAsyncRecord arg1"" was not provided.", logs[0].Message); + + Assert.Equal(new EventId(4, "RequiredParameterNotProvided"), logs[1].EventId); + Assert.Equal(LogLevel.Debug, logs[1].LogLevel); + Assert.Equal(@"Required parameter ""MyBindAsyncRecord arg2"" was not provided.", logs[1].Message); + } + + [Fact] + public async Task BindAsyncExceptionsThrowException() + { + // Not supplying any headers will cause the HttpContext TryParse overload to fail. + var httpContext = new DefaultHttpContext() + { + RequestServices = new ServiceCollection().AddSingleton(LoggerFactory).BuildServiceProvider(), + }; + + var requestDelegate = RequestDelegateFactory.Create((MyBindAsyncTypeThatThrows arg1) => { }); + + var ex = await Assert.ThrowsAsync(() => requestDelegate(httpContext)); + Assert.Equal("BindAsync failed", ex.Message); + } + + [Fact] + public async Task BindAsyncWithBodyArgument() + { + Todo originalTodo = new() + { + Name = "Write more tests!" + }; + + var httpContext = new DefaultHttpContext(); + + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(originalTodo); + var stream = new MemoryStream(requestBodyBytes); ; + httpContext.Request.Body = stream; + + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + + var jsonOptions = new JsonOptions(); + jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); + + var mock = new Mock(); + mock.Setup(m => m.GetService(It.IsAny())).Returns(t => + { + if (t == typeof(IOptions)) + { + return Options.Create(jsonOptions); + } + return null; + }); + + httpContext.RequestServices = mock.Object; + httpContext.Request.Headers.Referer = "https://example.org"; + + var invoked = false; + + var requestDelegate = RequestDelegateFactory.Create((HttpContext context, MyBindAsyncRecord arg1, Todo todo) => + { + invoked = true; + context.Items[nameof(arg1)] = arg1; + context.Items[nameof(todo)] = todo; + }); + + await requestDelegate(httpContext); + + Assert.True(invoked); + var arg = httpContext.Items["arg1"] as MyBindAsyncRecord; + Assert.NotNull(arg); + Assert.Equal("https://example.org/", arg!.Uri.ToString()); + var todo = httpContext.Items["todo"] as Todo; + Assert.NotNull(todo); + Assert.Equal("Write more tests!", todo!.Name); + } + + [Fact] + public async Task BindAsyncRunsBeforeBodyBinding() + { + Todo originalTodo = new() + { + Name = "Write more tests!" + }; + + var httpContext = new DefaultHttpContext(); + + var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(originalTodo); + var stream = new MemoryStream(requestBodyBytes); ; + httpContext.Request.Body = stream; + + httpContext.Request.Headers["Content-Type"] = "application/json"; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + + var jsonOptions = new JsonOptions(); + jsonOptions.SerializerOptions.Converters.Add(new TodoJsonConverter()); + + var mock = new Mock(); + mock.Setup(m => m.GetService(It.IsAny())).Returns(t => + { + if (t == typeof(IOptions)) + { + return Options.Create(jsonOptions); + } + return null; + }); + + httpContext.RequestServices = mock.Object; + httpContext.Request.Headers.Referer = "https://example.org"; + + var invoked = false; + + var requestDelegate = RequestDelegateFactory.Create((HttpContext context, CustomTodo customTodo, Todo todo) => + { + invoked = true; + context.Items[nameof(customTodo)] = customTodo; + context.Items[nameof(todo)] = todo; + }); + + await requestDelegate(httpContext); + + Assert.True(invoked); + var todo0 = httpContext.Items["customTodo"] as Todo; + Assert.NotNull(todo0); + Assert.Equal("Write more tests!", todo0!.Name); + var todo1 = httpContext.Items["todo"] as Todo; + Assert.NotNull(todo1); + Assert.Equal("Write more tests!", todo1!.Name); + } + [Fact] public async Task RequestDelegatePopulatesFromQueryParameterBasedOnParameterName() { @@ -1669,6 +1952,26 @@ public async Task RequestDelegateHandlesBodyParamOptionality(Delegate @delegate, } } + [Fact] + public async Task RequestDelegateDoesSupportBindAsyncOptionality() + { + var httpContext = new DefaultHttpContext() + { + RequestServices = new ServiceCollection().AddSingleton(LoggerFactory).BuildServiceProvider(), + }; + + var invoked = false; + + var requestDelegate = RequestDelegateFactory.Create((MyBindAsyncRecord? arg1) => + { + invoked = true; + }); + + await requestDelegate(httpContext); + + Assert.True(invoked); + } + public static IEnumerable ServiceParamOptionalityData { get @@ -1843,6 +2146,16 @@ private class Todo : ITodo public bool IsComplete { get; set; } } + private class CustomTodo : Todo + { + public static async ValueTask BindAsync(HttpContext context) + { + var body = await context.Request.ReadFromJsonAsync(); + context.Request.Body.Position = 0; + return body; + } + } + private record struct TodoStruct(int Id, string? Name, bool IsComplete) : ITodo; private interface ITodo diff --git a/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs b/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs index 362388c4c577..0e86a6204cbe 100644 --- a/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs +++ b/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs @@ -137,6 +137,73 @@ public void FindTryParseMethodForEnumsWhenNonGenericEnumParseIsUsed() Assert.Equal(Choice.Three, parseEnum("Three")); } + [Fact] + public async Task FindBindAsyncMethod_FindsCorrectMethodOnClass() + { + var type = typeof(BindAsyncRecord); + var cache = new TryParseMethodCache(); + var methodFound = cache.FindBindAsyncMethod(type); + + Assert.NotNull(methodFound); + + var parsedValue = Expression.Variable(type, "parsedValue"); + var call = methodFound as MethodCallExpression; + Assert.NotNull(call); + var method = call!.Method; + var parameters = method.GetParameters(); + + Assert.Single(parameters); + Assert.Equal(typeof(HttpContext), parameters[0].ParameterType); + + var parseHttpContext = Expression.Lambda>>(Expression.Block(new[] { parsedValue }, + call), cache.HttpContextExpr).Compile(); + + var httpContext = new DefaultHttpContext + { + Request = + { + Headers = + { + ["ETag"] = "42", + }, + }, + }; + + Assert.Equal(new BindAsyncRecord(42), await parseHttpContext(httpContext)); + } + + public static IEnumerable BindAsyncParameterInfoData + { + get + { + return new[] + { + new[] + { + GetFirstParameter((BindAsyncRecord arg) => BindAsyncRecordMethod(arg)), + }, + new[] + { + GetFirstParameter((BindAsyncStruct arg) => BindAsyncStructMethod(arg)), + }, + }; + } + } + + [Theory] + [MemberData(nameof(BindAsyncParameterInfoData))] + public void HasBindAsyncMethod_ReturnsTrueWhenMethodExists(ParameterInfo parameterInfo) + { + Assert.True(new TryParseMethodCache().HasBindAsyncMethod(parameterInfo)); + } + + [Fact] + public void FindBindAsyncMethod_DoesNotFindMethodGivenNullableType() + { + var parameterInfo = GetFirstParameter((BindAsyncStruct? arg) => BindAsyncNullableStructMethod(arg)); + Assert.False(new TryParseMethodCache().HasBindAsyncMethod(parameterInfo)); + } + enum Choice { One, @@ -144,9 +211,24 @@ enum Choice Three } - private record TryParsableInvariantRecord(int value) + private static void TryParseStringRecordMethod(TryParseStringRecord arg) { } + private static void TryParseStringStructMethod(TryParseStringStruct arg) { } + private static void TryParseStringNullableStructMethod(TryParseStringStruct? arg) { } + + private static void BindAsyncRecordMethod(BindAsyncRecord arg) { } + private static void BindAsyncStructMethod(BindAsyncStruct arg) { } + private static void BindAsyncNullableStructMethod(BindAsyncStruct? arg) { } + + + private static ParameterInfo GetFirstParameter(Expression> expr) { - public static bool TryParse(string? value, IFormatProvider formatProvider, out TryParsableInvariantRecord? result) + var mc = (MethodCallExpression)expr.Body; + return mc.Method.GetParameters()[0]; + } + + private record TryParseStringRecord(int Value) + { + public static bool TryParse(string? value, IFormatProvider formatProvider, out TryParseStringRecord? result) { if (!int.TryParse(value, NumberStyles.Integer, formatProvider, out var val)) { @@ -154,10 +236,50 @@ public static bool TryParse(string? value, IFormatProvider formatProvider, out T return false; } - result = new TryParsableInvariantRecord(val); + result = new TryParseStringRecord(val); + return true; + } + } + + private record struct TryParseStringStruct(int Value) + { + public static bool TryParse(string? value, IFormatProvider formatProvider, out TryParsableInvariantRecord? result) + { + if (!int.TryParse(value, NumberStyles.Integer, formatProvider, out var val)) + { + result = default; + return false; + } + + result = new TryParseStringStruct(val); return true; } } + private record BindAsyncRecord(int Value) + { + public static ValueTask BindAsync(HttpContext context) + { + if (!int.TryParse(context.Request.Headers.ETag, out var val)) + { + return ValueTask.FromResult(null); + } + + return ValueTask.FromResult(new BindAsyncRecord(val)); + } + } + + private record struct BindAsyncStruct(int Value) + { + public static ValueTask BindAsync(HttpContext context) + { + if (!int.TryParse(context.Request.Headers.ETag, out var val)) + { + return ValueTask.FromResult(null); + } + + return ValueTask.FromResult(new BindAsyncRecord(val)); + } + } } } diff --git a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs index b27c729a747c..d6a2a515e129 100644 --- a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs +++ b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs @@ -184,6 +184,7 @@ private ApiDescription CreateApiDescription(RouteEndpoint routeEndpoint, string parameter.ParameterType == typeof(HttpResponse) || parameter.ParameterType == typeof(ClaimsPrincipal) || parameter.ParameterType == typeof(CancellationToken) || + TryParseMethodCache.HasBindAsyncMethod(parameter) || _serviceProviderIsService?.IsService(parameter.ParameterType) == true) { return (BindingSource.Services, parameter.Name ?? string.Empty, false); diff --git a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs index 3452974d8a2c..063199b39513 100644 --- a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs +++ b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs @@ -291,6 +291,7 @@ public void DoesNotAddFromServiceParameterAsService() Assert.Empty(GetApiDescription((HttpResponse response) => { }).ParameterDescriptions); Assert.Empty(GetApiDescription((ClaimsPrincipal user) => { }).ParameterDescriptions); Assert.Empty(GetApiDescription((CancellationToken token) => { }).ParameterDescriptions); + Assert.Empty(GetApiDescription((BindAsyncRecord context) => { }).ParameterDescriptions); } [Fact] @@ -665,5 +666,19 @@ public TestEndpointRouteBuilder(IApplicationBuilder applicationBuilder) public IServiceProvider ServiceProvider => ApplicationBuilder.ApplicationServices; } + + private record TryParseStringRecord(int Value) + { + public static bool TryParse(string value, out TryParseStringRecord result) => + throw new NotImplementedException(); + } + + private record BindAsyncRecord(int Value) + { + public static ValueTask BindAsync(HttpContext context) => + throw new NotImplementedException(); + public static bool TryParse(string value, out BindAsyncRecord result) => + throw new NotImplementedException(); + } } } diff --git a/src/Shared/TryParseMethodCache.cs b/src/Shared/TryParseMethodCache.cs index 7b3393f24821..875385fc0117 100644 --- a/src/Shared/TryParseMethodCache.cs +++ b/src/Shared/TryParseMethodCache.cs @@ -19,7 +19,8 @@ internal sealed class TryParseMethodCache private readonly MethodInfo _enumTryParseMethod; // Since this is shared source, the cache won't be shared between RequestDelegateFactory and the ApiDescriptionProvider sadly :( - private readonly ConcurrentDictionary?> _methodCallCache = new(); + private readonly ConcurrentDictionary?> _stringMethodCallCache = new(); + private readonly ConcurrentDictionary _bindAsyncMethodCallCache = new(); internal readonly ParameterExpression TempSourceStringExpr = Expression.Variable(typeof(string), "tempSourceString"); @@ -41,7 +42,10 @@ public bool HasTryParseMethod(ParameterInfo parameter) return FindTryParseMethod(nonNullableParameterType) is not null; } - public Func? FindTryParseMethod(Type type) + public bool HasBindAsyncMethod(ParameterInfo parameter) => + FindBindAsyncMethod(parameter.ParameterType) is not null; + + public Func? FindTryParseStringMethod(Type type) { Func? Finder(Type type) { @@ -117,7 +121,26 @@ public bool HasTryParseMethod(ParameterInfo parameter) return null; } - return _methodCallCache.GetOrAdd(type, Finder); + return _stringMethodCallCache.GetOrAdd(type, Finder); + } + + public Expression? FindBindAsyncMethod(Type type) + { + Expression? Finder(Type type) + { + var methodInfo = type.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext) }); + + // We're looking for a method with the following signature: + // static ValueTask BindAsync(HttpContext context) + if (methodInfo is not null && methodInfo.ReturnType == typeof(ValueTask)) + { + return Expression.Call(methodInfo, HttpContextExpr); + } + + return null; + } + + return _bindAsyncMethodCallCache.GetOrAdd(type, Finder); } private static MethodInfo GetEnumTryParseMethod(bool preferNonGenericEnumParseOverload) From e5450197621fe91a3bd85f2c7531187dac6dba5f Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 19 Aug 2021 17:38:34 -0700 Subject: [PATCH 2/3] Fixed bad merge --- .../src/RequestDelegateFactory.cs | 40 ++++++++++++------- .../EndpointMetadataApiDescriptionProvider.cs | 2 +- src/Shared/TryParseMethodCache.cs | 5 ++- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 3f4cd4cc79ec..94d04695dc99 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -35,18 +36,18 @@ public static partial class RequestDelegateFactory private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteWriteStringResponseAsync), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo JsonResultWriteResponseAsyncMethod = GetMethodInfo>((response, value) => HttpResponseJsonExtensions.WriteAsJsonAsync(response, value, default)); - private static readonly MethodInfo LogParameterBindingFailureMethod = GetMethodInfo>((httpContext, parameterType, parameterName, sourceValue) => - Log.ParameterBindingFailed(httpContext, parameterType, parameterName, sourceValue)); + private static readonly MethodInfo LogParameterBindingFailedMethod = GetMethodInfo>((httpContext, parameterType, parameterName, sourceValue) => + Log.ParameterBindingFailed(httpContext, parameterType, parameterName, sourceValue)); private static readonly MethodInfo LogRequiredParameterNotProvidedMethod = GetMethodInfo>((httpContext, parameterType, parameterName) => Log.RequiredParameterNotProvided(httpContext, parameterType, parameterName)); private static readonly ParameterExpression TargetExpr = Expression.Parameter(typeof(object), "target"); - private static readonly ParameterExpression HttpContextExpr = Expression.Parameter(typeof(HttpContext), "httpContext"); private static readonly ParameterExpression BodyValueExpr = Expression.Parameter(typeof(object), "bodyValue"); private static readonly ParameterExpression WasParamCheckFailureExpr = Expression.Variable(typeof(bool), "wasParamCheckFailure"); private static readonly ParameterExpression BoundValuesArrayExpr = Expression.Parameter(typeof(object[]), "boundValues"); + private static ParameterExpression HttpContextExpr => TryParseMethodCache.HttpContextExpr; private static readonly MemberExpression RequestServicesExpr = Expression.Property(HttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.RequestServices))!); private static readonly MemberExpression HttpRequestExpr = Expression.Property(HttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.Request))!); private static readonly MemberExpression HttpResponseExpr = Expression.Property(HttpContextExpr, typeof(HttpContext).GetProperty(nameof(HttpContext.Response))!); @@ -58,8 +59,9 @@ public static partial class RequestDelegateFactory private static readonly MemberExpression StatusCodeExpr = Expression.Property(HttpResponseExpr, typeof(HttpResponse).GetProperty(nameof(HttpResponse.StatusCode))!); private static readonly MemberExpression CompletedTaskExpr = Expression.Property(null, (PropertyInfo)GetMemberInfo>(() => Task.CompletedTask)); - private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TryParseMethodCache.TempSourceStringExpr, Expression.Constant(null)); - private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TryParseMethodCache.TempSourceStringExpr, Expression.Constant(null)); + private static ParameterExpression TempSourceStringExpr => TryParseMethodCache.TempSourceStringExpr; + private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null)); + private static readonly BinaryExpression TempSourceStringNullExpr = Expression.Equal(TempSourceStringExpr, Expression.Constant(null)); /// /// Creates a implementation for . @@ -171,7 +173,7 @@ public static RequestDelegate Create(MethodInfo methodInfo, Func ParameterBindingFailed(GetLogger(httpContext), parameterTypeName, parameterName, sourceValue); - public static void RequiredParameterNotProvided(HttpContext httpContext, string parameterTypeName, string parameterName) - => RequiredParameterNotProvided(GetLogger(httpContext), parameterTypeName, parameterName); - [LoggerMessage(3, LogLevel.Debug, @"Failed to bind parameter ""{ParameterType} {ParameterName}"" from ""{SourceValue}"".", EventName = "ParamaterBindingFailed")] private static partial void ParameterBindingFailed(ILogger logger, string parameterType, string parameterName, string sourceValue); + public static void RequiredParameterNotProvided(HttpContext httpContext, string parameterTypeName, string parameterName) + => RequiredParameterNotProvided(GetLogger(httpContext), parameterTypeName, parameterName); + [LoggerMessage(4, LogLevel.Debug, @"Required parameter ""{ParameterType} {ParameterName}"" was not provided.", EventName = "RequiredParameterNotProvided")] private static partial void RequiredParameterNotProvided(ILogger logger, string parameterType, string parameterName); + public static void ParameterBindingFromHttpContextFailed(HttpContext httpContext, string parameterTypeName, string parameterName) + => ParameterBindingFromHttpContextFailed(GetLogger(httpContext), parameterTypeName, parameterName); + + [LoggerMessage(5, LogLevel.Debug, + @"Failed to bind parameter ""{ParameterType} {ParameterName}"" from HttpContext.", + EventName = "ParameterBindingFromHttpContextFailed")] + private static partial void ParameterBindingFromHttpContextFailed(ILogger logger, string parameterType, string parameterName); + private static ILogger GetLogger(HttpContext httpContext) { var loggerFactory = httpContext.RequestServices.GetRequiredService(); diff --git a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs index d6a2a515e129..815d1b269259 100644 --- a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs +++ b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs @@ -189,7 +189,7 @@ private ApiDescription CreateApiDescription(RouteEndpoint routeEndpoint, string { return (BindingSource.Services, parameter.Name ?? string.Empty, false); } - else if (parameter.ParameterType == typeof(string) || TryParseMethodCache.HasTryParseMethod(parameter)) + else if (parameter.ParameterType == typeof(string) || TryParseMethodCache.HasTryParseStringMethod(parameter)) { // Path vs query cannot be determined by RequestDelegateFactory at startup currently because of the layering, but can be done here. if (parameter.Name is { } name && pattern.GetParameter(name) is not null) diff --git a/src/Shared/TryParseMethodCache.cs b/src/Shared/TryParseMethodCache.cs index 875385fc0117..91aff2f4a6cc 100644 --- a/src/Shared/TryParseMethodCache.cs +++ b/src/Shared/TryParseMethodCache.cs @@ -23,6 +23,7 @@ internal sealed class TryParseMethodCache private readonly ConcurrentDictionary _bindAsyncMethodCallCache = new(); internal readonly ParameterExpression TempSourceStringExpr = Expression.Variable(typeof(string), "tempSourceString"); + internal readonly ParameterExpression HttpContextExpr = Expression.Parameter(typeof(HttpContext), "httpContext"); // If IsDynamicCodeSupported is false, we can't use the static Enum.TryParse since there's no easy way for // this code to generate the specific instantiation for any enums used @@ -36,10 +37,10 @@ public TryParseMethodCache(bool preferNonGenericEnumParseOverload) _enumTryParseMethod = GetEnumTryParseMethod(preferNonGenericEnumParseOverload); } - public bool HasTryParseMethod(ParameterInfo parameter) + public bool HasTryParseStringMethod(ParameterInfo parameter) { var nonNullableParameterType = Nullable.GetUnderlyingType(parameter.ParameterType) ?? parameter.ParameterType; - return FindTryParseMethod(nonNullableParameterType) is not null; + return FindTryParseStringMethod(nonNullableParameterType) is not null; } public bool HasBindAsyncMethod(ParameterInfo parameter) => From 56d72161f4c93c71f9d5e7edba507195da006935 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 19 Aug 2021 18:30:49 -0700 Subject: [PATCH 3/3] More fixes --- .../test/TryParseMethodCacheTests.cs | 56 ++++++++++++++----- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs b/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs index 0e86a6204cbe..85b8792f7661 100644 --- a/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs +++ b/src/Http/Http.Extensions/test/TryParseMethodCacheTests.cs @@ -6,7 +6,7 @@ using System; using System.Globalization; using System.Linq.Expressions; - +using System.Reflection; using Xunit; namespace Microsoft.AspNetCore.Http.Extensions.Tests @@ -25,9 +25,9 @@ public class TryParseMethodCacheTests [InlineData(typeof(ushort))] [InlineData(typeof(uint))] [InlineData(typeof(ulong))] - public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCulture(Type @type) + public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvariantCulture(Type type) { - var methodFound = new TryParseMethodCache().FindTryParseMethod(@type); + var methodFound = new TryParseMethodCache().FindTryParseStringMethod(@type); Assert.NotNull(methodFound); @@ -48,9 +48,9 @@ public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCult [InlineData(typeof(DateTimeOffset))] [InlineData(typeof(TimeOnly))] [InlineData(typeof(TimeSpan))] - public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCultureDateType(Type @type) + public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvariantCultureDateType(Type type) { - var methodFound = new TryParseMethodCache().FindTryParseMethod(@type); + var methodFound = new TryParseMethodCache().FindTryParseStringMethod(@type); Assert.NotNull(methodFound); @@ -76,10 +76,11 @@ public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCult } [Theory] - [InlineData(typeof(TryParsableInvariantRecord))] - public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCultureCustomType(Type @type) + [InlineData(typeof(TryParseStringRecord))] + [InlineData(typeof(TryParseStringStruct))] + public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvariantCultureCustomType(Type type) { - var methodFound = new TryParseMethodCache().FindTryParseMethod(@type); + var methodFound = new TryParseMethodCache().FindTryParseStringMethod(@type); Assert.NotNull(methodFound); @@ -94,11 +95,40 @@ public void FindTryParseMethod_ReturnsTheExpectedTryParseMethodWithInvariantCult Assert.True(((call.Arguments[1] as ConstantExpression)!.Value as CultureInfo)!.Equals(CultureInfo.InvariantCulture)); } + public static IEnumerable TryParseStringParameterInfoData + { + get + { + return new[] + { + new[] + { + GetFirstParameter((TryParseStringRecord arg) => TryParseStringRecordMethod(arg)), + }, + new[] + { + GetFirstParameter((TryParseStringStruct arg) => TryParseStringStructMethod(arg)), + }, + new[] + { + GetFirstParameter((TryParseStringStruct? arg) => TryParseStringNullableStructMethod(arg)), + }, + }; + } + } + + [Theory] + [MemberData(nameof(TryParseStringParameterInfoData))] + public void HasTryParseStringMethod_ReturnsTrueWhenMethodExists(ParameterInfo parameterInfo) + { + Assert.True(new TryParseMethodCache().HasTryParseStringMethod(parameterInfo)); + } + [Fact] - public void FindTryParseMethodForEnums() + public void FindTryParseStringMethod_WorksForEnums() { var type = typeof(Choice); - var methodFound = new TryParseMethodCache().FindTryParseMethod(type); + var methodFound = new TryParseMethodCache().FindTryParseStringMethod(type); Assert.NotNull(methodFound); @@ -115,11 +145,11 @@ public void FindTryParseMethodForEnums() } [Fact] - public void FindTryParseMethodForEnumsWhenNonGenericEnumParseIsUsed() + public void FindTryParseStringMethod_WorksForEnumsWhenNonGenericEnumParseIsUsed() { var type = typeof(Choice); var cache = new TryParseMethodCache(preferNonGenericEnumParseOverload: true); - var methodFound = cache.FindTryParseMethod(type); + var methodFound = cache.FindTryParseStringMethod(type); Assert.NotNull(methodFound); @@ -243,7 +273,7 @@ public static bool TryParse(string? value, IFormatProvider formatProvider, out T private record struct TryParseStringStruct(int Value) { - public static bool TryParse(string? value, IFormatProvider formatProvider, out TryParsableInvariantRecord? result) + public static bool TryParse(string? value, IFormatProvider formatProvider, out TryParseStringStruct result) { if (!int.TryParse(value, NumberStyles.Integer, formatProvider, out var val)) {