Skip to content

Removed unsafe string mutation #31850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,22 @@ public BCryptKeyHandle GenerateSymmetricKey(byte* pbSecret, uint cbSecret)
/// </summary>
public string GetAlgorithmName()
{
const int StackAllocCharSize = 128;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There can't be longer algorithm names?
I gues no, but ... want to double-check.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd be extremely surprised if that were the case. So what you had proposed should be good.

If you wanted to be defensive you could always string.Create instead of stackalloc, but overflowing this buffer is never going to happen in practice. (This method isn't even called outside unit tests TBH.)


// First, calculate how many characters are in the name.
uint byteLengthOfNameWithTerminatingNull = GetProperty(Constants.BCRYPT_ALGORITHM_NAME, null, 0);
CryptoUtil.Assert(byteLengthOfNameWithTerminatingNull % sizeof(char) == 0 && byteLengthOfNameWithTerminatingNull > sizeof(char), "byteLengthOfNameWithTerminatingNull % sizeof(char) == 0 && byteLengthOfNameWithTerminatingNull > sizeof(char)");
CryptoUtil.Assert(byteLengthOfNameWithTerminatingNull % sizeof(char) == 0 && byteLengthOfNameWithTerminatingNull > sizeof(char) && byteLengthOfNameWithTerminatingNull <= StackAllocCharSize * sizeof(char), "byteLengthOfNameWithTerminatingNull % sizeof(char) == 0 && byteLengthOfNameWithTerminatingNull > sizeof(char) && byteLengthOfNameWithTerminatingNull <= StackAllocCharSize * sizeof(char)");
uint numCharsWithoutNull = (byteLengthOfNameWithTerminatingNull - 1) / sizeof(char);

if (numCharsWithoutNull == 0)
{
return String.Empty; // degenerate case
return string.Empty; // degenerate case
}

// Allocate a string object and write directly into it (CLR team approves of this mechanism).
string retVal = new String((char)0, checked((int)numCharsWithoutNull));
uint numBytesCopied;
fixed (char* pRetVal = retVal)
{
numBytesCopied = GetProperty(Constants.BCRYPT_ALGORITHM_NAME, pRetVal, byteLengthOfNameWithTerminatingNull);
}
char* pBuffer = stackalloc char[StackAllocCharSize];
uint numBytesCopied = GetProperty(Constants.BCRYPT_ALGORITHM_NAME, pBuffer, byteLengthOfNameWithTerminatingNull);
CryptoUtil.Assert(numBytesCopied == byteLengthOfNameWithTerminatingNull, "numBytesCopied == byteLengthOfNameWithTerminatingNull");
return retVal;
return new string(pBuffer, 0, (int)numCharsWithoutNull);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ internal static partial class HttpUtilities
private const ulong _http11VersionLong = 3543824036068086856; // GetAsciiStringAsLong("HTTP/1.1"); const results in better codegen

private static readonly UTF8EncodingSealed DefaultRequestHeaderEncoding = new UTF8EncodingSealed();
private static readonly SpanAction<char, IntPtr> _getHeaderName = GetHeaderName;
private static readonly SpanAction<char, IntPtr> _getAsciiStringNonNullCharacters = GetAsciiStringNonNullCharacters;
private static readonly SpanAction<char, IntPtr> s_getHeaderName = GetHeaderName;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void SetKnownMethod(ulong mask, ulong knownMethodUlong, HttpMethod knownMethod, int length)
Expand Down Expand Up @@ -86,15 +85,15 @@ public static unsafe string GetHeaderName(this ReadOnlySpan<byte> span)

fixed (byte* source = &MemoryMarshal.GetReference(span))
{
return string.Create(span.Length, new IntPtr(source), _getHeaderName);
return string.Create(span.Length, new IntPtr(source), s_getHeaderName);
}
}

private static unsafe void GetHeaderName(Span<char> buffer, IntPtr state)
{
fixed (char* output = &MemoryMarshal.GetReference(buffer))
{
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
// This version of AsciiUtilities returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
{
Expand All @@ -104,38 +103,11 @@ private static unsafe void GetHeaderName(Span<char> buffer, IntPtr state)
}

public static string GetAsciiStringNonNullCharacters(this Span<byte> span)
=> GetAsciiStringNonNullCharacters((ReadOnlySpan<byte>)span);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe string GetAsciiStringNonNullCharacters(this ReadOnlySpan<byte> span)
{
if (span.IsEmpty)
{
return string.Empty;
}

fixed (byte* source = &MemoryMarshal.GetReference(span))
{
return string.Create(span.Length, new IntPtr(source), _getAsciiStringNonNullCharacters);
}
}
=> StringUtilities.GetAsciiStringNonNullCharacters(span);

public static string GetAsciiOrUTF8StringNonNullCharacters(this ReadOnlySpan<byte> span)
=> StringUtilities.GetAsciiOrUTF8StringNonNullCharacters(span, DefaultRequestHeaderEncoding);

private static unsafe void GetAsciiStringNonNullCharacters(Span<char> buffer, IntPtr state)
{
fixed (char* output = &MemoryMarshal.GetReference(buffer))
{
// StringUtilities.TryGetAsciiString returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
{
throw new InvalidOperationException();
}
}
}

public static string GetRequestHeaderString(this ReadOnlySpan<byte> span, string name, Func<string, Encoding?> encodingSelector)
{
if (ReferenceEquals(KestrelServerOptions.DefaultRequestHeaderEncodingSelector, encodingSelector))
Expand Down
59 changes: 1 addition & 58 deletions src/Shared/Http2cat/Http2Utilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,64 +143,7 @@ public Http2Utilities(ConnectionContext clientConnectionContext, ILogger logger,

void IHttpHeadersHandler.OnHeader(ReadOnlySpan<byte> name, ReadOnlySpan<byte> value)
{
_decodedHeaders[GetAsciiStringNonNullCharacters(name)] = GetAsciiOrUTF8StringNonNullCharacters(value);
}

public unsafe string GetAsciiStringNonNullCharacters(ReadOnlySpan<byte> span)
{
if (span.IsEmpty)
{
return string.Empty;
}

var asciiString = new string('\0', span.Length);

fixed (char* output = asciiString)
fixed (byte* buffer = span)
{
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
{
throw new InvalidOperationException();
}
}
return asciiString;
}

public unsafe string GetAsciiOrUTF8StringNonNullCharacters(ReadOnlySpan<byte> span)
{
if (span.IsEmpty)
{
return string.Empty;
}

var resultString = new string('\0', span.Length);

fixed (char* output = resultString)
fixed (byte* buffer = span)
{
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
{
// null characters are considered invalid
if (span.IndexOf((byte)0) != -1)
{
throw new InvalidOperationException();
}

try
{
resultString = HeaderValueEncoding.GetString(buffer, span.Length);
}
catch (DecoderFallbackException)
{
throw new InvalidOperationException();
}
}
}
return resultString;
_decodedHeaders[name.GetAsciiStringNonNullCharacters()] = value.GetAsciiOrUTF8StringNonNullCharacters(HeaderValueEncoding);
}

void IHttpHeadersHandler.OnHeadersComplete(bool endStream) { }
Expand Down
63 changes: 43 additions & 20 deletions src/Shared/ServerInfrastructure/StringUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
{
internal static class StringUtilities
{
private static readonly SpanAction<char, IntPtr> s_getAsciiOrUtf8StringNonNullCharacters = GetAsciiStringNonNullCharacters;

private static string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte> span, Encoding defaultEncoding)
=> GetAsciiOrUTF8StringNonNullCharacters((ReadOnlySpan<byte>)span, defaultEncoding);
private static readonly SpanAction<char, IntPtr> s_getAsciiOrUTF8StringNonNullCharacters = GetAsciiStringNonNullCharactersWithMarker;
private static readonly SpanAction<char, IntPtr> s_getAsciiStringNonNullCharacters = GetAsciiStringNonNullCharacters;
private static readonly SpanAction<char, IntPtr> s_getLatin1StringNonNullCharacters = GetLatin1StringNonNullCharacters;
private static readonly SpanAction<char, (string? str, char separator, uint number)> s_populateSpanWithHexSuffix = PopulateSpanWithHexSuffix;

public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this ReadOnlySpan<byte> span, Encoding defaultEncoding)
{
Expand All @@ -31,7 +31,7 @@ public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this ReadOnlyS

fixed (byte* source = &MemoryMarshal.GetReference(span))
{
var resultString = string.Create(span.Length, new IntPtr(source), s_getAsciiOrUtf8StringNonNullCharacters);
var resultString = string.Create(span.Length, (IntPtr)source, s_getAsciiOrUTF8StringNonNullCharacters);

// If resultString is marked, perform UTF-8 encoding
if (resultString[0] == '\0')
Expand All @@ -56,11 +56,11 @@ public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this ReadOnlyS
}
}

private static unsafe void GetAsciiStringNonNullCharacters(Span<char> buffer, IntPtr state)
private static unsafe void GetAsciiStringNonNullCharactersWithMarker(Span<char> buffer, IntPtr state)
{
fixed (char* output = &MemoryMarshal.GetReference(buffer))
{
// This version if AsciiUtilities returns false if there are any null ('\0') or non-Ascii
// This version of AsciiUtilities returns false if there are any null ('\0') or non-Ascii
// character (> 127) in the string.
if (!TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
{
Expand All @@ -70,27 +70,55 @@ private static unsafe void GetAsciiStringNonNullCharacters(Span<char> buffer, In
}
}

public static unsafe string GetAsciiStringNonNullCharacters(this ReadOnlySpan<byte> span)
{
if (span.IsEmpty)
{
return string.Empty;
}

fixed (byte* source = &MemoryMarshal.GetReference(span))
{
return string.Create(span.Length, (IntPtr)source, s_getAsciiStringNonNullCharacters);
}
}

private static unsafe void GetAsciiStringNonNullCharacters(Span<char> buffer, IntPtr state)
{
fixed (char* output = &MemoryMarshal.GetReference(buffer))
{
// This version of AsciiUtilities returns false if there are any null ('\0') or non-Ascii
// character (> 127) in the string.
if (!TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
{
throw new InvalidOperationException();
}
}
}

public static unsafe string GetLatin1StringNonNullCharacters(this ReadOnlySpan<byte> span)
{
if (span.IsEmpty)
{
return string.Empty;
}

var resultString = new string('\0', span.Length);
fixed (byte* source = &MemoryMarshal.GetReference(span))
{
return string.Create(span.Length, (IntPtr)source, s_getLatin1StringNonNullCharacters);
}
}

fixed (char* output = resultString)
fixed (byte* buffer = span)
private static unsafe void GetLatin1StringNonNullCharacters(Span<char> buffer, IntPtr state)
{
fixed (char* output = &MemoryMarshal.GetReference(buffer))
{
// This returns false if there are any null (0 byte) characters in the string.
if (!TryGetLatin1String(buffer, output, span.Length))
if (!TryGetLatin1String((byte*)state.ToPointer(), output, buffer.Length))
{
// null characters are considered invalid
throw new InvalidOperationException();
}
}

return resultString;
}

[MethodImpl(MethodImplOptions.AggressiveOptimization)]
Expand Down Expand Up @@ -299,7 +327,7 @@ public static unsafe bool TryGetLatin1String(byte* input, char* output, int coun
// If Vector not-accelerated or remaining less than vector size
if (!Vector.IsHardwareAccelerated || input > end - Vector<sbyte>.Count)
{
if (IntPtr.Size == 8) // Use Intrinsic switch for branch elimination
if (Environment.Is64BitProcess) // Use Intrinsic switch for branch elimination
{
// 64-bit: Loop longs by default
while (input <= end - sizeof(long))
Expand Down Expand Up @@ -403,10 +431,6 @@ public static bool BytesOrdinalEqualsStringAndAscii(string previousValue, ReadOn
goto NotEqual;
}

// Use IntPtr values rather than int, to avoid unnecessary 32 -> 64 movs on 64-bit.
// Unfortunately this means we also need to cast to byte* for comparisons as IntPtr doesn't
// support operator comparisons (e.g. <=, >, etc).
//
// Note: Pointer comparison is unsigned, so we use the compare pattern (offset + length <= count)
// rather than (offset <= count - length) which we'd do with signed comparison to avoid overflow.
// This isn't problematic as we know the maximum length is max string length (from test above)
Expand Down Expand Up @@ -666,7 +690,6 @@ private static bool IsValidHeaderString(string value)
return false;
}
}
private static readonly SpanAction<char, (string? str, char separator, uint number)> s_populateSpanWithHexSuffix = PopulateSpanWithHexSuffix;

/// <summary>
/// A faster version of String.Concat(<paramref name="str"/>, <paramref name="separator"/>, <paramref name="number"/>.ToString("X8"))
Expand Down