From 2bac3016454cc8689fdd081d434bf67c94cbff15 Mon Sep 17 00:00:00 2001 From: John Luo Date: Wed, 30 Nov 2016 13:18:28 -0800 Subject: [PATCH] Always add IResponseCachingFeature before calling the next middleware --- .../ResponseCachingMiddleware.cs | 32 ++++++++--- .../ResponseCachingMiddlewareTests.cs | 53 +++++++++++++++++-- .../TestUtils.cs | 7 ++- 3 files changed, 79 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.AspNetCore.ResponseCaching/ResponseCachingMiddleware.cs b/src/Microsoft.AspNetCore.ResponseCaching/ResponseCachingMiddleware.cs index dc01df4..2ae5ef0 100644 --- a/src/Microsoft.AspNetCore.ResponseCaching/ResponseCachingMiddleware.cs +++ b/src/Microsoft.AspNetCore.ResponseCaching/ResponseCachingMiddleware.cs @@ -106,7 +106,17 @@ public async Task Invoke(HttpContext httpContext) } else { - await _next(httpContext); + // Add IResponseCachingFeature which may be required when the response is generated + AddResponseCachingFeature(httpContext); + + try + { + await _next(httpContext); + } + finally + { + RemoveResponseCachingFeature(httpContext); + } } } @@ -318,6 +328,15 @@ internal Task OnResponseStartingAsync(ResponseCachingContext context) } } + internal static void AddResponseCachingFeature(HttpContext context) + { + if (context.Features.Get() != null) + { + throw new InvalidOperationException($"Another instance of {nameof(ResponseCachingFeature)} already exists. Only one instance of {nameof(ResponseCachingMiddleware)} can be configured for an application."); + } + context.Features.Set(new ResponseCachingFeature()); + } + internal void ShimResponseStream(ResponseCachingContext context) { // Shim response stream @@ -333,13 +352,12 @@ internal void ShimResponseStream(ResponseCachingContext context) } // Add IResponseCachingFeature - if (context.HttpContext.Features.Get() != null) - { - throw new InvalidOperationException($"Another instance of {nameof(ResponseCachingFeature)} already exists. Only one instance of {nameof(ResponseCachingMiddleware)} can be configured for an application."); - } - context.HttpContext.Features.Set(new ResponseCachingFeature()); + AddResponseCachingFeature(context.HttpContext); } + internal static void RemoveResponseCachingFeature(HttpContext context) => + context.Features.Set(null); + internal static void UnshimResponseStream(ResponseCachingContext context) { // Unshim response stream @@ -349,7 +367,7 @@ internal static void UnshimResponseStream(ResponseCachingContext context) context.HttpContext.Features.Set(context.OriginalSendFileFeature); // Remove IResponseCachingFeature - context.HttpContext.Features.Set(null); + RemoveResponseCachingFeature(context.HttpContext); } internal static bool ContentIsNotModified(ResponseCachingContext context) diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingMiddlewareTests.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingMiddlewareTests.cs index fb74241..f81bd25 100644 --- a/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingMiddlewareTests.cs +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/ResponseCachingMiddlewareTests.cs @@ -5,8 +5,10 @@ using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Headers; using Microsoft.AspNetCore.ResponseCaching.Internal; +using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging.Testing; using Microsoft.Extensions.Primitives; using Microsoft.Net.Http.Headers; @@ -715,16 +717,57 @@ public async Task FinalizeCacheBody_DoNotCache_IfBufferingDisabled() } [Fact] - public void ShimResponseStream_SecondInvocation_Throws() + public void AddResponseCachingFeature_SecondInvocation_Throws() { - var middleware = TestUtils.CreateTestMiddleware(); - var context = TestUtils.CreateTestContext(); + var httpContext = new DefaultHttpContext(); // Should not throw - middleware.ShimResponseStream(context); + ResponseCachingMiddleware.AddResponseCachingFeature(httpContext); // Should throw - Assert.ThrowsAny(() => middleware.ShimResponseStream(context)); + Assert.ThrowsAny(() => ResponseCachingMiddleware.AddResponseCachingFeature(httpContext)); + } + + private class FakeResponseFeature : HttpResponseFeature + { + public override void OnStarting(Func callback, object state) { } + } + + [Fact] + public async Task Invoke_CacheableRequest_AddsResponseCachingFeature() + { + var responseCachingFeatureAdded = false; + var middleware = TestUtils.CreateTestMiddleware(next: httpContext => + { + responseCachingFeatureAdded = httpContext.Features.Get() != null; + return TaskCache.CompletedTask; + }, + policyProvider: new ResponseCachingPolicyProvider()); + + var context = new DefaultHttpContext(); + context.Request.Method = HttpMethods.Get; + context.Features.Set(new FakeResponseFeature()); + await middleware.Invoke(context); + + Assert.True(responseCachingFeatureAdded); + } + + [Fact] + public async Task Invoke_NonCacheableRequest_AddsResponseCachingFeature() + { + var responseCachingFeatureAdded = false; + var middleware = TestUtils.CreateTestMiddleware(next: httpContext => + { + responseCachingFeatureAdded = httpContext.Features.Get() != null; + return TaskCache.CompletedTask; + }, + policyProvider: new ResponseCachingPolicyProvider()); + + var context = new DefaultHttpContext(); + context.Request.Method = HttpMethods.Post; + await middleware.Invoke(context); + + Assert.True(responseCachingFeatureAdded); } [Fact] diff --git a/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs b/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs index b50be5b..a572466 100644 --- a/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs +++ b/test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs @@ -103,12 +103,17 @@ internal static IEnumerable CreateBuildersWithResponseCaching( } internal static ResponseCachingMiddleware CreateTestMiddleware( + RequestDelegate next = null, IResponseCache cache = null, ResponseCachingOptions options = null, TestSink testSink = null, IResponseCachingKeyProvider keyProvider = null, IResponseCachingPolicyProvider policyProvider = null) { + if (next == null) + { + next = httpContext => TaskCache.CompletedTask; + } if (cache == null) { cache = new TestResponseCache(); @@ -127,7 +132,7 @@ internal static ResponseCachingMiddleware CreateTestMiddleware( } return new ResponseCachingMiddleware( - httpContext => TaskCache.CompletedTask, + next, Options.Create(options), testSink == null ? (ILoggerFactory)NullLoggerFactory.Instance : new TestLoggerFactory(testSink, true), policyProvider,