diff --git a/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs b/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs index 57df73b27d27..852478e751bc 100644 --- a/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs +++ b/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs @@ -77,6 +77,7 @@ public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvaria [Theory] [InlineData(typeof(TryParseStringRecord))] [InlineData(typeof(TryParseStringStruct))] + [InlineData(typeof(TryParseInheritClassWithFormatProvider))] public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvariantCultureCustomType(Type type) { var methodFound = new ParameterBindingMethodCache().FindTryParseMethod(@type); @@ -94,6 +95,24 @@ public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvaria Assert.True(((call.Arguments[1] as ConstantExpression)!.Value as CultureInfo)!.Equals(CultureInfo.InvariantCulture)); } + [Theory] + [InlineData(typeof(TryParseNoFormatProviderRecord))] + [InlineData(typeof(TryParseNoFormatProviderStruct))] + [InlineData(typeof(TryParseInheritClass))] + public void FindTryParseMethod_WithNoFormatProvider(Type type) + { + var methodFound = new ParameterBindingMethodCache().FindTryParseMethod(@type); + Assert.NotNull(methodFound); + + var call = methodFound!(Expression.Variable(type, "parsedValue")) as MethodCallExpression; + Assert.NotNull(call); + var parameters = call!.Method.GetParameters(); + + Assert.Equal(2, parameters.Length); + Assert.Equal(typeof(string), parameters[0].ParameterType); + Assert.True(parameters[1].IsOut); + } + public static IEnumerable TryParseStringParameterInfoData { get @@ -249,6 +268,14 @@ public static IEnumerable BindAsyncParameterInfoData new[] { GetFirstParameter((BindAsyncSingleArgStruct arg) => BindAsyncSingleArgStructMethod(arg)), + }, + new[] + { + GetFirstParameter((InheritBindAsync arg) => InheritBindAsyncMethod(arg)) + }, + new[] + { + GetFirstParameter((InheritBindAsyncWithParameterInfo arg) => InheritBindAsyncWithParameterInfoMethod(arg)) } }; } @@ -285,6 +312,7 @@ public void FindBindAsyncMethod_FindsNonNullableReturningBindAsyncMethodGivenNul [InlineData(typeof(InvalidTooFewArgsTryParseClass))] [InlineData(typeof(InvalidNonStaticTryParseStruct))] [InlineData(typeof(InvalidNonStaticTryParseClass))] + [InlineData(typeof(TryParseWrongTypeInheritClass))] public void FindTryParseMethod_ThrowsIfInvalidTryParseOnType(Type type) { var ex = Assert.Throws( @@ -308,6 +336,8 @@ public void FindTryParseMethod_IgnoresInvalidTryParseIfGoodOneFound(Type type) [InlineData(typeof(InvalidWrongReturnBindAsyncClass))] [InlineData(typeof(InvalidWrongParamBindAsyncStruct))] [InlineData(typeof(InvalidWrongParamBindAsyncClass))] + [InlineData(typeof(BindAsyncWrongTypeInherit))] + [InlineData(typeof(BindAsyncWithParameterInfoWrongTypeInherit))] public void FindBindAsyncMethod_ThrowsIfInvalidBindAsyncOnType(Type type) { var cache = new ParameterBindingMethodCache(); @@ -350,6 +380,8 @@ private static void NullableReturningBindAsyncStructMethod(NullableReturningBind private static void BindAsyncSingleArgRecordMethod(BindAsyncSingleArgRecord arg) { } private static void BindAsyncSingleArgStructMethod(BindAsyncSingleArgStruct arg) { } + private static void InheritBindAsyncMethod(InheritBindAsync arg) { } + private static void InheritBindAsyncWithParameterInfoMethod(InheritBindAsyncWithParameterInfo args) { } private static ParameterInfo GetFirstParameter(Expression> expr) { @@ -538,6 +570,67 @@ public bool TryParse(string? value, IFormatProvider formatProvider, out InvalidN } } + private record TryParseNoFormatProviderRecord(int Value) + { + public static bool TryParse(string? value, out TryParseNoFormatProviderRecord? result) + { + if (!int.TryParse(value, out var val)) + { + result = null; + return false; + } + + result = new TryParseNoFormatProviderRecord(val); + return true; + } + } + + private record struct TryParseNoFormatProviderStruct(int Value) + { + public static bool TryParse(string? value, out TryParseNoFormatProviderStruct result) + { + if (!int.TryParse(value, out var val)) + { + result = default; + return false; + } + + result = new TryParseNoFormatProviderStruct(val); + return true; + } + } + + private class BaseTryParseClass + { + public static bool TryParse(string? value, out T? result) + { + result = default(T); + return false; + } + } + + private class TryParseInheritClass : BaseTryParseClass + { + } + + // using wrong T on purpose + private class TryParseWrongTypeInheritClass : BaseTryParseClass + { + } + + private class BaseTryParseClassWithFormatProvider + { + public static bool TryParse(string? value, IFormatProvider formatProvider, out T? result) + { + result = default(T); + return false; + } + } + + private class TryParseInheritClassWithFormatProvider : BaseTryParseClassWithFormatProvider + { + } + private record BindAsyncRecord(int Value) { public static ValueTask BindAsync(HttpContext context, ParameterInfo parameter) @@ -644,6 +737,40 @@ public static ValueTask BindAsync(ParameterInfo pa throw new NotImplementedException(); } + private class BaseBindAsync + { + public static ValueTask BindAsync(HttpContext context) + { + return new(default(T)); + } + } + + private class InheritBindAsync : BaseBindAsync + { + } + + // Using wrong T on purpose + private class BindAsyncWrongTypeInherit : BaseBindAsync + { + } + + private class BaseBindAsyncWithParameterInfo + { + public static ValueTask BindAsync(HttpContext context, ParameterInfo parameter) + { + return new(default(T)); + } + } + + private class InheritBindAsyncWithParameterInfo : BaseBindAsyncWithParameterInfo + { + } + + // Using wrong T on purpose + private class BindAsyncWithParameterInfoWrongTypeInherit : BaseBindAsyncWithParameterInfo + { + } + private class MockParameterInfo : ParameterInfo { public MockParameterInfo(Type type, string name) diff --git a/src/Shared/ParameterBindingMethodCache.cs b/src/Shared/ParameterBindingMethodCache.cs index 1f7f304188f5..be7f48c77f2f 100644 --- a/src/Shared/ParameterBindingMethodCache.cs +++ b/src/Shared/ParameterBindingMethodCache.cs @@ -106,7 +106,7 @@ public bool HasBindAsyncMethod(ParameterInfo parameter) => expression); } - methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static, new[] { typeof(string), typeof(IFormatProvider), type.MakeByRefType() }); + methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(string), typeof(IFormatProvider), type.MakeByRefType() }); if (methodInfo is not null && methodInfo.ReturnType == typeof(bool)) { @@ -117,14 +117,14 @@ public bool HasBindAsyncMethod(ParameterInfo parameter) => expression); } - methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static, new[] { typeof(string), type.MakeByRefType() }); + methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(string), type.MakeByRefType() }); if (methodInfo is not null && methodInfo.ReturnType == typeof(bool)) { return (expression) => Expression.Call(methodInfo, TempSourceStringExpr, expression); } - if (type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance) is MethodInfo invalidMethod) + if (type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance | BindingFlags.FlattenHierarchy) is MethodInfo invalidMethod) { var stringBuilder = new StringBuilder(); stringBuilder.AppendLine(CultureInfo.InvariantCulture, $"TryParse method found on {TypeNameHelper.GetTypeDisplayName(type, fullName: false)} with incorrect format. Must be a static method with format"); @@ -149,11 +149,11 @@ public bool HasBindAsyncMethod(ParameterInfo parameter) => { var hasParameterInfo = true; // There should only be one BindAsync method with these parameters since C# does not allow overloading on return type. - var methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext), typeof(ParameterInfo) }); + var methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(HttpContext), typeof(ParameterInfo) }); if (methodInfo is null) { hasParameterInfo = false; - methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext) }); + methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(HttpContext) }); } // We're looking for a method with the following signatures: @@ -207,7 +207,7 @@ public bool HasBindAsyncMethod(ParameterInfo parameter) => } } - if (nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance) is MethodInfo invalidBindMethod) + if (nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance | BindingFlags.FlattenHierarchy) is MethodInfo invalidBindMethod) { var stringBuilder = new StringBuilder(); stringBuilder.AppendLine(CultureInfo.InvariantCulture, $"BindAsync method found on {TypeNameHelper.GetTypeDisplayName(nonNullableParameterType, fullName: false)} with incorrect format. Must be a static method with format");