diff --git a/src/Servers/Connections.Abstractions/src/Features/IPersistentStateFeature.cs b/src/Servers/Connections.Abstractions/src/Features/IPersistentStateFeature.cs new file mode 100644 index 000000000000..6b9c5c9ff187 --- /dev/null +++ b/src/Servers/Connections.Abstractions/src/Features/IPersistentStateFeature.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Connections.Features +{ + /// + /// Provides access to a key/value collection that can be used to persist state between connections and requests. + /// Whether a transport supports persisting state depends on the implementation. The transport must support + /// pooling and reusing connection instances for state to be persisted. + /// + /// Because values added to persistent state can live in memory until a connection is no longer pooled, + /// use caution when adding items to this collection to avoid excessive memory use. + /// + /// + public interface IPersistentStateFeature + { + /// + /// Gets a key/value collection that can be used to persist state between connections and requests. + /// + IDictionary State { get; } + } +} diff --git a/src/Servers/Connections.Abstractions/src/PublicAPI.Unshipped.txt b/src/Servers/Connections.Abstractions/src/PublicAPI.Unshipped.txt index 88c8c23149f2..fefd6f14e6df 100644 --- a/src/Servers/Connections.Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Servers/Connections.Abstractions/src/PublicAPI.Unshipped.txt @@ -2,6 +2,8 @@ *REMOVED*Microsoft.AspNetCore.Connections.IConnectionListener.AcceptAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask Microsoft.AspNetCore.Connections.Features.IConnectionSocketFeature Microsoft.AspNetCore.Connections.Features.IConnectionSocketFeature.Socket.get -> System.Net.Sockets.Socket! +Microsoft.AspNetCore.Connections.Features.IPersistentStateFeature +Microsoft.AspNetCore.Connections.Features.IPersistentStateFeature.State.get -> System.Collections.Generic.IDictionary! Microsoft.AspNetCore.Connections.IConnectionListener.AcceptAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask Microsoft.AspNetCore.Connections.IMultiplexedConnectionBuilder Microsoft.AspNetCore.Connections.IMultiplexedConnectionBuilder.ApplicationServices.get -> System.IServiceProvider! diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.FeatureCollection.cs index c89165524c58..cb75fc80da29 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.FeatureCollection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.FeatureCollection.cs @@ -1,13 +1,21 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { internal partial class Http1Connection : IHttpMinRequestBodyDataRateFeature, - IHttpMinResponseDataRateFeature + IHttpMinResponseDataRateFeature, + IPersistentStateFeature { + // Persistent state collection is not reset with a request by design. + // If SocketsConections are pooled in the future this state could be moved + // to the transport layer. + private IDictionary? _persistentState; + MinDataRate? IHttpMinRequestBodyDataRateFeature.MinDataRate { get => MinRequestBodyDataRate; @@ -19,5 +27,14 @@ internal partial class Http1Connection : IHttpMinRequestBodyDataRateFeature, get => MinResponseDataRate; set => MinResponseDataRate = value; } + + IDictionary IPersistentStateFeature.State + { + get + { + // Lazily allocate persistent state + return _persistentState ?? (_persistentState = new ConnectionItems()); + } + } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs index 9c2ef6a80078..3899e896e73d 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1Connection.cs @@ -616,7 +616,6 @@ private void ValidateNonOriginHostHeader(string hostText) protected override void OnReset() { - _requestTimedOut = false; _requestTargetForm = HttpRequestTarget.Unknown; _absoluteRequestTarget = null; @@ -628,6 +627,7 @@ protected override void OnReset() // Reset Http1 Features _currentIHttpMinRequestBodyDataRateFeature = this; _currentIHttpMinResponseDataRateFeature = this; + _currentIPersistentStateFeature = this; } protected override void OnRequestProcessingEnding() diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs index 0ba146deec2e..a06a8fd77fb4 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features.Authentication; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; @@ -62,6 +63,7 @@ internal partial class HttpProtocol : IFeatureCollection, internal protected IHttpMinRequestBodyDataRateFeature? _currentIHttpMinRequestBodyDataRateFeature; internal protected IHttpMinResponseDataRateFeature? _currentIHttpMinResponseDataRateFeature; internal protected IHttpResetFeature? _currentIHttpResetFeature; + internal protected IPersistentStateFeature? _currentIPersistentStateFeature; private int _featureRevision; @@ -99,6 +101,7 @@ private void FastReset() _currentIHttpMinRequestBodyDataRateFeature = null; _currentIHttpMinResponseDataRateFeature = null; _currentIHttpResetFeature = null; + _currentIPersistentStateFeature = null; } // Internal for testing @@ -286,6 +289,10 @@ private void ExtraFeatureSet(Type key, object? value) { feature = _currentIHttpResetFeature; } + else if (key == typeof(IPersistentStateFeature)) + { + feature = _currentIPersistentStateFeature; + } else if (MaybeExtra != null) { feature = ExtraFeatureGet(key); @@ -414,6 +421,10 @@ private void ExtraFeatureSet(Type key, object? value) { _currentIHttpResetFeature = (IHttpResetFeature?)value; } + else if (key == typeof(IPersistentStateFeature)) + { + _currentIPersistentStateFeature = (IPersistentStateFeature?)value; + } else { ExtraFeatureSet(key, value); @@ -544,6 +555,10 @@ private void ExtraFeatureSet(Type key, object? value) { feature = Unsafe.As(ref _currentIHttpResetFeature); } + else if (typeof(TFeature) == typeof(IPersistentStateFeature)) + { + feature = Unsafe.As(ref _currentIPersistentStateFeature); + } else if (MaybeExtra != null) { feature = (TFeature?)(ExtraFeatureGet(typeof(TFeature))); @@ -680,6 +695,10 @@ private void ExtraFeatureSet(Type key, object? value) { _currentIHttpResetFeature = Unsafe.As(ref feature); } + else if (typeof(TFeature) == typeof(IPersistentStateFeature)) + { + _currentIPersistentStateFeature = Unsafe.As(ref feature); + } else { ExtraFeatureSet(typeof(TFeature), feature); @@ -804,6 +823,10 @@ private IEnumerable> FastEnumerable() { yield return new KeyValuePair(typeof(IHttpResetFeature), _currentIHttpResetFeature); } + if (_currentIPersistentStateFeature != null) + { + yield return new KeyValuePair(typeof(IPersistentStateFeature), _currentIPersistentStateFeature); + } if (MaybeExtra != null) { diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs index c38b525462f5..cdd16fe45da0 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.FeatureCollection.cs @@ -4,6 +4,7 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; @@ -14,11 +15,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 internal partial class Http2Stream : IHttp2StreamIdFeature, IHttpMinRequestBodyDataRateFeature, IHttpResetFeature, - IHttpResponseTrailersFeature - + IHttpResponseTrailersFeature, + IPersistentStateFeature { private IHeaderDictionary? _userTrailers; + // Persistent state collection is not reset with a stream by design. + private IDictionary? _persistentState; + IHeaderDictionary IHttpResponseTrailersFeature.Trailers { get @@ -65,5 +69,14 @@ void IHttpResetFeature.Reset(int errorCode) var abortReason = new ConnectionAbortedException(CoreStrings.FormatHttp2StreamResetByApplication((Http2ErrorCode)errorCode)); ApplicationAbort(abortReason, (Http2ErrorCode)errorCode); } + + IDictionary IPersistentStateFeature.State + { + get + { + // Lazily allocate persistent state + return _persistentState ?? (_persistentState = new ConnectionItems()); + } + } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs index 723ef36b891f..70c641da551f 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs @@ -119,6 +119,7 @@ protected override void OnReset() _currentIHttp2StreamIdFeature = this; _currentIHttpResponseTrailersFeature = this; _currentIHttpResetFeature = this; + _currentIPersistentStateFeature = this; } protected override void OnRequestProcessingEnded() diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs index f690efa1bcb2..0a24eb467697 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs @@ -13,6 +13,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.Extensions.Logging; diff --git a/src/Servers/Kestrel/Core/test/Http1HttpProtocolFeatureCollectionTests.cs b/src/Servers/Kestrel/Core/test/Http1HttpProtocolFeatureCollectionTests.cs index d7d15d61250f..69e6d0673ae1 100644 --- a/src/Servers/Kestrel/Core/test/Http1HttpProtocolFeatureCollectionTests.cs +++ b/src/Servers/Kestrel/Core/test/Http1HttpProtocolFeatureCollectionTests.cs @@ -5,6 +5,7 @@ using System.IO.Pipelines; using System.Linq; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; @@ -123,6 +124,8 @@ public void FeaturesSetByTypeSameAsGeneric() _collection[typeof(IHttpBodyControlFeature)] = CreateHttp1Connection(); _collection[typeof(IRouteValuesFeature)] = CreateHttp1Connection(); _collection[typeof(IEndpointFeature)] = CreateHttp1Connection(); + _collection[typeof(IHttpUpgradeFeature)] = CreateHttp1Connection(); + _collection[typeof(IPersistentStateFeature)] = CreateHttp1Connection(); CompareGenericGetterToIndexer(); @@ -147,6 +150,8 @@ public void FeaturesSetByGenericSameAsByType() _collection.Set(CreateHttp1Connection()); _collection.Set(CreateHttp1Connection()); _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); + _collection.Set(CreateHttp1Connection()); CompareGenericGetterToIndexer(); @@ -190,13 +195,21 @@ private void CompareGenericGetterToIndexer() private int EachHttpProtocolFeatureSetAndUnique() { - int featureCount = 0; + var featureCount = 0; foreach (var item in _collection) { - Type type = item.Key; + var type = item.Key; if (type.IsAssignableFrom(typeof(HttpProtocol))) { - Assert.Equal(1, _collection.Count(kv => ReferenceEquals(kv.Value, item.Value))); + var matches = _collection.Where(kv => ReferenceEquals(kv.Value, item.Value)).ToList(); + try + { + Assert.Single(matches); + } + catch (Exception ex) + { + throw new Exception($"Error for feature {type}.", ex); + } featureCount++; } diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.FeatureCollection.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.FeatureCollection.cs new file mode 100644 index 000000000000..dcebdac0ad6d --- /dev/null +++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.FeatureCollection.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Net.Sockets; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Internal +{ + internal sealed partial class QuicStreamContext : IPersistentStateFeature + { + private IDictionary? _persistentState; + + IDictionary IPersistentStateFeature.State + { + get + { + // Lazily allocate persistent state + return _persistentState ?? (_persistentState = new ConnectionItems()); + } + } + + private void InitializeFeatures() + { + _currentIPersistentStateFeature = this; + } + } +} diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.cs index 9443cabba3df..45505649b237 100644 --- a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.cs +++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicStreamContext.cs @@ -16,7 +16,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Internal { - internal class QuicStreamContext : TransportConnection, IStreamDirectionFeature, IProtocolErrorCodeFeature, IStreamIdFeature, IPooledStream + internal partial class QuicStreamContext : TransportConnection, IStreamDirectionFeature, IProtocolErrorCodeFeature, IStreamIdFeature, IPooledStream { // Internal for testing. internal Task _processingTask = Task.CompletedTask; @@ -87,12 +87,16 @@ public void Initialize(QuicStream stream) } ConnectionClosed = _streamClosedTokenSource.Token; + + // TODO - add to generated features Features.Set(this); Features.Set(this); Features.Set(this); - // TODO populate the ITlsConnectionFeature (requires client certs). Features.Set(new FakeTlsConnectionFeature()); + + InitializeFeatures(); + CanRead = _stream.CanRead; CanWrite = _stream.CanWrite; Error = 0; @@ -132,6 +136,8 @@ public override string ConnectionId public void Start() { + Debug.Assert(_processingTask.IsCompletedSuccessfully); + _processingTask = StartAsync(); } diff --git a/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionContextTests.cs b/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionContextTests.cs index ada6add50cbe..cbf601cd4a4c 100644 --- a/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionContextTests.cs +++ b/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionContextTests.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.FunctionalTests; using Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Internal; using Microsoft.AspNetCore.Testing; @@ -403,7 +404,7 @@ public async Task StreamPool_ManyConcurrentStreams_StreamPoolFull() // TODO: Race condition in QUIC library. // Delay between sending streams to avoid // https://github.com/dotnet/runtime/issues/55249 - await Task.Delay(50); + await Task.Delay(100); streamTasks.Add(SendStream(requestState)); } @@ -451,6 +452,68 @@ static async Task SendStream(RequestState requestState) } } + [ConditionalFact] + [MsQuicSupported] + public async Task PersistentState_StreamsReused_StatePersisted() + { + // Arrange + await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory); + + var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint); + using var clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, options); + await clientConnection.ConnectAsync().DefaultTimeout(); + + await using var serverConnection = await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout(); + + // Act + var clientStream1 = clientConnection.OpenBidirectionalStream(); + await clientStream1.WriteAsync(TestData, endStream: true).DefaultTimeout(); + var serverStream1 = await serverConnection.AcceptAsync().DefaultTimeout(); + var readResult1 = await serverStream1.Transport.Input.ReadAtLeastAsync(TestData.Length).DefaultTimeout(); + serverStream1.Transport.Input.AdvanceTo(readResult1.Buffer.End); + + serverStream1.Features.Get().State["test"] = true; + + // Input should be completed. + readResult1 = await serverStream1.Transport.Input.ReadAsync(); + Assert.True(readResult1.IsCompleted); + + // Complete reading and writing. + await serverStream1.Transport.Input.CompleteAsync(); + await serverStream1.Transport.Output.CompleteAsync(); + + var quicStreamContext1 = Assert.IsType(serverStream1); + await quicStreamContext1._processingTask.DefaultTimeout(); + await quicStreamContext1.DisposeAsync(); + + var clientStream2 = clientConnection.OpenBidirectionalStream(); + await clientStream2.WriteAsync(TestData, endStream: true).DefaultTimeout(); + var serverStream2 = await serverConnection.AcceptAsync().DefaultTimeout(); + var readResult2 = await serverStream2.Transport.Input.ReadAtLeastAsync(TestData.Length).DefaultTimeout(); + serverStream2.Transport.Input.AdvanceTo(readResult2.Buffer.End); + + object state = serverStream2.Features.Get().State["test"]; + + // Input should be completed. + readResult2 = await serverStream2.Transport.Input.ReadAsync(); + Assert.True(readResult2.IsCompleted); + + // Complete reading and writing. + await serverStream2.Transport.Input.CompleteAsync(); + await serverStream2.Transport.Output.CompleteAsync(); + + var quicStreamContext2 = Assert.IsType(serverStream2); + await quicStreamContext2._processingTask.DefaultTimeout(); + await quicStreamContext2.DisposeAsync(); + + Assert.Same(quicStreamContext1, quicStreamContext2); + + var quicConnectionContext = Assert.IsType(serverConnection); + Assert.Equal(1, quicConnectionContext.StreamPool.Count); + + Assert.Equal(true, state); + } + private record RequestState( QuicConnection QuicConnection, MultiplexedConnectionContext ServerConnection, diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.FeatureCollection.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.FeatureCollection.cs index 8645f7e6459c..c3e16df6cfde 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.FeatureCollection.cs +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.FeatureCollection.cs @@ -10,7 +10,7 @@ internal sealed partial class SocketConnection : IConnectionSocketFeature { public Socket Socket => _socket; - private void InitiaizeFeatures() + private void InitializeFeatures() { _currentIConnectionSocketFeature = this; } diff --git a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs index 7eb4a649f103..49a2873c2d81 100644 --- a/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs +++ b/src/Servers/Kestrel/Transport.Sockets/src/Internal/SocketConnection.cs @@ -73,7 +73,7 @@ internal SocketConnection(Socket socket, Transport = new SocketDuplexPipe(this); - InitiaizeFeatures(); + InitializeFeatures(); } public IDuplexPipe InnerTransport => _originalTransport; diff --git a/src/Servers/Kestrel/shared/KnownHeaders.cs b/src/Servers/Kestrel/shared/KnownHeaders.cs index de949e005b52..e06ceb0037e3 100644 --- a/src/Servers/Kestrel/shared/KnownHeaders.cs +++ b/src/Servers/Kestrel/shared/KnownHeaders.cs @@ -768,7 +768,7 @@ public static string GeneratedFile() offset += header.BytesCount; } } - return $@"// Copyright (c) .NET Foundation. All rights reserved. + var s = $@"// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -1337,6 +1337,9 @@ public bool MoveNext() }} }} ")}}}"; + + // Temporary workaround for https://github.com/dotnet/runtime/issues/55688 + return s.Replace("{{", "{").Replace("}}", "}"); } private static string GetHeaderLookup() diff --git a/src/Servers/Kestrel/shared/PooledStreamStack.cs b/src/Servers/Kestrel/shared/PooledStreamStack.cs index 486b983f36df..c4d36038ef73 100644 --- a/src/Servers/Kestrel/shared/PooledStreamStack.cs +++ b/src/Servers/Kestrel/shared/PooledStreamStack.cs @@ -14,7 +14,7 @@ internal interface IPooledStream void DisposeCore(); } - // See https://github.com/dotnet/runtime/blob/master/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/BufferSegmentStack.cs + // See https://github.com/dotnet/runtime/blob/da9b16f2804e87c9c1ca9dcd9036e7b53e724f5d/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/BufferSegmentStack.cs internal struct PooledStreamStack where TValue : class, IPooledStream { // Internal for testing @@ -139,6 +139,7 @@ private static int CalculateRemoveCount(long now, int size, StreamAsValueType[] return size; } + // See https://github.com/dotnet/runtime/blob/da9b16f2804e87c9c1ca9dcd9036e7b53e724f5d/src/libraries/System.IO.Pipelines/src/System/IO/Pipelines/BufferSegmentStack.cs#L68-L79 internal readonly struct StreamAsValueType { private readonly TValue _value; diff --git a/src/Servers/Kestrel/shared/TransportConnection.FeatureCollection.cs b/src/Servers/Kestrel/shared/TransportConnection.FeatureCollection.cs index 757927e2f433..7e5547f6b620 100644 --- a/src/Servers/Kestrel/shared/TransportConnection.FeatureCollection.cs +++ b/src/Servers/Kestrel/shared/TransportConnection.FeatureCollection.cs @@ -6,6 +6,7 @@ using System.IO.Pipelines; using System.Threading; using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; #nullable enable diff --git a/src/Servers/Kestrel/shared/TransportConnection.Generated.cs b/src/Servers/Kestrel/shared/TransportConnection.Generated.cs index 18b69ce0e46d..9003975b00ce 100644 --- a/src/Servers/Kestrel/shared/TransportConnection.Generated.cs +++ b/src/Servers/Kestrel/shared/TransportConnection.Generated.cs @@ -28,6 +28,7 @@ internal partial class TransportConnection : IFeatureCollection, internal protected IConnectionLifetimeFeature? _currentIConnectionLifetimeFeature; // Other reserved feature slots + internal protected IPersistentStateFeature? _currentIPersistentStateFeature; internal protected IConnectionSocketFeature? _currentIConnectionSocketFeature; private int _featureRevision; @@ -42,6 +43,7 @@ private void FastReset() _currentIMemoryPoolFeature = this; _currentIConnectionLifetimeFeature = this; + _currentIPersistentStateFeature = null; _currentIConnectionSocketFeature = null; } @@ -126,6 +128,10 @@ private void ExtraFeatureSet(Type key, object? value) { feature = _currentIConnectionItemsFeature; } + else if (key == typeof(IPersistentStateFeature)) + { + feature = _currentIPersistentStateFeature; + } else if (key == typeof(IMemoryPoolFeature)) { feature = _currentIMemoryPoolFeature; @@ -162,6 +168,10 @@ private void ExtraFeatureSet(Type key, object? value) { _currentIConnectionItemsFeature = (IConnectionItemsFeature?)value; } + else if (key == typeof(IPersistentStateFeature)) + { + _currentIPersistentStateFeature = (IPersistentStateFeature?)value; + } else if (key == typeof(IMemoryPoolFeature)) { _currentIMemoryPoolFeature = (IMemoryPoolFeature?)value; @@ -200,6 +210,10 @@ private void ExtraFeatureSet(Type key, object? value) { feature = Unsafe.As(ref _currentIConnectionItemsFeature); } + else if (typeof(TFeature) == typeof(IPersistentStateFeature)) + { + feature = Unsafe.As(ref _currentIPersistentStateFeature); + } else if (typeof(TFeature) == typeof(IMemoryPoolFeature)) { feature = Unsafe.As(ref _currentIMemoryPoolFeature); @@ -239,6 +253,10 @@ private void ExtraFeatureSet(Type key, object? value) { _currentIConnectionItemsFeature = Unsafe.As(ref feature); } + else if (typeof(TFeature) == typeof(IPersistentStateFeature)) + { + _currentIPersistentStateFeature = Unsafe.As(ref feature); + } else if (typeof(TFeature) == typeof(IMemoryPoolFeature)) { _currentIMemoryPoolFeature = Unsafe.As(ref feature); @@ -271,6 +289,10 @@ private IEnumerable> FastEnumerable() { yield return new KeyValuePair(typeof(IConnectionItemsFeature), _currentIConnectionItemsFeature); } + if (_currentIPersistentStateFeature != null) + { + yield return new KeyValuePair(typeof(IPersistentStateFeature), _currentIPersistentStateFeature); + } if (_currentIMemoryPoolFeature != null) { yield return new KeyValuePair(typeof(IMemoryPoolFeature), _currentIMemoryPoolFeature); diff --git a/src/Servers/Kestrel/shared/TransportMultiplexedConnection.FeatureCollection.cs b/src/Servers/Kestrel/shared/TransportMultiplexedConnection.FeatureCollection.cs index 918903b4bbe0..2f0eaa5876e4 100644 --- a/src/Servers/Kestrel/shared/TransportMultiplexedConnection.FeatureCollection.cs +++ b/src/Servers/Kestrel/shared/TransportMultiplexedConnection.FeatureCollection.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.Threading; using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Connections { diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs index 1d407e402fc5..197bf2033cfa 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2ConnectionTests.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.Collections; using System.Collections.Generic; using System.Globalization; using System.IO; @@ -13,7 +14,9 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; @@ -208,6 +211,123 @@ await ExpectAsync(Http2FrameType.HEADERS, await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); } + private class ResponseTrailersWrapper : IHeaderDictionary + { + readonly IHeaderDictionary _innerHeaders; + + public ResponseTrailersWrapper(IHeaderDictionary headers) + { + _innerHeaders = headers; + } + + public StringValues this[string key] { get => _innerHeaders[key]; set => _innerHeaders[key] = value; } + public long? ContentLength { get => _innerHeaders.ContentLength; set => _innerHeaders.ContentLength = value; } + public ICollection Keys => _innerHeaders.Keys; + public ICollection Values => _innerHeaders.Values; + public int Count => _innerHeaders.Count; + public bool IsReadOnly => _innerHeaders.IsReadOnly; + public void Add(string key, StringValues value) => _innerHeaders.Add(key, value); + public void Add(KeyValuePair item) => _innerHeaders.Add(item); + public void Clear() => _innerHeaders.Clear(); + public bool Contains(KeyValuePair item) => _innerHeaders.Contains(item); + public bool ContainsKey(string key) => _innerHeaders.ContainsKey(key); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => _innerHeaders.CopyTo(array, arrayIndex); + public IEnumerator> GetEnumerator() => _innerHeaders.GetEnumerator(); + public bool Remove(string key) => _innerHeaders.Remove(key); + public bool Remove(KeyValuePair item) => _innerHeaders.Remove(item); + public bool TryGetValue(string key, out StringValues value) => _innerHeaders.TryGetValue(key, out value); + IEnumerator IEnumerable.GetEnumerator() => _innerHeaders.GetEnumerator(); + } + + [Fact] + public async Task ResponseTrailers_MultipleStreams_Reset() + { + IEnumerable> requestHeaders = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/hello"), + new KeyValuePair(HeaderNames.Scheme, "http"), + new KeyValuePair(HeaderNames.Authority, "localhost:80"), + new KeyValuePair(HeaderNames.ContentType, "application/json") + }; + + var requestCount = 0; + await InitializeConnectionAsync(context => + { + requestCount++; + + var trailersFeature = context.Features.Get(); + + IHeaderDictionary trailers; + if (requestCount == 1) + { + trailers = new ResponseTrailersWrapper(trailersFeature.Trailers); + trailersFeature.Trailers = trailers; + } + else + { + trailers = trailersFeature.Trailers; + } + trailers["trailer-" + requestCount] = "true"; + return Task.CompletedTask; + }); + + await StartStreamAsync(1, requestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 36, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 1); + + var trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 16, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 1); + + _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); + + Assert.Single(_decodedHeaders); + Assert.Equal("true", _decodedHeaders["trailer-1"]); + + _decodedHeaders.Clear(); + + // Ping will trigger the stream to be returned to the pool so we can assert it + await SendPingAsync(Http2PingFrameFlags.NONE); + await ExpectAsync(Http2FrameType.PING, + withLength: 8, + withFlags: (byte)Http2PingFrameFlags.ACK, + withStreamId: 0); + await SendPingAsync(Http2PingFrameFlags.NONE); + await ExpectAsync(Http2FrameType.PING, + withLength: 8, + withFlags: (byte)Http2PingFrameFlags.ACK, + withStreamId: 0); + + // Stream has been returned to the pool + Assert.Equal(1, _connection.StreamPool.Count); + + await StartStreamAsync(3, requestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 6, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS), + withStreamId: 3); + + trailersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 16, + withFlags: (byte)(Http2HeadersFrameFlags.END_HEADERS | Http2HeadersFrameFlags.END_STREAM), + withStreamId: 3); + + _hpackDecoder.Decode(trailersFrame.PayloadSequence, endHeaders: true, handler: this); + + Assert.Single(_decodedHeaders); + Assert.Equal("true", _decodedHeaders["trailer-2"]); + + _decodedHeaders.Clear(); + + await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); + } + [Fact] public async Task StreamPool_SingleStream_ReturnedToPool() { @@ -304,9 +424,18 @@ await ExpectAsync(Http2FrameType.PING, public async Task StreamPool_MultipleStreamsInSequence_PooledStreamReused() { TaskCompletionSource appDelegateTcs = null; + object persistedState = null; + var requestCount = 0; await InitializeConnectionAsync(async context => { + requestCount++; + var persistentStateCollection = context.Features.Get().State; + if (persistentStateCollection.TryGetValue("Counter", out var value)) + { + persistedState = value; + } + persistentStateCollection["Counter"] = requestCount; await appDelegateTcs.Task; }); @@ -330,6 +459,9 @@ await ExpectAsync(Http2FrameType.HEADERS, Assert.True(_connection.StreamPool.TryPeek(out var pooledStream)); Assert.Equal(stream, pooledStream); + // First request has no persisted state + Assert.Null(persistedState); + appDelegateTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); await StartStreamAsync(3, _browserRequestHeaders, endStream: true); @@ -348,6 +480,9 @@ await ExpectAsync(Http2FrameType.HEADERS, Assert.True(_connection.StreamPool.TryPeek(out pooledStream)); Assert.Equal(stream, pooledStream); + // State persisted on first request was available on the second request + Assert.Equal(1, (int)persistedState); + await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false); async Task PingUntilStreamPooled(int expectedCount) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs index cb7a21fe86d9..c806025b24f6 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs @@ -9,6 +9,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; @@ -2055,6 +2056,68 @@ await connection.Receive( } } + [Fact] + public async Task PersistentStateBetweenRequests() + { + var testContext = new TestServiceContext(LoggerFactory); + object persistedState = null; + var requestCount = 0; + + await using (var server = new TestServer(context => + { + requestCount++; + var persistentStateCollection = context.Features.Get().State; + if (persistentStateCollection.TryGetValue("Counter", out var value)) + { + persistedState = value; + } + persistentStateCollection["Counter"] = requestCount; + return Task.CompletedTask; + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + // First request + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Content-Type: application/test", + "X-CustomHeader: customvalue", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + "Content-Length: 0", + $"Date: {testContext.DateHeaderValue}", + "", + ""); + var firstRequestState = persistedState; + + // Second request + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "Content-Type: application/test", + "X-CustomHeader: customvalue", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + "Content-Length: 0", + $"Date: {testContext.DateHeaderValue}", + "", + ""); + var secondRequestState = persistedState; + + // First request has no persisted state + Assert.Null(firstRequestState); + + // State persisted on first request was available on the second request + Assert.Equal(1, secondRequestState); + } + } + } + [Fact] public async Task Latin1HeaderValueAcceptedWhenLatin1OptionIsConfigured() { diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs index dcc2ccb12afa..ff4d33359069 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs @@ -6,6 +6,7 @@ using System.Text; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core; @@ -54,41 +55,24 @@ public void CompleteStream() public async Task POST_ServerCompletsWithoutReadingRequestBody_ClientGetsResponse() { // Arrange - var builder = GetHostBuilder() - .ConfigureWebHost(webHostBuilder => + var builder = CreateHttp3HostBuilder(async context => + { + var body = context.Request.Body; + + var data = new List(); + var buffer = new byte[1024]; + var readCount = 0; + while ((readCount = await body.ReadAsync(buffer).DefaultTimeout()) != -1) { - webHostBuilder - .UseKestrel(o => - { - o.ConfigureEndpointDefaults(listenOptions => - { - listenOptions.Protocols = HttpProtocols.Http3; - }); - }) - .UseUrls("https://127.0.0.1:0") - .Configure(app => - { - app.Run(async context => - { - var body = context.Request.Body; - - var data = new List(); - var buffer = new byte[1024]; - var readCount = 0; - while ((readCount = await body.ReadAsync(buffer).DefaultTimeout()) != -1) - { - data.AddRange(buffer.AsMemory(0, readCount).ToArray()); - if (data.Count == TestData.Length) - { - break; - } - } - - await context.Response.Body.WriteAsync(buffer.AsMemory(0, TestData.Length)); - }); - }); - }) - .ConfigureServices(AddTestLogging); + data.AddRange(buffer.AsMemory(0, readCount).ToArray()); + if (data.Count == TestData.Length) + { + break; + } + } + + await context.Response.Body.WriteAsync(buffer.AsMemory(0, TestData.Length)); + }); using (var host = builder.Build()) using (var client = new HttpClient()) @@ -124,6 +108,85 @@ public async Task POST_ServerCompletsWithoutReadingRequestBody_ClientGetsRespons } } + [ConditionalFact] + [MsQuicSupported] + public async Task GET_MultipleRequestsInSequence_ReusedState() + { + // Arrange + object persistedState = null; + var requestCount = 0; + + var builder = CreateHttp3HostBuilder(context => + { + requestCount++; + var persistentStateCollection = context.Features.Get().State; + if (persistentStateCollection.TryGetValue("Counter", out var value)) + { + persistedState = value; + } + persistentStateCollection["Counter"] = requestCount; + + return Task.CompletedTask; + }); + + using (var host = builder.Build()) + using (var client = new HttpClient()) + { + await host.StartAsync(); + + // Act + var request1 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/"); + request1.Version = HttpVersion.Version30; + request1.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var response1 = await client.SendAsync(request1); + response1.EnsureSuccessStatusCode(); + var firstRequestState = persistedState; + + // Delay to ensure the stream has enough time to return to pool + await Task.Delay(100); + + var request2 = new HttpRequestMessage(HttpMethod.Get, $"https://127.0.0.1:{host.GetPort()}/"); + request2.Version = HttpVersion.Version30; + request2.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + + var response2 = await client.SendAsync(request2); + response2.EnsureSuccessStatusCode(); + var secondRequestState = persistedState; + + // Assert + // First request has no persisted state + Assert.Null(firstRequestState); + + // State persisted on first request was available on the second request + Assert.Equal(1, secondRequestState); + + await host.StopAsync(); + } + } + + private IHostBuilder CreateHttp3HostBuilder(RequestDelegate requestDelegate) + { + return GetHostBuilder() + .ConfigureWebHost(webHostBuilder => + { + webHostBuilder + .UseKestrel(o => + { + o.ConfigureEndpointDefaults(listenOptions => + { + listenOptions.Protocols = HttpProtocols.Http3; + }); + }) + .UseUrls("https://127.0.0.1:0") + .Configure(app => + { + app.Run(requestDelegate); + }); + }) + .ConfigureServices(AddTestLogging); + } + public static IHostBuilder GetHostBuilder(long? maxReadBufferSize = null) { return new HostBuilder() diff --git a/src/Servers/Kestrel/tools/CodeGenerator/FeatureCollectionGenerator.cs b/src/Servers/Kestrel/tools/CodeGenerator/FeatureCollectionGenerator.cs index f8f4ad91b64a..ba6e36ae2df1 100644 --- a/src/Servers/Kestrel/tools/CodeGenerator/FeatureCollectionGenerator.cs +++ b/src/Servers/Kestrel/tools/CodeGenerator/FeatureCollectionGenerator.cs @@ -19,7 +19,7 @@ public static string GenerateFile(string namespaceName, string className, string Index = index }); - return $@"// Copyright (c) .NET Foundation. All rights reserved. + var s = $@"// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -211,6 +211,9 @@ private IEnumerable> FastEnumerable() }} }} "; + + // Temporary workaround for https://github.com/dotnet/runtime/issues/55688 + return s.Replace("{{", "{").Replace("}}", "}"); } static string Each(IEnumerable values, Func formatter) diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs index a081abeac55d..ed2b6da17dea 100644 --- a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs +++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs @@ -50,7 +50,8 @@ public static string GenerateFile() "IHttpMinResponseDataRateFeature", "IHttpBodyControlFeature", "IHttpRequestBodyDetectionFeature", - "IHttpResetFeature" + "IHttpResetFeature", + "IPersistentStateFeature" }; var allFeatures = alwaysFeatures @@ -80,6 +81,7 @@ public static string GenerateFile() }; var usings = $@" +using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features.Authentication; using Microsoft.AspNetCore.Server.Kestrel.Core.Features;"; diff --git a/src/Servers/Kestrel/tools/CodeGenerator/ReadOnlySpanStaticDataGenerator.cs b/src/Servers/Kestrel/tools/CodeGenerator/ReadOnlySpanStaticDataGenerator.cs index 9360db249197..fc56ec4ccfdc 100644 --- a/src/Servers/Kestrel/tools/CodeGenerator/ReadOnlySpanStaticDataGenerator.cs +++ b/src/Servers/Kestrel/tools/CodeGenerator/ReadOnlySpanStaticDataGenerator.cs @@ -19,7 +19,7 @@ public static string GenerateFile(string namespaceName, string className, IEnume Index = index }); - return $@"// Copyright (c) .NET Foundation. All rights reserved. + var s = $@"// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -36,6 +36,9 @@ internal partial class {className} }} }} "; + + // Temporary workaround for https://github.com/dotnet/runtime/issues/55688 + return s.Replace("{{", "{").Replace("}}", "}"); } private static string Each(IEnumerable values, Func formatter) diff --git a/src/Servers/Kestrel/tools/CodeGenerator/TransportConnectionFeatureCollection.cs b/src/Servers/Kestrel/tools/CodeGenerator/TransportConnectionFeatureCollection.cs index 912615803957..3c60a700e813 100644 --- a/src/Servers/Kestrel/tools/CodeGenerator/TransportConnectionFeatureCollection.cs +++ b/src/Servers/Kestrel/tools/CodeGenerator/TransportConnectionFeatureCollection.cs @@ -17,6 +17,7 @@ public static string GenerateFile() "IConnectionIdFeature", "IConnectionTransportFeature", "IConnectionItemsFeature", + "IPersistentStateFeature", "IMemoryPoolFeature", "IConnectionLifetimeFeature", "IConnectionSocketFeature"