diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index e45c415db021..7a742a79dbbe 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Metadata; @@ -45,6 +46,7 @@ public static class RequestDelegateFactory private static readonly MemberExpression HttpRequestExpr = Expression.Property(HttpContextExpr, nameof(HttpContext.Request)); private static readonly MemberExpression HttpResponseExpr = Expression.Property(HttpContextExpr, nameof(HttpContext.Response)); private static readonly MemberExpression RequestAbortedExpr = Expression.Property(HttpContextExpr, nameof(HttpContext.RequestAborted)); + private static readonly MemberExpression UserExpr = Expression.Property(HttpContextExpr, nameof(HttpContext.User)); private static readonly MemberExpression RouteValuesExpr = Expression.Property(HttpRequestExpr, nameof(HttpRequest.RouteValues)); private static readonly MemberExpression QueryExpr = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Query)); private static readonly MemberExpression HeadersExpr = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Headers)); @@ -221,6 +223,18 @@ private static Expression CreateArgument(ParameterInfo parameter, FactoryContext { return HttpContextExpr; } + else if (parameter.ParameterType == typeof(HttpRequest)) + { + return HttpRequestExpr; + } + else if (parameter.ParameterType == typeof(HttpResponse)) + { + return HttpResponseExpr; + } + else if (parameter.ParameterType == typeof(ClaimsPrincipal)) + { + return UserExpr; + } else if (parameter.ParameterType == typeof(CancellationToken)) { return RequestAbortedExpr; diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 73c73cb6378d..2a890dfd11db 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -14,6 +14,7 @@ using System.Numerics; using System.Reflection; using System.Reflection.Metadata; +using System.Security.Claims; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -96,7 +97,7 @@ public async Task RequestDelegateInvokesAction(Delegate @delegate) { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -116,7 +117,7 @@ public async Task StaticMethodInfoOverloadWorksWithBasicReflection() BindingFlags.NonPublic | BindingFlags.Static, new[] { typeof(HttpContext) }); - var requestDelegate = RequestDelegateFactory.Create(methodInfo!, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(methodInfo!, new EmptyServiceProvider()); var httpContext = new DefaultHttpContext(); @@ -162,7 +163,7 @@ object GetTarget() return new TestNonStaticActionClass(2); } - var requestDelegate = RequestDelegateFactory.Create(methodInfo!, new EmptyServiceProvdier(), _ => GetTarget()); + var requestDelegate = RequestDelegateFactory.Create(methodInfo!, new EmptyServiceProvider(), _ => GetTarget()); var httpContext = new DefaultHttpContext(); @@ -185,7 +186,7 @@ public void BuildRequestDelegateThrowsArgumentNullExceptions() BindingFlags.NonPublic | BindingFlags.Static, new[] { typeof(HttpContext) }); - var serviceProvider = new EmptyServiceProvdier(); + var serviceProvider = new EmptyServiceProvider(); var exNullAction = Assert.Throws(() => RequestDelegateFactory.Create(action: null!, serviceProvider)); var exNullMethodInfo1 = Assert.Throws(() => RequestDelegateFactory.Create(methodInfo: null!, serviceProvider)); @@ -204,7 +205,7 @@ public async Task RequestDelegatePopulatesFromRouteParameterBasedOnParameterName const string paramName = "value"; const int originalRouteParam = 42; - void TestAction(HttpContext httpContext, [FromRoute] int value) + static void TestAction(HttpContext httpContext, [FromRoute] int value) { httpContext.Items.Add("input", value); } @@ -212,7 +213,7 @@ void TestAction(HttpContext httpContext, [FromRoute] int value) var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -239,7 +240,7 @@ public async Task RequestDelegatePopulatesFromRouteOptionalParameter() { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((Action)TestOptional, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestOptional, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -251,7 +252,7 @@ public async Task RequestDelegatePopulatesFromNullableOptionalParameter() { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((Action)TestOptional, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestOptional, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -263,7 +264,7 @@ public async Task RequestDelegatePopulatesFromOptionalStringParameter() { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((Action)TestOptionalString, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestOptionalString, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -280,7 +281,7 @@ public async Task RequestDelegatePopulatesFromRouteOptionalParameterBasedOnParam httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action)TestOptional, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestOptional, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -303,7 +304,7 @@ void TestAction([FromRoute(Name = specifiedName)] int foo) var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[specifiedName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -326,7 +327,7 @@ void TestAction([FromRoute] int foo) var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[unmatchedName] = unmatchedRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -405,7 +406,7 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues["tryParsable"] = routeValue; - var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -422,7 +423,7 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromQ ["tryParsable"] = routeValue }); - var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -445,7 +446,7 @@ public async Task RequestDelegatePopulatesUnattributedTryParsableParametersFromR { httpContext.Items["tryParsable"] = tryParsable; }), - new EmptyServiceProvdier()); + new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -473,7 +474,7 @@ void InvalidFromHeader([FromHeader] object notTryParsable) { } [MemberData(nameof(DelegatesWithAttributesOnNotTryParsableParameters))] public void CreateThrowsInvalidOperationExceptionWhenAttributeRequiresTryParseMethodThatDoesNotExist(Delegate action) { - var ex = Assert.Throws(() => RequestDelegateFactory.Create(action, new EmptyServiceProvdier())); + var ex = Assert.Throws(() => RequestDelegateFactory.Create(action, new EmptyServiceProvider())); Assert.Equal("No public static bool Object.TryParse(string, out Object) method found for notTryParsable.", ex.Message); } @@ -482,7 +483,7 @@ public void CreateThrowsInvalidOperationExceptionGivenUnnamedArgument() { var unnamedParameter = Expression.Parameter(typeof(int)); var lambda = Expression.Lambda(Expression.Block(), unnamedParameter); - var ex = Assert.Throws(() => RequestDelegateFactory.Create((Action)lambda.Compile(), new EmptyServiceProvdier())); + var ex = Assert.Throws(() => RequestDelegateFactory.Create((Action)lambda.Compile(), new EmptyServiceProvider())); Assert.Equal("A parameter does not have a name! Was it generated? All parameters must be named.", ex.Message); } @@ -505,7 +506,7 @@ void TestAction([FromRoute] int tryParsable, [FromRoute] int tryParsable2) httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -547,7 +548,7 @@ void TestAction([FromQuery] int value) var httpContext = new DefaultHttpContext(); httpContext.Request.Query = query; - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -570,7 +571,7 @@ void TestAction([FromHeader(Name = customHeaderName)] int value) var httpContext = new DefaultHttpContext(); httpContext.Request.Headers[customHeaderName] = originalHeaderParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -634,7 +635,7 @@ public async Task RequestDelegatePopulatesFromBodyParameter(Delegate action) }); httpContext.RequestServices = mock.Object; - var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -651,7 +652,7 @@ public async Task RequestDelegateRejectsEmptyBodyGivenFromBodyParameter(Delegate httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; - var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvider()); await Assert.ThrowsAsync(() => requestDelegate(httpContext)); } @@ -670,7 +671,7 @@ void TestAction([FromBody(AllowEmpty = true)] Todo todo) httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -694,7 +695,7 @@ void TestAction([FromBody(AllowEmpty = true)] BodyStruct bodyStruct) httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -721,7 +722,7 @@ void TestAction([FromBody] Todo todo) httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -754,7 +755,7 @@ void TestAction([FromBody] Todo todo) httpContext.Features.Set(new TestHttpRequestLifetimeFeature()); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -775,9 +776,9 @@ void TestAttributedInvalidAction([FromBody] int value1, [FromBody] int value2) { void TestInferredInvalidAction(Todo value1, Todo value2) { } void TestBothInvalidAction(Todo value1, [FromBody] int value2) { } - Assert.Throws(() => RequestDelegateFactory.Create((Action)TestAttributedInvalidAction, new EmptyServiceProvdier())); - Assert.Throws(() => RequestDelegateFactory.Create((Action)TestInferredInvalidAction, new EmptyServiceProvdier())); - Assert.Throws(() => RequestDelegateFactory.Create((Action)TestBothInvalidAction, new EmptyServiceProvdier())); + Assert.Throws(() => RequestDelegateFactory.Create((Action)TestAttributedInvalidAction, new EmptyServiceProvider())); + Assert.Throws(() => RequestDelegateFactory.Create((Action)TestInferredInvalidAction, new EmptyServiceProvider())); + Assert.Throws(() => RequestDelegateFactory.Create((Action)TestBothInvalidAction, new EmptyServiceProvider())); } public static object[][] FromServiceActions @@ -849,9 +850,9 @@ public async Task RequestDelegatePopulatesParametersFromServiceWithAndWithoutAtt public async Task RequestDelegateRequiresServiceForAllFromServiceParameters(Delegate action) { var httpContext = new DefaultHttpContext(); - httpContext.RequestServices = new EmptyServiceProvdier(); + httpContext.RequestServices = new EmptyServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvider()); await Assert.ThrowsAsync(() => requestDelegate(httpContext)); } @@ -868,7 +869,7 @@ void TestAction(HttpContext httpContext) var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -876,7 +877,7 @@ void TestAction(HttpContext httpContext) } [Fact] - public async Task RequestDelegatePassHttpContextRequestAbortedAsCancelationToken() + public async Task RequestDelegatePassHttpContextRequestAbortedAsCancellationToken() { CancellationToken? cancellationTokenArgument = null; @@ -891,13 +892,73 @@ void TestAction(CancellationToken cancellationToken) RequestAborted = cts.Token }; - var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); await requestDelegate(httpContext); Assert.Equal(httpContext.RequestAborted, cancellationTokenArgument); } + [Fact] + public async Task RequestDelegatePassHttpContextUserAsClaimsPrincipal() + { + ClaimsPrincipal? userArgument = null; + + void TestAction(ClaimsPrincipal user) + { + userArgument = user; + } + + var httpContext = new DefaultHttpContext + { + User = new ClaimsPrincipal() + }; + + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); + + await requestDelegate(httpContext); + + Assert.Equal(httpContext.User, userArgument); + } + + [Fact] + public async Task RequestDelegatePassHttpContextRequestAsHttpRequest() + { + HttpRequest? httpRequestArgument = null; + + void TestAction(HttpRequest httpRequest) + { + httpRequestArgument = httpRequest; + } + + var httpContext = new DefaultHttpContext(); + + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); + + await requestDelegate(httpContext); + + Assert.Equal(httpContext.Request, httpRequestArgument); + } + + [Fact] + public async Task RequestDelegatePassesHttpContextRresponseAsHttpResponse() + { + HttpResponse? httpResponseArgument = null; + + void TestAction(HttpResponse httpResponse) + { + httpResponseArgument = httpResponse; + } + + var httpContext = new DefaultHttpContext(); + + var requestDelegate = RequestDelegateFactory.Create((Action)TestAction, new EmptyServiceProvider()); + + await requestDelegate(httpContext); + + Assert.Equal(httpContext.Response, httpResponseArgument); + } + public static IEnumerable ComplexResult { get @@ -935,7 +996,7 @@ public async Task RequestDelegateWritesComplexReturnValueAsJsonResponseBody(Dele var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -984,7 +1045,7 @@ public async Task RequestDelegateUsesCustomIResult(Delegate @delegate) var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -1027,7 +1088,7 @@ public async Task RequestDelegateWritesStringReturnValueAsJsonResponseBody(Deleg var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -1068,7 +1129,7 @@ public async Task RequestDelegateWritesIntReturnValue(Delegate @delegate) var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -1109,7 +1170,7 @@ public async Task RequestDelegateWritesBoolReturnValue(Delegate @delegate) var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvider()); await requestDelegate(httpContext); @@ -1227,7 +1288,7 @@ class TodoJsonConverter : JsonConverter break; } - string property = reader.GetString()!; + var property = reader.GetString()!; reader.Read(); switch (property.ToLowerInvariant()) @@ -1352,13 +1413,13 @@ public override void Write(byte[] buffer, int offset, int count) } } - private class EmptyServiceProvdier : IServiceScope, IServiceProvider, IServiceScopeFactory + private class EmptyServiceProvider : IServiceScope, IServiceProvider, IServiceScopeFactory { public IServiceProvider ServiceProvider => this; public IServiceScope CreateScope() { - return new EmptyServiceProvdier(); + return new EmptyServiceProvider(); } public void Dispose() @@ -1378,7 +1439,7 @@ public void Dispose() private class TestHttpRequestLifetimeFeature : IHttpRequestLifetimeFeature { - private readonly CancellationTokenSource _requestAbortedCts = new CancellationTokenSource(); + private readonly CancellationTokenSource _requestAbortedCts = new(); public CancellationToken RequestAborted { get => _requestAbortedCts.Token; set => throw new NotImplementedException(); }