diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs index 8747eda3d4fc..433c783bb900 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs @@ -44,7 +44,7 @@ public override void AdvanceTo(SequencePosition consumed, SequencePosition exami public override bool TryReadInternal(out ReadResult readResult) { - TryStart(); + TryStartAsync(); var boolResult = _requestBodyPipe.Reader.TryRead(out _readResult); @@ -61,7 +61,7 @@ public override bool TryReadInternal(out ReadResult readResult) public override async ValueTask ReadAsyncInternal(CancellationToken cancellationToken = default) { - TryStart(); + await TryStartAsync(); try { @@ -101,7 +101,7 @@ private async Task PumpAsync() if (!awaitable.IsCompleted) { - TryProduceContinue(); + await TryProduceContinueAsync(); } while (true) @@ -171,7 +171,7 @@ protected override ValueTask OnStopAsync() // call complete here on the reader _requestBodyPipe.Reader.Complete(); - Debug.Assert(_pumpTask != null, "OnReadStarted must have been called."); + Debug.Assert(_pumpTask != null, "OnReadStartedAsync must have been called."); // PumpTask catches all Exceptions internally. if (_pumpTask.IsCompleted) @@ -195,9 +195,10 @@ private async ValueTask StopAsyncAwaited(Task pumpTask) _requestBodyPipe.Reset(); } - protected override void OnReadStarted() + protected override Task OnReadStartedAsync() { _pumpTask = PumpAsync(); + return Task.CompletedTask; } private bool Read(ReadOnlySequence readableBuffer, PipeWriter writableBuffer, out SequencePosition consumed, out SequencePosition examined) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs index dd9fbfe6675d..4b92e6f2cc2e 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs @@ -45,7 +45,7 @@ public override async ValueTask ReadAsyncInternal(CancellationToken KestrelBadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTimeout); } - TryStart(); + await TryStartAsync(); // The while(true) loop is required because the Http1 connection calls CancelPendingRead to unblock // the call to StartTimingReadAsync to check if the request timed out. @@ -132,7 +132,7 @@ public override bool TryReadInternal(out ReadResult readResult) KestrelBadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTimeout); } - TryStart(); + TryStartAsync(); // The while(true) because we don't want to return a canceled ReadResult if the user themselves didn't cancel it. while (true) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index 7b6c3b025789..a45a972fb139 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -909,27 +909,21 @@ protected bool VerifyResponseContentLength([NotNullWhen(false)] out Exception? e return true; } - public void ProduceContinue() + public ValueTask ProduceContinueAsync() { if (HasResponseStarted) { - return; + return default; } if (_httpVersion != Http.HttpVersion.Http10 && ((IHeaderDictionary)HttpRequestHeaders).TryGetValue(HeaderNames.Expect, out var expect) && (expect.FirstOrDefault() ?? "").Equals("100-continue", StringComparison.OrdinalIgnoreCase)) { - ValueTask vt = Output.Write100ContinueAsync(); - if (vt.IsCompleted) - { - vt.GetAwaiter().GetResult(); - } - else - { - vt.AsTask().GetAwaiter().GetResult(); - } + return Output.Write100ContinueAsync(); } + + return default; } public Task InitializeResponseAsync(int firstWriteByteCount) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpResponseControl.cs b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpResponseControl.cs index 4bd37e02f88c..232726bfed02 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpResponseControl.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpResponseControl.cs @@ -10,7 +10,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { internal interface IHttpResponseControl { - void ProduceContinue(); + ValueTask ProduceContinueAsync(); Memory GetMemory(int sizeHint = 0); Span GetSpan(int sizeHint = 0); void Advance(int bytes); diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs index fd1e1376c09e..06668bd872dc 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs @@ -66,11 +66,21 @@ public virtual ValueTask CompleteAsync(Exception? exception) public virtual Task ConsumeAsync() { - TryStart(); + Task startTask = TryStartAsync(); + if (!startTask.IsCompletedSuccessfully) + { + return ConsumeAwaited(startTask); + } return OnConsumeAsync(); } + private async Task ConsumeAwaited(Task startTask) + { + await startTask; + await OnConsumeAsync(); + } + public virtual ValueTask StopAsync() { TryStop(); @@ -93,20 +103,22 @@ public virtual void Reset() _examinedUnconsumedBytes = 0; } - protected void TryProduceContinue() + protected ValueTask TryProduceContinueAsync() { if (_send100Continue) { - _context.HttpResponseControl.ProduceContinue(); _send100Continue = false; + return _context.HttpResponseControl.ProduceContinueAsync(); } + + return default; } - protected void TryStart() + protected Task TryStartAsync() { if (_context.HasStartedConsumingRequestBody) { - return; + return Task.CompletedTask; } OnReadStarting(); @@ -128,7 +140,7 @@ protected void TryStart() } } - OnReadStarted(); + return OnReadStartedAsync(); } protected void TryStop() @@ -165,8 +177,9 @@ protected virtual void OnReadStarting() { } - protected virtual void OnReadStarted() + protected virtual Task OnReadStartedAsync() { + return Task.CompletedTask; } protected void AddAndCheckObservedBytes(long observedBytes) @@ -183,7 +196,15 @@ protected ValueTask StartTimingReadAsync(ValueTask readA { if (!readAwaitable.IsCompleted) { - TryProduceContinue(); + ValueTask continueTask = TryProduceContinueAsync(); + if (!continueTask.IsCompletedSuccessfully) + { + return StartTimingReadAwaited(continueTask, readAwaitable, cancellationToken); + } + else + { + continueTask.GetAwaiter().GetResult(); + } if (_timingEnabled) { @@ -195,6 +216,19 @@ protected ValueTask StartTimingReadAsync(ValueTask readA return readAwaitable; } + protected async ValueTask StartTimingReadAwaited(ValueTask continueTask, ValueTask readAwaitable, CancellationToken cancellationToken) + { + await continueTask; + + if (_timingEnabled) + { + _backpressure = true; + _context.TimeoutControl.StartTimingRead(); + } + + return await readAwaitable; + } + protected void CountBytesRead(long bytesInReadResult) { var numFirstSeenBytes = bytesInReadResult - _alreadyTimedBytes; diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs index b89b89350451..59af4bed2f92 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs @@ -2,10 +2,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Diagnostics; using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 @@ -30,13 +32,19 @@ protected override void OnReadStarting() } } - protected override void OnReadStarted() + protected override Task OnReadStartedAsync() { // Produce 100-continue if no request body data for the stream has arrived yet. if (!_context.RequestBodyStarted) { - TryProduceContinue(); + ValueTask continueTask = TryProduceContinueAsync(); + if (!continueTask.IsCompletedSuccessfully) + { + return continueTask.GetAsTask(); + } } + + return Task.CompletedTask; } public override void Reset() @@ -59,7 +67,7 @@ public override void AdvanceTo(SequencePosition consumed, SequencePosition exami public override bool TryRead(out ReadResult readResult) { - TryStart(); + TryStartAsync(); var hasResult = _context.RequestBodyPipe.Reader.TryRead(out readResult); @@ -80,7 +88,7 @@ public override bool TryRead(out ReadResult readResult) public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) { - TryStart(); + await TryStartAsync(); try { diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3MessageBody.cs index 2210d4b0799d..c710593880a2 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3MessageBody.cs @@ -46,7 +46,7 @@ public override void AdvanceTo(SequencePosition consumed, SequencePosition exami public override bool TryRead(out ReadResult readResult) { - TryStart(); + TryStartAsync(); var hasResult = _context.RequestBodyPipe.Reader.TryRead(out readResult); @@ -67,7 +67,7 @@ public override bool TryRead(out ReadResult readResult) public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) { - TryStart(); + await TryStartAsync(); try { diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs index aa05268499f2..a00408708b3b 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs @@ -628,6 +628,7 @@ public async Task RemoveConnectionSpecificHeaders() } [Fact] + [QuarantinedTest("https://github.com/dotnet/aspnetcore/issues/31777")] public async Task ContentLength_Received_NoDataFrames_Reset() { var headers = new[]