diff --git a/src/Hosting/TestHost/src/HttpContextBuilder.cs b/src/Hosting/TestHost/src/HttpContextBuilder.cs index 710ffa3d8612..287e4aacf7b2 100644 --- a/src/Hosting/TestHost/src/HttpContextBuilder.cs +++ b/src/Hosting/TestHost/src/HttpContextBuilder.cs @@ -20,10 +20,11 @@ internal class HttpContextBuilder : IHttpBodyControlFeature private readonly TaskCompletionSource _responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private readonly ResponseStream _responseStream; - private readonly ResponseFeature _responseFeature = new ResponseFeature(); + private readonly ResponseFeature _responseFeature; private readonly RequestLifetimeFeature _requestLifetimeFeature = new RequestLifetimeFeature(); private readonly ResponseTrailersFeature _responseTrailersFeature = new ResponseTrailersFeature(); private bool _pipelineFinished; + private bool _returningResponse; private Context _testContext; private Action _responseReadCompleteCallback; @@ -33,6 +34,7 @@ internal HttpContextBuilder(IHttpApplication application, bool allowSyn AllowSynchronousIO = allowSynchronousIO; _preserveExecutionContext = preserveExecutionContext; _httpContext = new DefaultHttpContext(); + _responseFeature = new ResponseFeature(Abort); var request = _httpContext.Request; request.Protocol = "HTTP/1.1"; @@ -40,6 +42,7 @@ internal HttpContextBuilder(IHttpApplication application, bool allowSyn _httpContext.Features.Set(this); _httpContext.Features.Set(_responseFeature); + _httpContext.Features.Set(_responseFeature); _httpContext.Features.Set(_requestLifetimeFeature); _httpContext.Features.Set(_responseTrailersFeature); @@ -132,12 +135,13 @@ internal async Task CompleteResponseAsync() internal async Task ReturnResponseMessageAsync() { - // Check if the response has already started because the TrySetResult below could happen a bit late + // Check if the response is already returning because the TrySetResult below could happen a bit late // (as it happens on a different thread) by which point the CompleteResponseAsync could run and calls this // method again. - if (!_responseFeature.HasStarted) + if (!_returningResponse) { - // Sets HasStarted + _returningResponse = true; + try { await _responseFeature.FireOnSendingHeadersAsync(); diff --git a/src/Hosting/TestHost/src/ResponseFeature.cs b/src/Hosting/TestHost/src/ResponseFeature.cs index c6c7b47e18dd..2af0e3de365b 100644 --- a/src/Hosting/TestHost/src/ResponseFeature.cs +++ b/src/Hosting/TestHost/src/ResponseFeature.cs @@ -3,21 +3,24 @@ using System; using System.IO; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.TestHost { - internal class ResponseFeature : IHttpResponseFeature + internal class ResponseFeature : IHttpResponseFeature, IHttpResponseStartFeature { + private readonly HeaderDictionary _headers = new HeaderDictionary(); + private readonly Action _abort; + private Func _responseStartingAsync = () => Task.FromResult(true); private Func _responseCompletedAsync = () => Task.FromResult(true); - private HeaderDictionary _headers = new HeaderDictionary(); private int _statusCode; private string _reasonPhrase; - public ResponseFeature() + public ResponseFeature(Action abort) { Headers = _headers; Body = new MemoryStream(); @@ -25,6 +28,7 @@ public ResponseFeature() // 200 is the default status code all the way down to the host, so we set it // here to be consistent with the rest of the hosts when writing tests. StatusCode = 200; + _abort = abort; } public int StatusCode @@ -98,14 +102,36 @@ public void OnCompleted(Func callback, object state) public async Task FireOnSendingHeadersAsync() { - await _responseStartingAsync(); - HasStarted = true; - _headers.IsReadOnly = true; + if (!HasStarted) + { + try + { + await _responseStartingAsync(); + } + finally + { + HasStarted = true; + _headers.IsReadOnly = true; + } + } } public Task FireOnResponseCompletedAsync() { return _responseCompletedAsync(); } + + public async Task StartAsync(CancellationToken token = default) + { + try + { + await FireOnSendingHeadersAsync(); + } + catch (Exception ex) + { + _abort(ex); + throw; + } + } } } diff --git a/src/Hosting/TestHost/test/ClientHandlerTests.cs b/src/Hosting/TestHost/test/ClientHandlerTests.cs index 73d37c0159b9..75399284602c 100644 --- a/src/Hosting/TestHost/test/ClientHandlerTests.cs +++ b/src/Hosting/TestHost/test/ClientHandlerTests.cs @@ -153,6 +153,48 @@ public async Task ServerTrailersSetOnResponseAfterContentRead() }); } + [Fact] + public async Task ResponseStartAsync() + { + var hasStartedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var hasAssertedResponseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + bool? preHasStarted = null; + bool? postHasStarted = null; + var handler = new ClientHandler(PathString.Empty, new DummyApplication(async context => + { + preHasStarted = context.Response.HasStarted; + + await context.Response.StartAsync(); + + postHasStarted = context.Response.HasStarted; + + hasStartedTcs.TrySetResult(null); + + await hasAssertedResponseTcs.Task; + })); + + var invoker = new HttpMessageInvoker(handler); + var message = new HttpRequestMessage(HttpMethod.Post, "https://example.com/"); + + var responseTask = invoker.SendAsync(message, CancellationToken.None); + + // Ensure StartAsync has been called in response + await hasStartedTcs.Task; + + // Delay so async thread would have had time to attempt to return response + await Task.Delay(100); + Assert.False(responseTask.IsCompleted, "HttpResponse.StartAsync does not return response"); + + // Asserted that response return was checked, allow response to finish + hasAssertedResponseTcs.TrySetResult(null); + + await responseTask; + + Assert.False(preHasStarted); + Assert.True(postHasStarted); + } + [Fact] public async Task ResubmitRequestWorks() { diff --git a/src/Hosting/TestHost/test/ResponseFeatureTests.cs b/src/Hosting/TestHost/test/ResponseFeatureTests.cs index f8af4cf64d24..cea2f6121485 100644 --- a/src/Hosting/TestHost/test/ResponseFeatureTests.cs +++ b/src/Hosting/TestHost/test/ResponseFeatureTests.cs @@ -13,7 +13,7 @@ public class ResponseFeatureTests public async Task StatusCode_DefaultsTo200() { // Arrange & Act - var responseInformation = new ResponseFeature(); + var responseInformation = CreateResponseFeature(); // Assert Assert.Equal(200, responseInformation.StatusCode); @@ -25,11 +25,27 @@ public async Task StatusCode_DefaultsTo200() Assert.True(responseInformation.Headers.IsReadOnly); } + [Fact] + public async Task StartAsync_StartsResponse() + { + // Arrange & Act + var responseInformation = CreateResponseFeature(); + + // Assert + Assert.Equal(200, responseInformation.StatusCode); + Assert.False(responseInformation.HasStarted); + + await responseInformation.StartAsync(); + + Assert.True(responseInformation.HasStarted); + Assert.True(responseInformation.Headers.IsReadOnly); + } + [Fact] public void OnStarting_ThrowsWhenHasStarted() { // Arrange - var responseInformation = new ResponseFeature(); + var responseInformation = CreateResponseFeature(); responseInformation.HasStarted = true; // Act & Assert @@ -45,7 +61,7 @@ public void OnStarting_ThrowsWhenHasStarted() [Fact] public void StatusCode_ThrowsWhenHasStarted() { - var responseInformation = new ResponseFeature(); + var responseInformation = CreateResponseFeature(); responseInformation.HasStarted = true; Assert.Throws(() => responseInformation.StatusCode = 400); @@ -55,7 +71,7 @@ public void StatusCode_ThrowsWhenHasStarted() [Fact] public void StatusCode_MustBeGreaterThan99() { - var responseInformation = new ResponseFeature(); + var responseInformation = CreateResponseFeature(); Assert.Throws(() => responseInformation.StatusCode = 99); Assert.Throws(() => responseInformation.StatusCode = 0); @@ -64,5 +80,10 @@ public void StatusCode_MustBeGreaterThan99() responseInformation.StatusCode = 200; responseInformation.StatusCode = 1000; } + + private ResponseFeature CreateResponseFeature() + { + return new ResponseFeature(ex => { }); + } } -} \ No newline at end of file +} diff --git a/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerTest.cs b/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerTest.cs index 0ee748add761..2ec721730ce7 100644 --- a/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerTest.cs +++ b/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerTest.cs @@ -7,11 +7,13 @@ using System.IO; using System.Linq; using System.Net; +using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; using Xunit; @@ -114,6 +116,8 @@ public async Task ClearsResponseBuffer_BeforeRequestIsReexecuted() // add response buffering app.Use(async (httpContext, next) => { + httpContext.Features.Set(null); + var response = httpContext.Response; var originalResponseBody = response.Body; var bufferingStream = new MemoryStream();