diff --git a/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs b/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs
index 30f567851c78..6f53a07297ef 100644
--- a/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs
+++ b/src/Http/Http.Features/ref/Microsoft.AspNetCore.Http.Features.netstandard2.0.cs
@@ -200,6 +200,10 @@ public partial interface IHttpRequestTrailersFeature
bool Available { get; }
Microsoft.AspNetCore.Http.IHeaderDictionary Trailers { get; }
}
+ public partial interface IHttpResponseCompletionFeature
+ {
+ System.Threading.Tasks.Task CompleteAsync();
+ }
public partial interface IHttpResponseFeature
{
System.IO.Stream Body { get; set; }
diff --git a/src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs b/src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs
new file mode 100644
index 000000000000..eed45e40364b
--- /dev/null
+++ b/src/Http/Http.Features/src/IHttpResponseCompletionFeature.cs
@@ -0,0 +1,20 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System.Threading.Tasks;
+
+namespace Microsoft.AspNetCore.Http.Features
+{
+ ///
+ /// A feature to gracefully end a response.
+ ///
+ public interface IHttpResponseCompletionFeature
+ {
+ ///
+ /// Flush any remaining response headers, data, or trailers.
+ /// This may throw if the response is in an invalid state such as a Content-Length mismatch.
+ ///
+ ///
+ Task CompleteAsync();
+ }
+}
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs
index 82b1ebfe6b1a..f922de8a995f 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs
@@ -277,6 +277,7 @@ protected void ResetHttp1Features()
protected void ResetHttp2Features()
{
_currentIHttp2StreamIdFeature = this;
+ _currentIHttpResponseCompletionFeature = this;
_currentIHttpResponseTrailersFeature = this;
}
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs
index b9b3e26905e7..f594feed0fa3 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs
@@ -29,6 +29,7 @@ internal partial class HttpProtocol : IFeatureCollection
private static readonly Type IFormFeatureType = typeof(IFormFeature);
private static readonly Type IHttpUpgradeFeatureType = typeof(IHttpUpgradeFeature);
private static readonly Type IHttp2StreamIdFeatureType = typeof(IHttp2StreamIdFeature);
+ private static readonly Type IHttpResponseCompletionFeatureType = typeof(IHttpResponseCompletionFeature);
private static readonly Type IHttpResponseTrailersFeatureType = typeof(IHttpResponseTrailersFeature);
private static readonly Type IResponseCookiesFeatureType = typeof(IResponseCookiesFeature);
private static readonly Type IItemsFeatureType = typeof(IItemsFeature);
@@ -58,6 +59,7 @@ internal partial class HttpProtocol : IFeatureCollection
private object _currentIFormFeature;
private object _currentIHttpUpgradeFeature;
private object _currentIHttp2StreamIdFeature;
+ private object _currentIHttpResponseCompletionFeature;
private object _currentIHttpResponseTrailersFeature;
private object _currentIResponseCookiesFeature;
private object _currentIItemsFeature;
@@ -98,6 +100,7 @@ private void FastReset()
_currentIQueryFeature = null;
_currentIFormFeature = null;
_currentIHttp2StreamIdFeature = null;
+ _currentIHttpResponseCompletionFeature = null;
_currentIHttpResponseTrailersFeature = null;
_currentIResponseCookiesFeature = null;
_currentIItemsFeature = null;
@@ -224,6 +227,10 @@ object IFeatureCollection.this[Type key]
{
feature = _currentIHttp2StreamIdFeature;
}
+ else if (key == IHttpResponseCompletionFeatureType)
+ {
+ feature = _currentIHttpResponseCompletionFeature;
+ }
else if (key == IHttpResponseTrailersFeatureType)
{
feature = _currentIHttpResponseTrailersFeature;
@@ -348,6 +355,10 @@ object IFeatureCollection.this[Type key]
{
_currentIHttp2StreamIdFeature = value;
}
+ else if (key == IHttpResponseCompletionFeatureType)
+ {
+ _currentIHttpResponseCompletionFeature = value;
+ }
else if (key == IHttpResponseTrailersFeatureType)
{
_currentIHttpResponseTrailersFeature = value;
@@ -470,6 +481,10 @@ TFeature IFeatureCollection.Get()
{
feature = (TFeature)_currentIHttp2StreamIdFeature;
}
+ else if (typeof(TFeature) == typeof(IHttpResponseCompletionFeature))
+ {
+ feature = (TFeature)_currentIHttpResponseCompletionFeature;
+ }
else if (typeof(TFeature) == typeof(IHttpResponseTrailersFeature))
{
feature = (TFeature)_currentIHttpResponseTrailersFeature;
@@ -598,6 +613,10 @@ void IFeatureCollection.Set(TFeature feature)
{
_currentIHttp2StreamIdFeature = feature;
}
+ else if (typeof(TFeature) == typeof(IHttpResponseCompletionFeature))
+ {
+ _currentIHttpResponseCompletionFeature = feature;
+ }
else if (typeof(TFeature) == typeof(IHttpResponseTrailersFeature))
{
_currentIHttpResponseTrailersFeature = feature;
@@ -718,6 +737,10 @@ private IEnumerable> FastEnumerable()
{
yield return new KeyValuePair(IHttp2StreamIdFeatureType, _currentIHttp2StreamIdFeature);
}
+ if (_currentIHttpResponseCompletionFeature != null)
+ {
+ yield return new KeyValuePair(IHttpResponseCompletionFeatureType, _currentIHttpResponseCompletionFeature);
+ }
if (_currentIHttpResponseTrailersFeature != null)
{
yield return new KeyValuePair(IHttpResponseTrailersFeatureType, _currentIHttpResponseTrailersFeature);
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
index 41a4f2e6aa54..bc053825c40c 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
@@ -210,6 +210,7 @@ private void HttpVersionSetSlow(string value)
public bool RequestTrailersAvailable { get; set; }
public Stream RequestBody { get; set; }
public PipeReader RequestBodyPipeReader { get; set; }
+ public HttpResponseTrailers ResponseTrailers { get; set; }
private int _statusCode;
public int StatusCode
@@ -287,7 +288,9 @@ public CancellationToken RequestAborted
public bool HasResponseStarted => _requestProcessingStatus >= RequestProcessingStatus.HeadersCommitted;
- public bool HasFlushedHeaders => _requestProcessingStatus == RequestProcessingStatus.HeadersFlushed;
+ public bool HasFlushedHeaders => _requestProcessingStatus >= RequestProcessingStatus.HeadersFlushed;
+
+ public bool HasResponseCompleted => _requestProcessingStatus == RequestProcessingStatus.ResponseCompleted;
protected HttpRequestHeaders HttpRequestHeaders { get; }
@@ -632,9 +635,18 @@ private async Task ProcessRequests(IHttpApplication applicat
// Run the application code for this request
await application.ProcessRequestAsync(context);
- if (!_connectionAborted)
+ // Trigger OnStarting if it hasn't been called yet and the app hasn't
+ // already failed. If an OnStarting callback throws we can go through
+ // our normal error handling in ProduceEnd.
+ // https://github.com/aspnet/KestrelHttpServer/issues/43
+ if (!HasResponseStarted && _applicationException == null && _onStarting?.Count > 0)
{
- VerifyResponseContentLength();
+ await FireOnStarting();
+ }
+
+ if (!_connectionAborted && !VerifyResponseContentLength(out var lengthException))
+ {
+ ReportApplicationError(lengthException);
}
}
catch (BadHttpRequestException ex)
@@ -652,15 +664,6 @@ private async Task ProcessRequests(IHttpApplication applicat
KestrelEventSource.Log.RequestStop(this);
- // Trigger OnStarting if it hasn't been called yet and the app hasn't
- // already failed. If an OnStarting callback throws we can go through
- // our normal error handling in ProduceEnd.
- // https://github.com/aspnet/KestrelHttpServer/issues/43
- if (!HasResponseStarted && _applicationException == null && _onStarting?.Count > 0)
- {
- await FireOnStarting();
- }
-
// At this point all user code that needs use to the request or response streams has completed.
// Using these streams in the OnCompleted callback is not allowed.
StopBodies();
@@ -898,7 +901,7 @@ private void CheckLastWrite()
}
}
- protected void VerifyResponseContentLength()
+ protected bool VerifyResponseContentLength(out Exception ex)
{
var responseHeaders = HttpResponseHeaders;
@@ -915,9 +918,13 @@ protected void VerifyResponseContentLength()
_keepAlive = false;
}
- ReportApplicationError(new InvalidOperationException(
- CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value)));
+ ex = new InvalidOperationException(
+ CoreStrings.FormatTooFewBytesWritten(_responseBytesWritten, responseHeaders.ContentLength.Value));
+ return false;
}
+
+ ex = null;
+ return true;
}
public void ProduceContinue()
@@ -1045,6 +1052,11 @@ protected Task ProduceEnd()
private Task WriteSuffix()
{
+ if (HasResponseCompleted)
+ {
+ return Task.CompletedTask;
+ }
+
// _autoChunk should be checked after we are sure ProduceStart() has been called
// since ProduceStart() may set _autoChunk to true.
if (_autoChunk || _httpVersion == Http.HttpVersion.Http2)
@@ -1064,7 +1076,7 @@ private Task WriteSuffix()
if (!HasFlushedHeaders)
{
- _requestProcessingStatus = RequestProcessingStatus.HeadersFlushed;
+ _requestProcessingStatus = RequestProcessingStatus.ResponseCompleted;
return FlushAsyncInternal();
}
@@ -1080,6 +1092,8 @@ private async Task WriteSuffixAwaited()
await Output.WriteStreamSuffixAsync();
+ _requestProcessingStatus = RequestProcessingStatus.ResponseCompleted;
+
if (_keepAlive)
{
Log.ConnectionKeepAlive(ConnectionId);
@@ -1244,6 +1258,7 @@ private void SetErrorResponseHeaders(int statusCode)
var responseHeaders = HttpResponseHeaders;
responseHeaders.Reset();
+ ResponseTrailers?.Reset();
var dateHeaderValues = DateHeaderValueManager.GetDateHeaderValues();
responseHeaders.SetRawDate(dateHeaderValues.String, dateHeaderValues.Bytes);
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs b/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs
index 61832dc34bdf..6e27fb5dc807 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/RequestProcessingStatus.cs
@@ -10,6 +10,7 @@ internal enum RequestProcessingStatus
ParsingHeaders,
AppStarted,
HeadersCommitted,
- HeadersFlushed
+ HeadersFlushed,
+ ResponseCompleted
}
}
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs
index 4f481850d703..5f0d80a37217 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs
@@ -151,7 +151,7 @@ public void WriteResponseHeaders(int statusCode, string ReasonPhrase, HttpRespon
// 2. There is no trailing HEADERS frame.
Http2HeadersFrameFlags http2HeadersFrame;
- if (appCompleted && !_startedWritingDataFrames && (_stream.Trailers == null || _stream.Trailers.Count == 0))
+ if (appCompleted && !_startedWritingDataFrames && (_stream.ResponseTrailers == null || _stream.ResponseTrailers.Count == 0))
{
_streamEnded = true;
http2HeadersFrame = Http2HeadersFrameFlags.END_STREAM;
@@ -313,7 +313,7 @@ private async ValueTask ProcessDataWrites()
{
readResult = await _dataPipe.Reader.ReadAsync();
- if (readResult.IsCompleted && _stream.Trailers?.Count > 0)
+ if (readResult.IsCompleted && _stream.ResponseTrailers?.Count > 0)
{
// Output is ending and there are trailers to write
// Write any remaining content then write trailers
@@ -322,7 +322,8 @@ private async ValueTask ProcessDataWrites()
flushResult = await _frameWriter.WriteDataAsync(_streamId, _flowControl, readResult.Buffer, endStream: false);
}
- flushResult = await _frameWriter.WriteResponseTrailers(_streamId, _stream.Trailers);
+ _stream.ResponseTrailers.SetReadOnly();
+ flushResult = await _frameWriter.WriteResponseTrailers(_streamId, _stream.ResponseTrailers);
}
else if (readResult.IsCompleted && _streamEnded)
{
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs
index 7187fc846240..fb27e387d3a4 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs
@@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
@@ -11,21 +12,25 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
{
internal partial class Http2Stream : IHttp2StreamIdFeature,
IHttpMinRequestBodyDataRateFeature,
+ IHttpResponseCompletionFeature,
IHttpResponseTrailersFeature
{
- internal HttpResponseTrailers Trailers { get; set; }
private IHeaderDictionary _userTrailers;
IHeaderDictionary IHttpResponseTrailersFeature.Trailers
{
get
{
- if (Trailers == null)
+ if (ResponseTrailers == null)
{
- Trailers = new HttpResponseTrailers();
+ ResponseTrailers = new HttpResponseTrailers();
+ if (HasResponseCompleted)
+ {
+ ResponseTrailers.SetReadOnly();
+ }
}
- return _userTrailers ?? Trailers;
+ return _userTrailers ?? ResponseTrailers;
}
set
{
@@ -48,5 +53,25 @@ MinDataRate IHttpMinRequestBodyDataRateFeature.MinDataRate
MinRequestBodyDataRate = value;
}
}
+
+ async Task IHttpResponseCompletionFeature.CompleteAsync()
+ {
+ // Finalize headers
+ if (!HasResponseStarted)
+ {
+ await FireOnStarting();
+ }
+
+ // Flush headers, body, trailers...
+ if (!HasResponseCompleted)
+ {
+ if (!VerifyResponseContentLength(out var lengthException))
+ {
+ throw lengthException;
+ }
+
+ await ProduceEnd();
+ }
+ }
}
}
diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs
index 09c9dafae6ff..7bdaa893a90b 100644
--- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs
+++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs
@@ -1839,6 +1839,32 @@ await InitializeConnectionAsync(context =>
Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
}
+ [Fact]
+ public async Task ResponseTrailers_WithExeption500_Cleared()
+ {
+ await InitializeConnectionAsync(context =>
+ {
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+ throw new NotImplementedException("Test Exception");
+ });
+
+ await StartStreamAsync(1, _browserRequestHeaders, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM | Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("500", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
+ }
+
[Fact]
public async Task ResponseTrailers_WithData_Sent()
{
@@ -3307,5 +3333,779 @@ await InitializeConnectionAsync(async context =>
Assert.Contains(TestSink.Writes, w => w.EventId.Id == 13 && w.LogLevel == LogLevel.Error
&& w.Exception is ConnectionAbortedException && w.Exception.InnerException == expectedException);
}
+
+ [Fact]
+ public async Task CompleteAsync_BeforeBodyStarted_SendsHeadersWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders["content-length"]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_BeforeBodyStarted_WithTrailers_SendsHeadersAndTrailersWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+ await completionFeature.CompleteAsync().DefaultTimeout(); // Can be called twice, no-ops
+
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders["content-length"]);
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_BeforeBodyStarted_WithTrailers_TruncatedContentLength_ThrowsAnd500()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ context.Response.ContentLength = 25;
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ var ex = await Assert.ThrowsAsync(() => completionFeature.CompleteAsync().DefaultTimeout());
+ Assert.Equal(CoreStrings.FormatTooFewBytesWritten(0, 25), ex.Message);
+
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully);
+ Assert.False(context.Response.Headers.IsReadOnly);
+ Assert.False(context.Features.Get().Trailers.IsReadOnly);
+
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("500", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_AfterBodyStarted_SendsBodyWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+ await completionFeature.CompleteAsync().DefaultTimeout(); // Can be called twice, no-ops
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ await ExpectAsync(Http2FrameType.DATA,
+ withLength: 0,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+ }
+
+ [Fact]
+ public async Task CompleteAsync_WriteAfterComplete_Throws()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ var ex = await Assert.ThrowsAsync(() => context.Response.WriteAsync("2 Hello World").DefaultTimeout());
+ Assert.Equal("Writing is not allowed after writer was completed.", ex.Message);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 55,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_WriteAgainAfterComplete_Throws()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World").DefaultTimeout();
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ var ex = await Assert.ThrowsAsync(() => context.Response.WriteAsync("2 Hello World").DefaultTimeout());
+ Assert.Equal("Writing is not allowed after writer was completed.", ex.Message);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ await ExpectAsync(Http2FrameType.DATA,
+ withLength: 0,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+ }
+
+ [Fact]
+ public async Task CompleteAsync_AfterPipeWrite_WithTrailers_SendsBodyAndTrailersWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ var buffer = context.Response.BodyWriter.GetMemory();
+ var length = Encoding.UTF8.GetBytes("Hello World", buffer.Span);
+ context.Response.BodyWriter.Advance(length);
+
+ Assert.False(startingTcs.Task.IsCompletedSuccessfully); // OnStarting did not get called.
+ Assert.False(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_AfterBodyStarted_WithTrailers_SendsBodyAndTrailersWithEndStream()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
+
+ [Fact]
+ public async Task CompleteAsync_AfterBodyStarted_WithTrailers_TruncatedContentLength_ThrowsAndReset()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ context.Response.ContentLength = 25;
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ var ex = await Assert.ThrowsAsync(() => completionFeature.CompleteAsync().DefaultTimeout());
+ Assert.Equal(CoreStrings.FormatTooFewBytesWritten(11, 25), ex.Message);
+
+ Assert.False(context.Features.Get().Trailers.IsReadOnly);
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 56,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+
+ clientTcs.SetResult(0);
+
+ await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR,
+ expectedErrorMessage: CoreStrings.FormatTooFewBytesWritten(11, 25));
+
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(3, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+ Assert.Equal("25", _decodedHeaders[HeaderNames.ContentLength]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+ }
+
+ [Fact]
+ public async Task AbortAfterCompleteAsync_GETWithResponseBodyAndTrailers_ResetsAfterResponse()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "GET"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // RequestAborted will no longer fire after CompleteAsync.
+ Assert.False(context.RequestAborted.CanBeCanceled);
+ context.Abort();
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: true);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+ await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR, expectedErrorMessage: null);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
+
+ [Fact]
+ public async Task AbortAfterCompleteAsync_POSTWithResponseBodyAndTrailers_RequestBodyThrows()
+ {
+ var startingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var clientTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ var headers = new[]
+ {
+ new KeyValuePair(HeaderNames.Method, "POST"),
+ new KeyValuePair(HeaderNames.Path, "/"),
+ new KeyValuePair(HeaderNames.Scheme, "http"),
+ };
+ await InitializeConnectionAsync(async context =>
+ {
+ try
+ {
+ var requestBodyTask = context.Request.BodyReader.ReadAsync();
+
+ context.Response.OnStarting(() => { startingTcs.SetResult(0); return Task.CompletedTask; });
+ var completionFeature = context.Features.Get();
+ Assert.NotNull(completionFeature);
+
+ await context.Response.WriteAsync("Hello World");
+ Assert.True(startingTcs.Task.IsCompletedSuccessfully); // OnStarting got called.
+ Assert.True(context.Response.Headers.IsReadOnly);
+
+ context.Response.AppendTrailer("CustomName", "Custom Value");
+
+ await completionFeature.CompleteAsync().DefaultTimeout();
+
+ Assert.True(context.Features.Get().Trailers.IsReadOnly);
+
+ // RequestAborted will no longer fire after CompleteAsync.
+ Assert.False(context.RequestAborted.CanBeCanceled);
+ context.Abort();
+
+ await Assert.ThrowsAsync(async () => await requestBodyTask);
+ await Assert.ThrowsAsync(async () => await context.Request.BodyReader.ReadAsync());
+
+ // Make sure the client gets our results from CompleteAsync instead of from the request delegate exiting.
+ await clientTcs.Task.DefaultTimeout();
+ appTcs.SetResult(0);
+ }
+ catch (Exception ex)
+ {
+ appTcs.SetException(ex);
+ }
+ });
+
+ await StartStreamAsync(1, headers, endStream: false);
+
+ var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 37,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS),
+ withStreamId: 1);
+ var bodyFrame = await ExpectAsync(Http2FrameType.DATA,
+ withLength: 11,
+ withFlags: (byte)(Http2HeadersFrameFlags.NONE),
+ withStreamId: 1);
+ var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS,
+ withLength: 25,
+ withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM),
+ withStreamId: 1);
+ await WaitForStreamErrorAsync(1, Http2ErrorCode.INTERNAL_ERROR, expectedErrorMessage: null);
+
+ clientTcs.SetResult(0);
+ await appTcs.Task;
+
+ await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
+
+ _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
+
+ Assert.Equal(2, _decodedHeaders.Count);
+ Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
+ Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
+
+ Assert.Equal("Hello World", Encoding.UTF8.GetString(bodyFrame.Payload.Span));
+
+ _decodedHeaders.Clear();
+
+ _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this);
+
+ Assert.Single(_decodedHeaders);
+ Assert.Equal("Custom Value", _decodedHeaders["CustomName"]);
+ }
}
}
diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs
index 826677071924..d30a30a9eb34 100644
--- a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs
+++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs
@@ -35,6 +35,7 @@ public static string GenerateFile()
{
"IHttpUpgradeFeature",
"IHttp2StreamIdFeature",
+ "IHttpResponseCompletionFeature",
"IHttpResponseTrailersFeature",
"IResponseCookiesFeature",
"IItemsFeature",