Skip to content

Fix Kestrel psuedo header reuse #38585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7102,6 +7102,11 @@ protected override bool CopyToFast(KeyValuePair<string, StringValues>[] array, i
return true;
}

internal void ClearPseudoRequestHeaders()
{
_pseudoBits = _bits & 240;
_bits &= ~240;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static unsafe ushort ReadUnalignedLittleEndian_ushort(ref byte source)
Expand Down Expand Up @@ -17014,4 +17019,4 @@ public bool MoveNext()
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal sealed partial class HttpRequestHeaders : HttpHeaders
{
private EnumeratorCache? _enumeratorCache;
private long _previousBits;
private long _pseudoBits;

public bool ReuseHeaderValues { get; set; }
public Func<string, Encoding?> EncodingSelector { get; set; }
Expand Down Expand Up @@ -54,16 +55,19 @@ protected override void ClearFast()
if (!ReuseHeaderValues)
{
// If we aren't reusing headers clear them all
Clear(_bits);
Clear(_bits | _pseudoBits);
}
else
{
// If we are reusing headers, store the currently set headers for comparison later
_previousBits = _bits;
// Pseudo header bits were cleared at the start of a request to hide from the user.
// Keep those values for reuse.
_previousBits = _bits | _pseudoBits;
}

// Mark no headers as currently in use
_bits = 0;
_pseudoBits = 0;
// Clear ContentLength and any unknown headers as we will never reuse them
_contentLength = null;
MaybeUnknown?.Clear();
Expand Down
8 changes: 4 additions & 4 deletions src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ protected override bool TryParseRequest(ReadResult result, out bool endConnectio
// We don't need any of the parameters because we don't implement BeginRead to actually
// do the reading from a pipeline, nor do we use endConnection to report connection-level errors.
endConnection = !TryValidatePseudoHeaders();

// Suppress pseudo headers from the public headers collection.
HttpRequestHeaders.ClearPseudoRequestHeaders();

return true;
}

Expand Down Expand Up @@ -249,7 +253,6 @@ private bool TryValidatePseudoHeaders()
// enabling the use of HTTP to interact with non - HTTP services.
// A common example is TLS termination.
var headerScheme = HttpRequestHeaders.HeaderScheme.ToString();
HttpRequestHeaders.HeaderScheme = default; // Suppress pseduo headers from the public headers collection.
if (!ReferenceEquals(headerScheme, Scheme) &&
!string.Equals(headerScheme, Scheme, StringComparison.OrdinalIgnoreCase))
{
Expand All @@ -266,7 +269,6 @@ private bool TryValidatePseudoHeaders()
// :path (and query) - Required
// Must start with / except may be * for OPTIONS
var path = HttpRequestHeaders.HeaderPath.ToString();
HttpRequestHeaders.HeaderPath = default; // Suppress pseduo headers from the public headers collection.
RawTarget = path;

// OPTIONS - https://tools.ietf.org/html/rfc7540#section-8.1.2.3
Expand Down Expand Up @@ -304,7 +306,6 @@ private bool TryValidateMethod()
{
// :method
_methodText = HttpRequestHeaders.HeaderMethod.ToString();
HttpRequestHeaders.HeaderMethod = default; // Suppress pseduo headers from the public headers collection.
Method = HttpUtilities.GetKnownMethod(_methodText);

if (Method == HttpMethod.None)
Expand All @@ -331,7 +332,6 @@ private bool TryValidateAuthorityAndHost(out string hostText)
// Prefer this over Host

var authority = HttpRequestHeaders.HeaderAuthority;
HttpRequestHeaders.HeaderAuthority = default; // Suppress pseduo headers from the public headers collection.
var host = HttpRequestHeaders.HeaderHost;
if (!StringValues.IsNullOrEmpty(authority))
{
Expand Down
8 changes: 4 additions & 4 deletions src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,10 @@ protected override MessageBody CreateMessageBody()
protected override bool TryParseRequest(ReadResult result, out bool endConnection)
{
endConnection = !TryValidatePseudoHeaders();

// Suppress pseudo headers from the public headers collection.
HttpRequestHeaders.ClearPseudoRequestHeaders();

return true;
}

Expand Down Expand Up @@ -791,7 +795,6 @@ private bool TryValidatePseudoHeaders()
// proxy or gateway can translate requests for non - HTTP schemes,
// enabling the use of HTTP to interact with non - HTTP services.
var headerScheme = HttpRequestHeaders.HeaderScheme.ToString();
HttpRequestHeaders.HeaderScheme = default; // Suppress pseduo headers from the public headers collection.
if (!ReferenceEquals(headerScheme, Scheme) &&
!string.Equals(headerScheme, Scheme, StringComparison.OrdinalIgnoreCase))
{
Expand All @@ -808,7 +811,6 @@ private bool TryValidatePseudoHeaders()
// :path (and query) - Required
// Must start with / except may be * for OPTIONS
var path = HttpRequestHeaders.HeaderPath.ToString();
HttpRequestHeaders.HeaderPath = default; // Suppress pseduo headers from the public headers collection.
RawTarget = path;

// OPTIONS - https://tools.ietf.org/html/rfc7540#section-8.1.2.3
Expand Down Expand Up @@ -847,7 +849,6 @@ private bool TryValidateMethod()
{
// :method
_methodText = HttpRequestHeaders.HeaderMethod.ToString();
HttpRequestHeaders.HeaderMethod = default; // Suppress pseduo headers from the public headers collection.
Method = HttpUtilities.GetKnownMethod(_methodText);

if (Method == Http.HttpMethod.None)
Expand All @@ -874,7 +875,6 @@ private bool TryValidateAuthorityAndHost(out string hostText)
// Prefer this over Host

var authority = HttpRequestHeaders.HeaderAuthority;
HttpRequestHeaders.HeaderAuthority = default; // Suppress pseduo headers from the public headers collection.
var host = HttpRequestHeaders.HeaderHost;
if (!StringValues.IsNullOrEmpty(authority))
{
Expand Down
18 changes: 18 additions & 0 deletions src/Servers/Kestrel/Core/test/HttpRequestHeadersTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
Expand Down Expand Up @@ -112,6 +113,23 @@ public void EntriesCanBeEnumeratedAfterResets()
EnumerateEntries((IDictionary<string, StringValues>)headers);
}

[Fact]
public void ClearPseudoRequestHeadersPlusResetClearsHeaderReferenceValue()
{
const BindingFlags privateFlags = BindingFlags.NonPublic | BindingFlags.Instance;

HttpRequestHeaders headers = new HttpRequestHeaders(reuseHeaderValues: false);
headers.HeaderMethod = "GET";
headers.ClearPseudoRequestHeaders();
headers.Reset();

// Hacky but required because header references is private.
var headerReferences = typeof(HttpRequestHeaders).GetField("_headers", privateFlags).GetValue(headers);
var methodValue = (StringValues)headerReferences.GetType().GetField("_Method").GetValue(headerReferences);

Assert.Equal(StringValues.Empty, methodValue);
}

[Fact]
public void EnumeratorNotReusedBeforeReset()
{
Expand Down
15 changes: 13 additions & 2 deletions src/Servers/Kestrel/shared/KnownHeaders.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class KnownHeaders
HeaderNames.DNT,
};

public static readonly string[] PsuedoHeaderNames = new[]
public static readonly string[] PseudoHeaderNames = new[]
{
"Authority", // :authority
"Method", // :method
Expand All @@ -50,7 +50,7 @@ public class KnownHeaders

public static readonly string[] NonApiHeaders =
ObsoleteHeaderNames
.Concat(PsuedoHeaderNames)
.Concat(PseudoHeaderNames)
.ToArray();

public static readonly string[] ApiHeaderNames =
Expand All @@ -59,6 +59,7 @@ public class KnownHeaders
.ToArray();

public static readonly long InvalidH2H3ResponseHeadersBits;
public static readonly long PseudoRequestHeadersBits;

static KnownHeaders()
{
Expand Down Expand Up @@ -263,6 +264,11 @@ static KnownHeaders()
.Where(header => invalidH2H3ResponseHeaders.Contains(header.Name))
.Select(header => 1L << header.Index)
.Aggregate((a, b) => a | b);

PseudoRequestHeadersBits = RequestHeaders
.Where(header => PseudoHeaderNames.Contains(header.Identifier))
.Select(header => 1L << header.Index)
.Aggregate((a, b) => a | b);
}

static string Each<T>(IEnumerable<T> values, Func<T, string> formatter)
Expand Down Expand Up @@ -1249,6 +1255,11 @@ internal unsafe void CopyToFast(ref BufferWriter<PipeWriter> output)
}}
}} while (tempBits != 0);
}}" : "")}{(loop.ClassName == "HttpRequestHeaders" ? $@"
internal void ClearPseudoRequestHeaders()
{{
_pseudoBits = _bits & {PseudoRequestHeadersBits};
_bits &= ~{PseudoRequestHeadersBits};
}}
{Each(new string[] { "ushort", "uint", "ulong" }, type => $@"
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static unsafe {type} ReadUnalignedLittleEndian_{type}(ref byte source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ await ExpectAsync(Http2FrameType.HEADERS,
withStreamId: 1);

var contentType1 = _receivedHeaders["Content-Type"];
var authority1 = _receivedRequestFields.Authority;
var path1 = _receivedRequestFields.Path;

// TriggerTick will trigger the stream to be returned to the pool so we can assert it
TriggerTick();
Expand All @@ -194,8 +196,12 @@ await ExpectAsync(Http2FrameType.HEADERS,
withStreamId: 3);

var contentType2 = _receivedHeaders["Content-Type"];
var authority2 = _receivedRequestFields.Authority;
var path2 = _receivedRequestFields.Path;

Assert.Same(contentType1, contentType2);
Assert.Same(authority1, authority2);
Assert.Same(path1, path2);

await StopConnectionAsync(expectedLastStreamId: 3, ignoreNonGoAwayFrames: false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ public Http2TestBase()
_receivedRequestFields.Scheme = context.Request.Scheme;
_receivedRequestFields.Path = context.Request.Path.Value;
_receivedRequestFields.RawTarget = context.Features.Get<IHttpRequestFeature>().RawTarget;
_receivedRequestFields.Authority = context.Request.Host.Value;
foreach (var header in context.Request.Headers)
{
_receivedHeaders[header.Key] = header.Value.ToString();
Expand Down Expand Up @@ -1413,5 +1414,6 @@ public class RequestFields
public string Scheme { get; set; }
public string Path { get; set; }
public string RawTarget { get; set; }
public string Authority { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,42 @@ public async Task StreamPool_MultipleStreamsInSequence_PooledStreamReused()
Assert.Same(streamContext1, streamContext2);
}

[Fact]
public async Task StreamPool_MultipleStreamsInSequence_KnownHeaderReused()
{
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "Custom"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
new KeyValuePair<string, string>(HeaderNames.Authority, "localhost:80"),
new KeyValuePair<string, string>(HeaderNames.ContentType, "application/json"),
};

string contentType = null;
string authority = null;
await Http3Api.InitializeConnectionAsync(async context =>
{
contentType = context.Request.ContentType;
authority = context.Request.Host.Value;
await _echoApplication(context);
});

var streamContext1 = await MakeRequestAsync(0, headers, sendData: true, waitForServerDispose: true);
var contentType1 = contentType;
var authority1 = authority;

var streamContext2 = await MakeRequestAsync(1, headers, sendData: true, waitForServerDispose: true);
var contentType2 = contentType;
var authority2 = authority;

Assert.NotNull(contentType1);
Assert.NotNull(authority1);

Assert.Same(contentType1, contentType2);
Assert.Same(authority1, authority2);
}

[Theory]
[InlineData(10)]
[InlineData(100)]
Expand Down