diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs index 8d67b56462c5..d75bb05d923d 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs @@ -14,6 +14,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http /// internal sealed class Http1UpgradeMessageBody : Http1MessageBody { + private int _userCanceled; + public Http1UpgradeMessageBody(Http1Connection context) : base(context) { @@ -26,13 +28,13 @@ public Http1UpgradeMessageBody(Http1Connection context) public override ValueTask ReadAsync(CancellationToken cancellationToken = default) { ThrowIfCompleted(); - return _context.Input.ReadAsync(cancellationToken); + return ReadAsyncInternal(cancellationToken); } public override bool TryRead(out ReadResult result) { ThrowIfCompleted(); - return _context.Input.TryRead(out result); + return TryReadInternal(out result); } public override void AdvanceTo(SequencePosition consumed) @@ -54,6 +56,7 @@ public override void Complete(Exception exception) public override void CancelPendingRead() { + Interlocked.Exchange(ref _userCanceled, 1); _context.Input.CancelPendingRead(); } @@ -69,12 +72,49 @@ public override Task StopAsync() public override bool TryReadInternal(out ReadResult readResult) { - return _context.Input.TryRead(out readResult); + // Ignore the canceled readResult unless it was canceled by the user. + do + { + if (!_context.Input.TryRead(out readResult)) + { + return false; + } + } while (readResult.IsCanceled && Interlocked.Exchange(ref _userCanceled, 0) == 0); + + return true; } public override ValueTask ReadAsyncInternal(CancellationToken cancellationToken = default) { - return _context.Input.ReadAsync(cancellationToken); + ReadResult readResult; + + // Ignore the canceled readResult unless it was canceled by the user. + do + { + var readTask = _context.Input.ReadAsync(cancellationToken); + + if (!readTask.IsCompletedSuccessfully) + { + return ReadAsyncInternalAwaited(readTask, cancellationToken); + } + + readResult = readTask.GetAwaiter().GetResult(); + } while (readResult.IsCanceled && Interlocked.Exchange(ref _userCanceled, 0) == 0); + + return new ValueTask(readResult); + } + + private async ValueTask ReadAsyncInternalAwaited(ValueTask readTask, CancellationToken cancellationToken = default) + { + var readResult = await readTask; + + // Ignore the canceled readResult unless it was canceled by the user. + while (readResult.IsCanceled && Interlocked.Exchange(ref _userCanceled, 0) == 0) + { + readResult = await _context.Input.ReadAsync(cancellationToken); + } + + return readResult; } } } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/UpgradeTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/UpgradeTests.cs index cf36bb8a1512..cfd81233bee2 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/UpgradeTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/UpgradeTests.cs @@ -3,7 +3,10 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; @@ -343,5 +346,50 @@ await connection.Receive("HTTP/1.1 101 Switching Protocols", await appCompletedTcs.Task.DefaultTimeout(); } } + + [Fact] + public async Task DoesNotThrowGivenCanceledReadResult() + { + var appCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await using var server = new TestServer(async context => + { + try + { + var upgradeFeature = context.Features.Get(); + var duplexStream = await upgradeFeature.UpgradeAsync(); + + // Kestrel will call Transport.Input.CancelPendingRead() during shutdown so idle connections + // can wake up and shutdown gracefully. We manually call CancelPendingRead() to simulate this and + // ensure the Stream returned by UpgradeAsync doesn't throw in this case. + // https://github.com/dotnet/aspnetcore/issues/26482 + var connectionTransportFeature = context.Features.Get(); + connectionTransportFeature.Transport.Input.CancelPendingRead(); + + // Use ReadAsync() instead of CopyToAsync() for this test since IsCanceled is only checked in + // HttpRequestStream.ReadAsync() and not HttpRequestStream.CopyToAsync() + Assert.Equal(0, await duplexStream.ReadAsync(new byte[1])); + appCompletedTcs.SetResult(null); + } + catch (Exception ex) + { + appCompletedTcs.SetException(ex); + throw; + } + }, + new TestServiceContext(LoggerFactory)); + + using (var connection = server.CreateConnection()) + { + await connection.SendEmptyGetWithUpgrade(); + await connection.Receive("HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + ""); + } + + await appCompletedTcs.Task.DefaultTimeout(); + } } }