diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs index 2911ef0b1d19..9c3a3e52f600 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs @@ -364,7 +364,7 @@ public async Task ProcessRequestAsync(IHttpApplication appli Log.Http3FrameReceived(ConnectionId, _streamIdFeature.StreamId, _incomingFrame); consumed = examined = framePayload.End; - await ProcessHttp3Stream(application, framePayload); + await ProcessHttp3Stream(application, framePayload, result.IsCompleted && readableBuffer.IsEmpty); } } @@ -448,14 +448,14 @@ private ValueTask OnEndStreamReceived() return RequestBodyPipe.Writer.CompleteAsync(); } - private Task ProcessHttp3Stream(IHttpApplication application, in ReadOnlySequence payload) where TContext : notnull + private Task ProcessHttp3Stream(IHttpApplication application, in ReadOnlySequence payload, bool isCompleted) where TContext : notnull { switch (_incomingFrame.Type) { case Http3FrameType.Data: return ProcessDataFrameAsync(payload); case Http3FrameType.Headers: - return ProcessHeadersFrameAsync(application, payload); + return ProcessHeadersFrameAsync(application, payload, isCompleted); case Http3FrameType.Settings: case Http3FrameType.CancelPush: case Http3FrameType.GoAway: @@ -478,7 +478,7 @@ private Task ProcessUnknownFrameAsync() return Task.CompletedTask; } - private Task ProcessHeadersFrameAsync(IHttpApplication application, ReadOnlySequence payload) where TContext : notnull + private async Task ProcessHeadersFrameAsync(IHttpApplication application, ReadOnlySequence payload, bool isCompleted) where TContext : notnull { // HEADERS frame after trailing headers is invalid. // https://quicwg.org/base-drafts/draft-ietf-quic-http.html#section-4.1 @@ -506,7 +506,7 @@ private Task ProcessHeadersFrameAsync(IHttpApplication appli case RequestHeaderParsingState.Trailers: // trailers // TODO figure out if there is anything else to do here. - return Task.CompletedTask; + return; default: Debug.Fail("Unexpected header parsing state."); break; @@ -514,11 +514,16 @@ private Task ProcessHeadersFrameAsync(IHttpApplication appli InputRemaining = HttpRequestHeaders.ContentLength; + // If the stream is complete after receiving the headers then run OnEndStreamReceived. + // If there is a bad content length then this will throw before the request delegate is called. + if (isCompleted) + { + await OnEndStreamReceived(); + } + _appCompleted = new TaskCompletionSource(); ThreadPool.UnsafeQueueUserWorkItem(this, preferLocal: false); - - return Task.CompletedTask; } private Task ProcessDataFrameAsync(in ReadOnlySequence payload) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs index a00408708b3b..0e86c84104ea 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3StreamTests.cs @@ -639,11 +639,64 @@ public async Task ContentLength_Received_NoDataFrames_Reset() new KeyValuePair(HeaderNames.ContentLength, "12"), }; - var requestStream = await InitializeConnectionAndStreamsAsync(_noopApplication); + var requestDelegateCalled = false; + var requestStream = await InitializeConnectionAndStreamsAsync(c => + { + // Bad content-length + end stream means the request delegate + // is never called by the server. + requestDelegateCalled = true; + return Task.CompletedTask; + }); await requestStream.SendHeadersAsync(headers, endStream: true); await requestStream.WaitForStreamErrorAsync(Http3ErrorCode.ProtocolError, CoreStrings.Http3StreamErrorLessDataThanLength); + + Assert.False(requestDelegateCalled); + } + + [Fact] + public async Task EndRequestStream_ContinueReadingFromResponse() + { + var headersTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + + var data = new byte[] { 1, 2, 3, 4, 5, 6 }; + + var requestStream = await InitializeConnectionAndStreamsAsync(async context => + { + await context.Response.BodyWriter.FlushAsync(); + + await headersTcs.Task; + + for (var i = 0; i < data.Length; i++) + { + await Task.Delay(50); + await context.Response.BodyWriter.WriteAsync(new byte[] { data[i] }); + } + }); + + await requestStream.SendHeadersAsync(headers, endStream: true); + await requestStream.ExpectHeadersAsync(); + + headersTcs.SetResult(); + + var receivedData = new List(); + while (receivedData.Count < data.Length) + { + var frameData = await requestStream.ExpectDataAsync(); + receivedData.AddRange(frameData.ToArray()); + } + + Assert.Equal(data, receivedData); + + await requestStream.ExpectReceiveEndOfStream(); } [Fact]