diff --git a/src/Hosting/TestHost/ref/Microsoft.AspNetCore.TestHost.netcoreapp3.0.cs b/src/Hosting/TestHost/ref/Microsoft.AspNetCore.TestHost.netcoreapp3.0.cs index 3a11a01b93f1..ed7536b4905e 100644 --- a/src/Hosting/TestHost/ref/Microsoft.AspNetCore.TestHost.netcoreapp3.0.cs +++ b/src/Hosting/TestHost/ref/Microsoft.AspNetCore.TestHost.netcoreapp3.0.cs @@ -14,6 +14,11 @@ public static partial class HostBuilderTestServerExtensions public static System.Net.Http.HttpClient GetTestClient(this Microsoft.Extensions.Hosting.IHost host) { throw null; } public static Microsoft.AspNetCore.TestHost.TestServer GetTestServer(this Microsoft.Extensions.Hosting.IHost host) { throw null; } } + public partial class HttpResetTestException : System.Exception + { + public HttpResetTestException(int errorCode) { } + public int ErrorCode { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } + } public partial class RequestBuilder { public RequestBuilder(Microsoft.AspNetCore.TestHost.TestServer server, string path) { } diff --git a/src/Hosting/TestHost/src/ClientHandler.cs b/src/Hosting/TestHost/src/ClientHandler.cs index a34c3f7158c6..14b26869792a 100644 --- a/src/Hosting/TestHost/src/ClientHandler.cs +++ b/src/Hosting/TestHost/src/ClientHandler.cs @@ -70,7 +70,15 @@ protected override async Task SendAsync( { var req = context.Request; - req.Protocol = "HTTP/" + request.Version.ToString(fieldCount: 2); + if (request.Version == HttpVersion.Version20) + { + // https://tools.ietf.org/html/rfc7540 + req.Protocol = "HTTP/2"; + } + else + { + req.Protocol = "HTTP/" + request.Version.ToString(fieldCount: 2); + } req.Method = request.Method.ToString(); req.Scheme = request.RequestUri.Scheme; diff --git a/src/Hosting/TestHost/src/HttpContextBuilder.cs b/src/Hosting/TestHost/src/HttpContextBuilder.cs index 06d353f8a9a1..7e8d336bd870 100644 --- a/src/Hosting/TestHost/src/HttpContextBuilder.cs +++ b/src/Hosting/TestHost/src/HttpContextBuilder.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; @@ -10,7 +11,7 @@ namespace Microsoft.AspNetCore.TestHost { - internal class HttpContextBuilder : IHttpBodyControlFeature + internal class HttpContextBuilder : IHttpBodyControlFeature, IHttpResetFeature { private readonly ApplicationWrapper _application; private readonly bool _preserveExecutionContext; @@ -20,7 +21,7 @@ internal class HttpContextBuilder : IHttpBodyControlFeature private readonly ResponseBodyReaderStream _responseReaderStream; private readonly ResponseBodyPipeWriter _responsePipeWriter; private readonly ResponseFeature _responseFeature; - private readonly RequestLifetimeFeature _requestLifetimeFeature = new RequestLifetimeFeature(); + private readonly RequestLifetimeFeature _requestLifetimeFeature; private readonly ResponseTrailersFeature _responseTrailersFeature = new ResponseTrailersFeature(); private bool _pipelineFinished; private bool _returningResponse; @@ -34,13 +35,14 @@ internal HttpContextBuilder(ApplicationWrapper application, bool allowSynchronou _preserveExecutionContext = preserveExecutionContext; _httpContext = new DefaultHttpContext(); _responseFeature = new ResponseFeature(Abort); + _requestLifetimeFeature = new RequestLifetimeFeature(Abort); var request = _httpContext.Request; request.Protocol = "HTTP/1.1"; request.Method = HttpMethods.Get; var pipe = new Pipe(); - _responseReaderStream = new ResponseBodyReaderStream(pipe, AbortRequest, () => _responseReadCompleteCallback?.Invoke(_httpContext)); + _responseReaderStream = new ResponseBodyReaderStream(pipe, ClientInitiatedAbort, () => _responseReadCompleteCallback?.Invoke(_httpContext)); _responsePipeWriter = new ResponseBodyPipeWriter(pipe, ReturnResponseMessageAsync); _responseFeature.Body = new ResponseBodyWriterStream(_responsePipeWriter, () => AllowSynchronousIO); _responseFeature.BodySnapshot = _responseFeature.Body; @@ -77,11 +79,17 @@ internal void RegisterResponseReadCompleteCallback(Action responseR /// internal Task SendAsync(CancellationToken cancellationToken) { - var registration = cancellationToken.Register(AbortRequest); + var registration = cancellationToken.Register(ClientInitiatedAbort); // Everything inside this function happens in the SERVER's execution context (unless PreserveExecutionContext is true) async Task RunRequestAsync() { + // HTTP/2 specific features must be added after the request has been configured. + if (string.Equals("HTTP/2", _httpContext.Request.Protocol, StringComparison.OrdinalIgnoreCase)) + { + _httpContext.Features.Set(this); + } + // This will configure IHttpContextAccessor so it needs to happen INSIDE this function, // since we are now inside the Server's execution context. If it happens outside this cont // it will be lost when we abandon the execution context. @@ -120,13 +128,16 @@ async Task RunRequestAsync() return _responseTcs.Task; } - internal void AbortRequest() + // Triggered by request CancellationToken canceling or response stream Disposal. + internal void ClientInitiatedAbort() { if (!_pipelineFinished) { - _requestLifetimeFeature.Abort(); + // We don't want to trigger the token for already completed responses. + _requestLifetimeFeature.Cancel(); } - _responsePipeWriter.Complete(); + // Writes will still succeed, the app will only get an error if they check the CT. + _responseReaderStream.Abort(new IOException("The client aborted the request.")); } internal async Task CompleteResponseAsync() @@ -178,10 +189,15 @@ internal async Task ReturnResponseMessageAsync() internal void Abort(Exception exception) { - _pipelineFinished = true; _responsePipeWriter.Abort(exception); _responseReaderStream.Abort(exception); + _requestLifetimeFeature.Cancel(); _responseTcs.TrySetException(exception); } + + void IHttpResetFeature.Reset(int errorCode) + { + Abort(new HttpResetTestException(errorCode)); + } } } diff --git a/src/Hosting/TestHost/src/HttpResetTestException.cs b/src/Hosting/TestHost/src/HttpResetTestException.cs new file mode 100644 index 000000000000..2b2258511ef2 --- /dev/null +++ b/src/Hosting/TestHost/src/HttpResetTestException.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.TestHost +{ + /// + /// Used to surface to the test client that the application invoked + /// + public class HttpResetTestException : Exception + { + /// + /// Creates a new test exception + /// + /// The error code passed to + public HttpResetTestException(int errorCode) + : base($"The application reset the request with error code {errorCode}.") + { + ErrorCode = errorCode; + } + + /// + /// The error code passed to + /// + public int ErrorCode { get; } + } +} diff --git a/src/Hosting/TestHost/src/RequestLifetimeFeature.cs b/src/Hosting/TestHost/src/RequestLifetimeFeature.cs index 7593f83306b1..be6daf86a1b7 100644 --- a/src/Hosting/TestHost/src/RequestLifetimeFeature.cs +++ b/src/Hosting/TestHost/src/RequestLifetimeFeature.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Threading; using Microsoft.AspNetCore.Http.Features; @@ -9,14 +10,25 @@ namespace Microsoft.AspNetCore.TestHost internal class RequestLifetimeFeature : IHttpRequestLifetimeFeature { private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource(); + private readonly Action _abort; - public RequestLifetimeFeature() + public RequestLifetimeFeature(Action abort) { RequestAborted = _cancellationTokenSource.Token; + _abort = abort; } public CancellationToken RequestAborted { get; set; } - public void Abort() => _cancellationTokenSource.Cancel(); + internal void Cancel() + { + _cancellationTokenSource.Cancel(); + } + + void IHttpRequestLifetimeFeature.Abort() + { + _abort(new Exception("The application aborted the request.")); + _cancellationTokenSource.Cancel(); + } } } diff --git a/src/Hosting/TestHost/src/ResponseBodyReaderStream.cs b/src/Hosting/TestHost/src/ResponseBodyReaderStream.cs index b1311fdfb83f..0ec24a1c9aee 100644 --- a/src/Hosting/TestHost/src/ResponseBodyReaderStream.cs +++ b/src/Hosting/TestHost/src/ResponseBodyReaderStream.cs @@ -80,6 +80,11 @@ public async override Task ReadAsync(byte[] buffer, int offset, int count, using var registration = cancellationToken.Register(Cancel); var result = await _pipe.Reader.ReadAsync(cancellationToken); + if (result.IsCanceled) + { + throw new OperationCanceledException(); + } + if (result.Buffer.IsEmpty && result.IsCompleted) { _pipe.Reader.Complete(); @@ -114,9 +119,7 @@ private static void VerifyBuffer(byte[] buffer, int offset, int count) internal void Cancel() { - _aborted = true; - _abortException = new OperationCanceledException(); - _pipe.Writer.Complete(_abortException); + Abort(new OperationCanceledException()); } internal void Abort(Exception innerException) @@ -124,6 +127,8 @@ internal void Abort(Exception innerException) Contract.Requires(innerException != null); _aborted = true; _abortException = innerException; + _pipe.Reader.CancelPendingRead(); + _pipe.Reader.Complete(); } private void CheckAborted() diff --git a/src/Hosting/TestHost/test/ClientHandlerTests.cs b/src/Hosting/TestHost/test/ClientHandlerTests.cs index 01408813f3ea..ddf73ec5c9df 100644 --- a/src/Hosting/TestHost/test/ClientHandlerTests.cs +++ b/src/Hosting/TestHost/test/ClientHandlerTests.cs @@ -301,8 +301,7 @@ public async Task ClientDisposalCloses() Task readTask = responseStream.ReadAsync(new byte[100], 0, 100); Assert.False(readTask.IsCompleted); responseStream.Dispose(); - var read = await readTask.WithTimeout(); - Assert.Equal(0, read); + await Assert.ThrowsAsync(() => readTask.WithTimeout()); block.SetResult(0); } diff --git a/src/Hosting/TestHost/test/HttpContextBuilderTests.cs b/src/Hosting/TestHost/test/HttpContextBuilderTests.cs index 5411ad582d6f..b4731e1572f7 100644 --- a/src/Hosting/TestHost/test/HttpContextBuilderTests.cs +++ b/src/Hosting/TestHost/test/HttpContextBuilderTests.cs @@ -194,8 +194,7 @@ public async Task ClientDisposalCloses() Task readTask = responseStream.ReadAsync(new byte[100], 0, 100); Assert.False(readTask.IsCompleted); responseStream.Dispose(); - var read = await readTask.WithTimeout(); - Assert.Equal(0, read); + await Assert.ThrowsAsync(() => readTask.WithTimeout()); block.SetResult(0); } @@ -313,19 +312,22 @@ public async Task ClientHandlerCreateContextWithDefaultRequestParameters() [Fact] public async Task CallingAbortInsideHandlerShouldSetRequestAborted() { + var requestAborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var builder = new WebHostBuilder() .Configure(app => { app.Run(context => { + context.RequestAborted.Register(() => requestAborted.SetResult(0)); context.Abort(); return Task.CompletedTask; }); }); var server = new TestServer(builder); - var ctx = await server.SendAsync(c => { }); - Assert.True(ctx.RequestAborted.IsCancellationRequested); + var ex = await Assert.ThrowsAsync(() => server.SendAsync(c => { })); + Assert.Equal("The application aborted the request.", ex.Message); + await requestAborted.Task.WithTimeout(); } private class VerifierLogger : ILogger diff --git a/src/Hosting/TestHost/test/RequestLifetimeTests.cs b/src/Hosting/TestHost/test/RequestLifetimeTests.cs new file mode 100644 index 000000000000..e73f2b74582e --- /dev/null +++ b/src/Hosting/TestHost/test/RequestLifetimeTests.cs @@ -0,0 +1,115 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.Hosting; +using Xunit; + +namespace Microsoft.AspNetCore.TestHost +{ + public class RequestLifetimeTests + { + [Fact] + public async Task LifetimeFeature_Abort_TriggersRequestAbortedToken() + { + var requestAborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var host = await CreateHost(async httpContext => + { + httpContext.RequestAborted.Register(() => requestAborted.SetResult(0)); + httpContext.Abort(); + + await requestAborted.Task.WithTimeout(); + }); + + var client = host.GetTestServer().CreateClient(); + var ex = await Assert.ThrowsAsync(() => client.GetAsync("/", HttpCompletionOption.ResponseHeadersRead)); + Assert.Equal("The application aborted the request.", ex.Message); + await requestAborted.Task.WithTimeout(); + } + + [Fact] + public async Task LifetimeFeature_AbortBeforeHeadersSent_ClientThrows() + { + var abortReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var host = await CreateHost(async httpContext => + { + httpContext.Abort(); + await abortReceived.Task.WithTimeout(); + }); + + var client = host.GetTestServer().CreateClient(); + var ex = await Assert.ThrowsAsync(() => client.GetAsync("/", HttpCompletionOption.ResponseHeadersRead)); + Assert.Equal("The application aborted the request.", ex.Message); + abortReceived.SetResult(0); + } + + [Fact] + public async Task LifetimeFeature_AbortAfterHeadersSent_ClientBodyThrows() + { + var responseReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var abortReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var host = await CreateHost(async httpContext => + { + await httpContext.Response.Body.FlushAsync(); + await responseReceived.Task.WithTimeout(); + httpContext.Abort(); + await abortReceived.Task.WithTimeout(); + }); + + var client = host.GetTestServer().CreateClient(); + var response = await client.GetAsync("/", HttpCompletionOption.ResponseHeadersRead); + responseReceived.SetResult(0); + response.EnsureSuccessStatusCode(); + var ex = await Assert.ThrowsAsync(() => response.Content.ReadAsByteArrayAsync()); + var rex = ex.GetBaseException(); + Assert.Equal("The application aborted the request.", rex.Message); + abortReceived.SetResult(0); + } + + [Fact] + public async Task LifetimeFeature_AbortAfterSomeDataSent_ClientBodyThrows() + { + var responseReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var abortReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var host = await CreateHost(async httpContext => + { + await httpContext.Response.WriteAsync("Hello World"); + await responseReceived.Task.WithTimeout(); + httpContext.Abort(); + await abortReceived.Task.WithTimeout(); + }); + + var client = host.GetTestServer().CreateClient(); + var response = await client.GetAsync("/", HttpCompletionOption.ResponseHeadersRead); + responseReceived.SetResult(0); + response.EnsureSuccessStatusCode(); + var ex = await Assert.ThrowsAsync(() => response.Content.ReadAsByteArrayAsync()); + var rex = ex.GetBaseException(); + Assert.Equal("The application aborted the request.", rex.Message); + abortReceived.SetResult(0); + } + + // TODO: Abort after CompleteAsync - No-op, the request is already complete. + + private Task CreateHost(RequestDelegate appDelegate) + { + return new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .Configure(app => + { + app.Run(appDelegate); + }); + }) + .StartAsync(); + } + } +} diff --git a/src/Hosting/TestHost/test/ResponseResetTests.cs b/src/Hosting/TestHost/test/ResponseResetTests.cs new file mode 100644 index 000000000000..d169f019a348 --- /dev/null +++ b/src/Hosting/TestHost/test/ResponseResetTests.cs @@ -0,0 +1,161 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.Hosting; +using Xunit; + +namespace Microsoft.AspNetCore.TestHost +{ + public class ResponseResetTests + { + [Fact] + // Reset is only present for HTTP/2 + public async Task ResetFeature_Http11_Missing() + { + using var host = await CreateHost(httpContext => + { + var feature = httpContext.Features.Get(); + Assert.Null(feature); + return Task.CompletedTask; + }); + + var client = host.GetTestServer().CreateClient(); + client.DefaultRequestVersion = HttpVersion.Version11; + var response = await client.GetAsync("/"); + response.EnsureSuccessStatusCode(); + } + + [Fact] + public async Task ResetFeature_Http2_Present() + { + using var host = await CreateHost(httpContext => + { + var feature = httpContext.Features.Get(); + Assert.NotNull(feature); + return Task.CompletedTask; + }); + + var client = host.GetTestServer().CreateClient(); + client.DefaultRequestVersion = HttpVersion.Version20; + var response = await client.GetAsync("/"); + response.EnsureSuccessStatusCode(); + } + + [Fact] + public async Task ResetFeature_Reset_TriggersRequestAbortedToken() + { + var requestAborted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var host = await CreateHost(async httpContext => + { + httpContext.RequestAborted.Register(() => requestAborted.SetResult(0)); + + var feature = httpContext.Features.Get(); + feature.Reset(12345); + await requestAborted.Task.WithTimeout(); + }); + + var client = host.GetTestServer().CreateClient(); + client.DefaultRequestVersion = HttpVersion.Version20; + var rex = await Assert.ThrowsAsync(() => client.GetAsync("/")); + Assert.Equal("The application reset the request with error code 12345.", rex.Message); + Assert.Equal(12345, rex.ErrorCode); + await requestAborted.Task.WithTimeout(); + } + + [Fact] + public async Task ResetFeature_ResetBeforeHeadersSent_ClientThrows() + { + var resetReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var host = await CreateHost(async httpContext => + { + var feature = httpContext.Features.Get(); + feature.Reset(12345); + await resetReceived.Task.WithTimeout(); + }); + + var client = host.GetTestServer().CreateClient(); + client.DefaultRequestVersion = HttpVersion.Version20; + var rex = await Assert.ThrowsAsync(() => client.GetAsync("/", HttpCompletionOption.ResponseHeadersRead)); + Assert.Equal("The application reset the request with error code 12345.", rex.Message); + Assert.Equal(12345, rex.ErrorCode); + resetReceived.SetResult(0); + } + + [Fact] + public async Task ResetFeature_ResetAfterHeadersSent_ClientBodyThrows() + { + var responseReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var resetReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var host = await CreateHost(async httpContext => + { + await httpContext.Response.Body.FlushAsync(); + await responseReceived.Task.WithTimeout(); + var feature = httpContext.Features.Get(); + feature.Reset(12345); + await resetReceived.Task.WithTimeout(); + }); + + var client = host.GetTestServer().CreateClient(); + client.DefaultRequestVersion = HttpVersion.Version20; + var response = await client.GetAsync("/", HttpCompletionOption.ResponseHeadersRead); + responseReceived.SetResult(0); + response.EnsureSuccessStatusCode(); + var ex = await Assert.ThrowsAsync(() => response.Content.ReadAsByteArrayAsync()); + var rex = Assert.IsAssignableFrom(ex.GetBaseException()); + Assert.Equal("The application reset the request with error code 12345.", rex.Message); + Assert.Equal(12345, rex.ErrorCode); + resetReceived.SetResult(0); + } + + [Fact] + public async Task ResetFeature_ResetAfterSomeDataSent_ClientBodyThrows() + { + var responseReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var resetReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + using var host = await CreateHost(async httpContext => + { + await httpContext.Response.WriteAsync("Hello World"); + await responseReceived.Task.WithTimeout(); + var feature = httpContext.Features.Get(); + feature.Reset(12345); + await resetReceived.Task.WithTimeout(); + }); + + var client = host.GetTestServer().CreateClient(); + client.DefaultRequestVersion = HttpVersion.Version20; + var response = await client.GetAsync("/", HttpCompletionOption.ResponseHeadersRead); + responseReceived.SetResult(0); + response.EnsureSuccessStatusCode(); + var ex = await Assert.ThrowsAsync(() => response.Content.ReadAsByteArrayAsync()); + var rex = Assert.IsAssignableFrom(ex.GetBaseException()); + Assert.Equal("The application reset the request with error code 12345.", rex.Message); + Assert.Equal(12345, rex.ErrorCode); + resetReceived.SetResult(0); + } + + // TODO: Reset after CompleteAsync - Not sure how to surface this. CompleteAsync hasn't been implemented yet anyways. + + private Task CreateHost(RequestDelegate appDelegate) + { + return new HostBuilder() + .ConfigureWebHost(webBuilder => + { + webBuilder + .UseTestServer() + .Configure(app => + { + app.Run(appDelegate); + }); + }) + .StartAsync(); + } + } +} diff --git a/src/Hosting/TestHost/test/TestClientTests.cs b/src/Hosting/TestHost/test/TestClientTests.cs index 20a45706dd8e..9fc257d843f2 100644 --- a/src/Hosting/TestHost/test/TestClientTests.cs +++ b/src/Hosting/TestHost/test/TestClientTests.cs @@ -420,7 +420,7 @@ public async Task ClientCancellationAbortsRequest() var client = server.CreateClient(); var cts = new CancellationTokenSource(); cts.CancelAfter(500); - var response = await client.GetAsync("http://localhost:12345", cts.Token); + var response = await Assert.ThrowsAnyAsync(() => client.GetAsync("http://localhost:12345", cts.Token)); // Assert var exception = await Assert.ThrowsAnyAsync(async () => await tcs.Task);