Skip to content
This repository was archived by the owner on Nov 22, 2018. It is now read-only.

Always add IResponseCachingFeature before calling the next middleware #82

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,17 @@ public async Task Invoke(HttpContext httpContext)
}
else
{
await _next(httpContext);
// Add IResponseCachingFeature which may be required when the response is generated
AddResponseCachingFeature(httpContext);

try
{
await _next(httpContext);
}
finally
{
RemoveResponseCachingFeature(httpContext);
}
}
}

Expand Down Expand Up @@ -318,6 +328,15 @@ internal Task OnResponseStartingAsync(ResponseCachingContext context)
}
}

internal static void AddResponseCachingFeature(HttpContext context)
{
if (context.Features.Get<IResponseCachingFeature>() != null)
{
throw new InvalidOperationException($"Another instance of {nameof(ResponseCachingFeature)} already exists. Only one instance of {nameof(ResponseCachingMiddleware)} can be configured for an application.");
}
context.Features.Set<IResponseCachingFeature>(new ResponseCachingFeature());
}

internal void ShimResponseStream(ResponseCachingContext context)
{
// Shim response stream
Expand All @@ -333,13 +352,12 @@ internal void ShimResponseStream(ResponseCachingContext context)
}

// Add IResponseCachingFeature
if (context.HttpContext.Features.Get<IResponseCachingFeature>() != null)
{
throw new InvalidOperationException($"Another instance of {nameof(ResponseCachingFeature)} already exists. Only one instance of {nameof(ResponseCachingMiddleware)} can be configured for an application.");
}
context.HttpContext.Features.Set<IResponseCachingFeature>(new ResponseCachingFeature());
AddResponseCachingFeature(context.HttpContext);
}

internal static void RemoveResponseCachingFeature(HttpContext context) =>
context.Features.Set<IResponseCachingFeature>(null);

internal static void UnshimResponseStream(ResponseCachingContext context)
{
// Unshim response stream
Expand All @@ -349,7 +367,7 @@ internal static void UnshimResponseStream(ResponseCachingContext context)
context.HttpContext.Features.Set(context.OriginalSendFileFeature);

// Remove IResponseCachingFeature
context.HttpContext.Features.Set<IResponseCachingFeature>(null);
RemoveResponseCachingFeature(context.HttpContext);
}

internal static bool ContentIsNotModified(ResponseCachingContext context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Headers;
using Microsoft.AspNetCore.ResponseCaching.Internal;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging.Testing;
using Microsoft.Extensions.Primitives;
using Microsoft.Net.Http.Headers;
Expand Down Expand Up @@ -715,16 +717,57 @@ public async Task FinalizeCacheBody_DoNotCache_IfBufferingDisabled()
}

[Fact]
public void ShimResponseStream_SecondInvocation_Throws()
public void AddResponseCachingFeature_SecondInvocation_Throws()
{
var middleware = TestUtils.CreateTestMiddleware();
var context = TestUtils.CreateTestContext();
var httpContext = new DefaultHttpContext();

// Should not throw
middleware.ShimResponseStream(context);
ResponseCachingMiddleware.AddResponseCachingFeature(httpContext);

// Should throw
Assert.ThrowsAny<InvalidOperationException>(() => middleware.ShimResponseStream(context));
Assert.ThrowsAny<InvalidOperationException>(() => ResponseCachingMiddleware.AddResponseCachingFeature(httpContext));
}

private class FakeResponseFeature : HttpResponseFeature
{
public override void OnStarting(Func<object, Task> callback, object state) { }
}

[Fact]
public async Task Invoke_CacheableRequest_AddsResponseCachingFeature()
{
var responseCachingFeatureAdded = false;
var middleware = TestUtils.CreateTestMiddleware(next: httpContext =>
{
responseCachingFeatureAdded = httpContext.Features.Get<IResponseCachingFeature>() != null;
return TaskCache.CompletedTask;
},
policyProvider: new ResponseCachingPolicyProvider());

var context = new DefaultHttpContext();
context.Request.Method = HttpMethods.Get;
context.Features.Set<IHttpResponseFeature>(new FakeResponseFeature());
await middleware.Invoke(context);

Assert.True(responseCachingFeatureAdded);
}

[Fact]
public async Task Invoke_NonCacheableRequest_AddsResponseCachingFeature()
{
var responseCachingFeatureAdded = false;
var middleware = TestUtils.CreateTestMiddleware(next: httpContext =>
{
responseCachingFeatureAdded = httpContext.Features.Get<IResponseCachingFeature>() != null;
return TaskCache.CompletedTask;
},
policyProvider: new ResponseCachingPolicyProvider());

var context = new DefaultHttpContext();
context.Request.Method = HttpMethods.Post;
await middleware.Invoke(context);

Assert.True(responseCachingFeatureAdded);
}

[Fact]
Expand Down
7 changes: 6 additions & 1 deletion test/Microsoft.AspNetCore.ResponseCaching.Tests/TestUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,17 @@ internal static IEnumerable<IWebHostBuilder> CreateBuildersWithResponseCaching(
}

internal static ResponseCachingMiddleware CreateTestMiddleware(
RequestDelegate next = null,
IResponseCache cache = null,
ResponseCachingOptions options = null,
TestSink testSink = null,
IResponseCachingKeyProvider keyProvider = null,
IResponseCachingPolicyProvider policyProvider = null)
{
if (next == null)
{
next = httpContext => TaskCache.CompletedTask;
}
if (cache == null)
{
cache = new TestResponseCache();
Expand All @@ -127,7 +132,7 @@ internal static ResponseCachingMiddleware CreateTestMiddleware(
}

return new ResponseCachingMiddleware(
httpContext => TaskCache.CompletedTask,
next,
Options.Create(options),
testSink == null ? (ILoggerFactory)NullLoggerFactory.Instance : new TestLoggerFactory(testSink, true),
policyProvider,
Expand Down