diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestHeaders.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestHeaders.cs index 961a08c38883..b51b1de32694 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestHeaders.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpRequestHeaders.cs @@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { internal sealed partial class HttpRequestHeaders : HttpHeaders { + private EnumeratorCache? _enumeratorCache; private long _previousBits; public bool ReuseHeaderValues { get; set; } @@ -65,6 +66,7 @@ protected override void ClearFast() // Clear ContentLength and any unknown headers as we will never reuse them _contentLength = null; MaybeUnknown?.Clear(); + _enumeratorCache?.Reset(); } private static long ParseContentLength(string value) @@ -148,7 +150,73 @@ public Enumerator GetEnumerator() protected override IEnumerator> GetEnumeratorFast() { - return GetEnumerator(); + // Get or create the cache. + var cache = _enumeratorCache ??= new(); + + EnumeratorBox enumerator; + if (cache.CachedEnumerator is not null) + { + // Previous enumerator, reuse that one. + enumerator = cache.InUseEnumerator = cache.CachedEnumerator; + // Set previous to null so if there is a second enumerator call + // during the same request it doesn't get the same one. + cache.CachedEnumerator = null; + } + else + { + // Create new enumerator box and store as in use. + enumerator = cache.InUseEnumerator = new(); + } + + // Set the underlying struct enumerator to a new one. + enumerator.Enumerator = new Enumerator(this); + return enumerator; + } + + private class EnumeratorCache + { + /// + /// Enumerator created from previous request + /// + public EnumeratorBox? CachedEnumerator { get; set; } + /// + /// Enumerator used on this request + /// + public EnumeratorBox? InUseEnumerator { get; set; } + + /// + /// Moves InUseEnumerator to CachedEnumerator + /// + public void Reset() + { + var enumerator = InUseEnumerator; + if (enumerator is not null) + { + InUseEnumerator = null; + enumerator.Enumerator = default; + CachedEnumerator = enumerator; + } + } + } + + /// + /// Strong box enumerator for the IEnumerator interface to cache and amortizate the + /// IEnumerator allocations across requests if the header collection is commonly + /// enumerated for forwarding in a reverse-proxy type situation. + /// + private class EnumeratorBox : IEnumerator> + { + public Enumerator Enumerator; + + public KeyValuePair Current => Enumerator.Current; + + public bool MoveNext() => Enumerator.MoveNext(); + + object IEnumerator.Current => Current; + + public void Dispose() { } + + public void Reset() => throw new NotSupportedException(); } public partial struct Enumerator : IEnumerator> diff --git a/src/Servers/Kestrel/Core/test/HttpRequestHeadersTests.cs b/src/Servers/Kestrel/Core/test/HttpRequestHeadersTests.cs index 491d6345bf95..ad924c44d55d 100644 --- a/src/Servers/Kestrel/Core/test/HttpRequestHeadersTests.cs +++ b/src/Servers/Kestrel/Core/test/HttpRequestHeadersTests.cs @@ -99,9 +99,64 @@ public void SameExceptionThrownForMissingKey() } [Fact] - public void EntriesCanBeEnumerated() + public void EntriesCanBeEnumeratedAfterResets() + { + HttpRequestHeaders headers = new HttpRequestHeaders(); + + EnumerateEntries((IHeaderDictionary)headers); + headers.Reset(); + EnumerateEntries((IDictionary)headers); + headers.Reset(); + EnumerateEntries((IHeaderDictionary)headers); + headers.Reset(); + EnumerateEntries((IDictionary)headers); + } + + [Fact] + public void EnumeratorNotReusedBeforeReset() + { + HttpRequestHeaders headers = new HttpRequestHeaders(); + IEnumerable> enumerable = headers; + + var enumerator0 = enumerable.GetEnumerator(); + var enumerator1 = enumerable.GetEnumerator(); + + Assert.NotSame(enumerator0, enumerator1); + } + + [Fact] + public void EnumeratorReusedAfterReset() + { + HttpRequestHeaders headers = new HttpRequestHeaders(); + IEnumerable> enumerable = headers; + + var enumerator0 = enumerable.GetEnumerator(); + headers.Reset(); + var enumerator1 = enumerable.GetEnumerator(); + + Assert.Same(enumerator0, enumerator1); + } + + private static void EnumerateEntries(IHeaderDictionary headers) + { + var v1 = new[] { "localhost" }; + var v2 = new[] { "0" }; + var v3 = new[] { "value" }; + headers.Host = v1; + headers.ContentLength = 0; + headers["custom"] = v3; + + Assert.Equal( + new[] { + new KeyValuePair("Host", v1), + new KeyValuePair("Content-Length", v2), + new KeyValuePair("custom", v3), + }, + headers); + } + + private static void EnumerateEntries(IDictionary headers) { - IDictionary headers = new HttpRequestHeaders(); var v1 = new[] { "localhost" }; var v2 = new[] { "0" }; var v3 = new[] { "value" };