diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.cs index fa3c7803ed78..15ec7d20d9c8 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/HttpUtilities.cs @@ -2,8 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; @@ -30,6 +32,8 @@ internal static partial class HttpUtilities private const ulong _http11VersionLong = 3543824036068086856; // GetAsciiStringAsLong("HTTP/1.1"); const results in better codegen private static readonly UTF8EncodingSealed HeaderValueEncoding = new UTF8EncodingSealed(); + private static readonly SpanAction _getHeaderName = GetHeaderName; + private static readonly SpanAction _getAsciiStringNonNullCharacters = GetAsciiStringNonNullCharacters; [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void SetKnownMethod(ulong mask, ulong knownMethodUlong, HttpMethod knownMethod, int length) @@ -86,6 +90,7 @@ private static unsafe ulong GetMaskAsLong(byte[] bytes) } // The same as GetAsciiStringNonNullCharacters but throws BadRequest + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static unsafe string GetHeaderName(this ReadOnlySpan span) { if (span.IsEmpty) @@ -93,25 +98,29 @@ public static unsafe string GetHeaderName(this ReadOnlySpan span) return string.Empty; } - var asciiString = new string('\0', span.Length); + fixed (byte* source = &MemoryMarshal.GetReference(span)) + { + return string.Create(span.Length, new IntPtr(source), _getHeaderName); + } + } - fixed (char* output = asciiString) - fixed (byte* buffer = span) + private static unsafe void GetHeaderName(Span buffer, IntPtr state) + { + fixed (char* output = &MemoryMarshal.GetReference(buffer)) { // This version if AsciiUtilities returns null if there are any null (0 byte) characters // in the string - if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length)) + if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length)) { BadHttpRequestException.Throw(RequestRejectionReason.InvalidCharactersInHeaderName); } } - - return asciiString; } public static string GetAsciiStringNonNullCharacters(this Span span) => GetAsciiStringNonNullCharacters((ReadOnlySpan)span); + [MethodImpl(MethodImplOptions.AggressiveInlining)] public static unsafe string GetAsciiStringNonNullCharacters(this ReadOnlySpan span) { if (span.IsEmpty) @@ -119,19 +128,23 @@ public static unsafe string GetAsciiStringNonNullCharacters(this ReadOnlySpan buffer, IntPtr state) + { + fixed (char* output = &MemoryMarshal.GetReference(buffer)) { // This version if AsciiUtilities returns null if there are any null (0 byte) characters // in the string - if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length)) + if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length)) { throw new InvalidOperationException(); } } - return asciiString; } public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span span) @@ -144,14 +157,12 @@ public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this ReadOnlyS return string.Empty; } - var resultString = new string('\0', span.Length); - - fixed (char* output = resultString) - fixed (byte* buffer = span) + fixed (byte* source = &MemoryMarshal.GetReference(span)) { - // This version if AsciiUtilities returns null if there are any null (0 byte) characters - // in the string - if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length)) + var resultString = string.Create(span.Length, new IntPtr(source), s_getAsciiOrUtf8StringNonNullCharacters); + + // If resultString is marked, perform UTF-8 encoding + if (resultString[0] == '\0') { // null characters are considered invalid if (span.IndexOf((byte)0) != -1) @@ -161,15 +172,32 @@ public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this ReadOnlyS try { - resultString = HeaderValueEncoding.GetString(buffer, span.Length); + resultString = HeaderValueEncoding.GetString(span); } catch (DecoderFallbackException) { throw new InvalidOperationException(); } } + + return resultString; + } + } + + private static readonly SpanAction s_getAsciiOrUtf8StringNonNullCharacters = GetAsciiOrUTF8StringNonNullCharacters; + + private static unsafe void GetAsciiOrUTF8StringNonNullCharacters(Span buffer, IntPtr state) + { + fixed (char* output = &MemoryMarshal.GetReference(buffer)) + { + // This version if AsciiUtilities returns null if there are any null (0 byte) characters + // in the string + if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length)) + { + // Mark resultString for UTF-8 encoding + output[0] = '\0'; + } } - return resultString; } public static string GetAsciiStringEscaped(this Span span, int maxChars) @@ -288,7 +316,7 @@ public static HttpMethod GetKnownMethod(string value) { method = HttpMethod.Head; } - else if(firstChar == 'P' && string.Equals(value, HttpMethods.Post, StringComparison.Ordinal)) + else if (firstChar == 'P' && string.Equals(value, HttpMethods.Post, StringComparison.Ordinal)) { method = HttpMethod.Post; } @@ -299,7 +327,7 @@ public static HttpMethod GetKnownMethod(string value) { method = HttpMethod.Trace; } - else if(firstChar == 'P' && string.Equals(value, HttpMethods.Patch, StringComparison.Ordinal)) + else if (firstChar == 'P' && string.Equals(value, HttpMethods.Patch, StringComparison.Ordinal)) { method = HttpMethod.Patch; } diff --git a/src/Shared/Http2cat/Http2Utilities.cs b/src/Shared/Http2cat/Http2Utilities.cs index 14fe26a601a7..f28933c9c00a 100644 --- a/src/Shared/Http2cat/Http2Utilities.cs +++ b/src/Shared/Http2cat/Http2Utilities.cs @@ -26,7 +26,7 @@ namespace Microsoft.AspNetCore.Http2Cat internal class Http2Utilities : IHttpHeadersHandler { public static ReadOnlySpan ClientPreface => new byte[24] { (byte)'P', (byte)'R', (byte)'I', (byte)' ', (byte)'*', (byte)' ', (byte)'H', (byte)'T', (byte)'T', (byte)'P', (byte)'/', (byte)'2', (byte)'.', (byte)'0', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n', (byte)'S', (byte)'M', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' }; - public static readonly int MaxRequestHeaderFieldSize = 16 * 1024; + public const int MaxRequestHeaderFieldSize = 16 * 1024; public static readonly string FourKHeaderValue = new string('a', 4096); private static readonly Encoding HeaderValueEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true); diff --git a/src/Shared/ServerInfrastructure/StringUtilities.cs b/src/Shared/ServerInfrastructure/StringUtilities.cs index 2a12b9893b21..856c3693e844 100644 --- a/src/Shared/ServerInfrastructure/StringUtilities.cs +++ b/src/Shared/ServerInfrastructure/StringUtilities.cs @@ -2,117 +2,228 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers.Binary; using System.Diagnostics; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.X86; using System.Text; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure { - internal class StringUtilities + internal static class StringUtilities { [MethodImpl(MethodImplOptions.AggressiveOptimization)] public static unsafe bool TryGetAsciiString(byte* input, char* output, int count) { - // Calculate end position + Debug.Assert(input != null); + Debug.Assert(output != null); + var end = input + count; - // Start as valid - var isValid = true; - do + Debug.Assert((long)end >= Vector256.Count); + + if (Sse2.IsSupported) { - // If Vector not-accelerated or remaining less than vector size - if (!Vector.IsHardwareAccelerated || input > end - Vector.Count) + if (Avx2.IsSupported && input <= end - Vector256.Count) { - if (IntPtr.Size == 8) // Use Intrinsic switch for branch elimination + Vector256 zero = Vector256.Zero; + + do { - // 64-bit: Loop longs by default - while (input <= end - sizeof(long)) + var vector = Avx.LoadVector256(input).AsSByte(); + if (!CheckBytesInAsciiRange(vector, zero)) { - isValid &= CheckBytesInAsciiRange(((long*)input)[0]); - - output[0] = (char)input[0]; - output[1] = (char)input[1]; - output[2] = (char)input[2]; - output[3] = (char)input[3]; - output[4] = (char)input[4]; - output[5] = (char)input[5]; - output[6] = (char)input[6]; - output[7] = (char)input[7]; - - input += sizeof(long); - output += sizeof(long); + return false; } - if (input <= end - sizeof(int)) - { - isValid &= CheckBytesInAsciiRange(((int*)input)[0]); - output[0] = (char)input[0]; - output[1] = (char)input[1]; - output[2] = (char)input[2]; - output[3] = (char)input[3]; + var tmp0 = Avx2.UnpackLow(vector, zero); + var tmp1 = Avx2.UnpackHigh(vector, zero); - input += sizeof(int); - output += sizeof(int); - } + // Bring into the right order + var out0 = Avx2.Permute2x128(tmp0, tmp1, 0x20); + var out1 = Avx2.Permute2x128(tmp0, tmp1, 0x31); + + Avx.Store((ushort*)output, out0.AsUInt16()); + Avx.Store((ushort*)output + Vector256.Count, out1.AsUInt16()); + + input += Vector256.Count; + output += Vector256.Count; + } while (input <= end - Vector256.Count); + + if (input == end) + { + return true; } - else + } + + if (input <= end - Vector128.Count) + { + Vector128 zero = Vector128.Zero; + + do { - // 32-bit: Loop ints by default - while (input <= end - sizeof(int)) + var vector = Sse2.LoadVector128(input).AsSByte(); + if (!CheckBytesInAsciiRange(vector, zero)) { - isValid &= CheckBytesInAsciiRange(((int*)input)[0]); + return false; + } - output[0] = (char)input[0]; - output[1] = (char)input[1]; - output[2] = (char)input[2]; - output[3] = (char)input[3]; + var c0 = Sse2.UnpackLow(vector, zero).AsUInt16(); + var c1 = Sse2.UnpackHigh(vector, zero).AsUInt16(); - input += sizeof(int); - output += sizeof(int); - } + Sse2.Store((ushort*)output, c0); + Sse2.Store((ushort*)output + Vector128.Count, c1); + + input += Vector128.Count; + output += Vector128.Count; + } while (input <= end - Vector128.Count); + + if (input == end) + { + return true; + } + } + } + else if (Vector.IsHardwareAccelerated) + { + while (input <= end - Vector.Count) + { + var vector = Unsafe.AsRef>(input); + if (!CheckBytesInAsciiRange(vector)) + { + return false; } - if (input <= end - sizeof(short)) + + Vector.Widen( + vector, + out Unsafe.AsRef>(output), + out Unsafe.AsRef>(output + Vector.Count)); + + input += Vector.Count; + output += Vector.Count; + } + + if (input == end) + { + return true; + } + } + + if (Environment.Is64BitProcess) // Use Intrinsic switch for branch elimination + { + // 64-bit: Loop longs by default + while (input <= end - sizeof(long)) + { + var value = *(long*)input; + if (!CheckBytesInAsciiRange(value)) { - isValid &= CheckBytesInAsciiRange(((short*)input)[0]); + return false; + } + if (Bmi2.X64.IsSupported) + { + // BMI2 will work regardless of the processor's endianness. + ((ulong*)output)[0] = Bmi2.X64.ParallelBitDeposit((ulong)value, 0x00FF00FF_00FF00FFul); + ((ulong*)output)[1] = Bmi2.X64.ParallelBitDeposit((ulong)(value >> 32), 0x00FF00FF_00FF00FFul); + } + else + { output[0] = (char)input[0]; output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + output[4] = (char)input[4]; + output[5] = (char)input[5]; + output[6] = (char)input[6]; + output[7] = (char)input[7]; + } - input += sizeof(short); - output += sizeof(short); + input += sizeof(long); + output += sizeof(long); + } + + if (input <= end - sizeof(int)) + { + var value = *(int*)input; + if (!CheckBytesInAsciiRange(value)) + { + return false; + } + + if (Bmi2.IsSupported) + { + // BMI2 will work regardless of the processor's endianness. + ((uint*)output)[0] = Bmi2.ParallelBitDeposit((uint)value, 0x00FF00FFu); + ((uint*)output)[1] = Bmi2.ParallelBitDeposit((uint)(value >> 16), 0x00FF00FFu); + } + else + { + output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; + } + + input += sizeof(int); + output += sizeof(int); + } + } + else + { + // 32-bit: Loop ints by default + while (input <= end - sizeof(int)) + { + var value = *(int*)input; + if (!CheckBytesInAsciiRange(value)) + { + return false; + } + + if (Bmi2.IsSupported) + { + // BMI2 will work regardless of the processor's endianness. + ((uint*)output)[0] = Bmi2.ParallelBitDeposit((uint)value, 0x00FF00FFu); + ((uint*)output)[1] = Bmi2.ParallelBitDeposit((uint)(value >> 16), 0x00FF00FFu); } - if (input < end) + else { - isValid &= CheckBytesInAsciiRange(((sbyte*)input)[0]); output[0] = (char)input[0]; + output[1] = (char)input[1]; + output[2] = (char)input[2]; + output[3] = (char)input[3]; } - return isValid; + input += sizeof(int); + output += sizeof(int); } + } - // do/while as entry condition already checked - do + if (input <= end - sizeof(short)) + { + if (!CheckBytesInAsciiRange(((short*)input)[0])) { - var vector = Unsafe.AsRef>(input); - isValid &= CheckBytesInAsciiRange(vector); - Vector.Widen( - vector, - out Unsafe.AsRef>(output), - out Unsafe.AsRef>(output + Vector.Count)); + return false; + } - input += Vector.Count; - output += Vector.Count; - } while (input <= end - Vector.Count); + output[0] = (char)input[0]; + output[1] = (char)input[1]; - // Vector path done, loop back to do non-Vector - // If is a exact multiple of vector size, bail now - } while (input < end); + input += sizeof(short); + output += sizeof(short); + } - return isValid; + if (input < end) + { + if (!CheckBytesInAsciiRange(((sbyte*)input)[0])) + { + return false; + } + output[0] = (char)input[0]; + } + + return true; } [MethodImpl(MethodImplOptions.AggressiveOptimization)] @@ -365,7 +476,8 @@ private unsafe static bool IsValidHeaderString(string value) new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true).GetByteCount(value); return !value.Contains('\0'); } - catch (DecoderFallbackException) { + catch (DecoderFallbackException) + { return false; } } @@ -418,6 +530,24 @@ private static bool CheckBytesInAsciiRange(Vector check) return Vector.GreaterThanAll(check, Vector.Zero); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool CheckBytesInAsciiRange(Vector256 check, Vector256 zero) + { + Debug.Assert(Avx2.IsSupported); + + var mask = Avx2.CompareGreaterThan(check, zero); + return (uint)Avx2.MoveMask(mask) == 0xFFFF_FFFF; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool CheckBytesInAsciiRange(Vector128 check, Vector128 zero) + { + Debug.Assert(Sse2.IsSupported); + + var mask = Sse2.CompareGreaterThan(check, zero); + return Sse2.MoveMask(mask) == 0xFFFF; + } + // Validate: bytes != 0 && bytes <= 127 // Subtract 1 from all bytes to move 0 to high bits // bitwise or with self to catch all > 127 bytes