Skip to content

Commit bced584

Browse files
authored
Light up Ascii.Equality.Equals and Ascii.Equality.EqualsIgnoreCase with Vector512 code path (#88650)
* merging with main Enabling AVX512 for ASCII.Equals * Correcting defects in the new Equals for AVX512 case * Correcting defects * Upgrading ASCII.Equality.EqualsIgnoreCase * Using intrinsics in AllCharsInVectorAreAscii * Using intrinsics in AllCharsInVectorAreAscii * Removing check for AVX512F and adding a check for Vector512 because the library is not using any functions from AVX512F * Removing check for CompExactlyDependsOn(AVX512F) from AllCharsInVectorAreAscii for Vector 512. Also checking for Vector512 support and not AVX512F in ASCIIEquality.Equals * Correcting the Tloader.Count512 for ushort * resolving merge errors * Adding TLoader method for Vector512 for EqualAndAscii * Updating Load512 for WideningLoader for performance increase * addressing review comments * Addressing review changes. Changing Widen to WidenLower for Load512
1 parent 420dd4e commit bced584

File tree

2 files changed

+173
-5
lines changed

2 files changed

+173
-5
lines changed

src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs

Lines changed: 157 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,36 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
6161
}
6262
}
6363
}
64+
else if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<TLeft>.Count)
65+
{
66+
ref TLeft currentLeftSearchSpace = ref left;
67+
ref TRight currentRightSearchSpace = ref right;
68+
// Add Vector512<TLeft>.Count because TLeft == TRight
69+
// Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count
70+
Debug.Assert(Vector512<TLeft>.Count == Vector512<TRight>.Count
71+
|| (typeof(TLoader) == typeof(WideningLoader) && Vector512<TLeft>.Count == Vector512<TRight>.Count * 2));
72+
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector512<TLeft>.Count);
73+
74+
// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
75+
do
76+
{
77+
if (!TLoader.EqualAndAscii512(ref currentLeftSearchSpace, ref currentRightSearchSpace))
78+
{
79+
return false;
80+
}
81+
82+
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector512<TLeft>.Count);
83+
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, Vector512<TLeft>.Count);
84+
}
85+
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));
86+
87+
// If any elements remain, process the last vector in the search space.
88+
if (length % (uint)Vector512<TLeft>.Count != 0)
89+
{
90+
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector512<TLeft>.Count);
91+
return TLoader.EqualAndAscii512(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
92+
}
93+
}
6494
else if (Avx.IsSupported && length >= (uint)Vector256<TLeft>.Count)
6595
{
6696
ref TLeft currentLeftSearchSpace = ref left;
@@ -74,7 +104,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
74104
// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
75105
do
76106
{
77-
if (!TLoader.EqualAndAscii(ref currentLeftSearchSpace, ref currentRightSearchSpace))
107+
if (!TLoader.EqualAndAscii256(ref currentLeftSearchSpace, ref currentRightSearchSpace))
78108
{
79109
return false;
80110
}
@@ -88,7 +118,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
88118
if (length % (uint)Vector256<TLeft>.Count != 0)
89119
{
90120
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector256<TLeft>.Count);
91-
return TLoader.EqualAndAscii(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
121+
return TLoader.EqualAndAscii256(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
92122
}
93123
}
94124
else
@@ -198,6 +228,77 @@ private static bool EqualsIgnoreCase<TLeft, TRight, TLoader>(ref TLeft left, ref
198228
}
199229
}
200230
}
231+
else if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<TRight>.Count)
232+
{
233+
ref TLeft currentLeftSearchSpace = ref left;
234+
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref currentLeftSearchSpace, length - TLoader.Count512);
235+
ref TRight currentRightSearchSpace = ref right;
236+
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector512<TRight>.Count);
237+
238+
Vector512<TRight> leftValues;
239+
Vector512<TRight> rightValues;
240+
241+
Vector512<TRight> loweringMask = Vector512.Create(TRight.CreateTruncating(0x20));
242+
Vector512<TRight> vecA = Vector512.Create(TRight.CreateTruncating('a'));
243+
Vector512<TRight> vecZMinusA = Vector512.Create(TRight.CreateTruncating(('z' - 'a')));
244+
245+
// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
246+
do
247+
{
248+
leftValues = TLoader.Load512(ref currentLeftSearchSpace);
249+
rightValues = Vector512.LoadUnsafe(ref currentRightSearchSpace);
250+
if (!AllCharsInVectorAreAscii(leftValues | rightValues))
251+
{
252+
return false;
253+
}
254+
255+
Vector512<TRight> notEquals = ~Vector512.Equals(leftValues, rightValues);
256+
257+
if (notEquals != Vector512<TRight>.Zero)
258+
{
259+
// not exact match
260+
261+
leftValues |= loweringMask;
262+
rightValues |= loweringMask;
263+
264+
if (Vector512.GreaterThanAny((leftValues - vecA) & notEquals, vecZMinusA) || leftValues != rightValues)
265+
{
266+
return false; // first input isn't in [A-Za-z], and not exact match of lowered
267+
}
268+
}
269+
270+
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, (uint)Vector512<TRight>.Count);
271+
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, TLoader.Count512);
272+
}
273+
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));
274+
275+
// If any elements remain, process the last vector in the search space.
276+
if (length % (uint)Vector512<TRight>.Count != 0)
277+
{
278+
leftValues = TLoader.Load512(ref oneVectorAwayFromLeftEnd);
279+
rightValues = Vector512.LoadUnsafe(ref oneVectorAwayFromRightEnd);
280+
281+
if (!AllCharsInVectorAreAscii(leftValues | rightValues))
282+
{
283+
return false;
284+
}
285+
286+
Vector512<TRight> notEquals = ~Vector512.Equals(leftValues, rightValues);
287+
288+
if (notEquals != Vector512<TRight>.Zero)
289+
{
290+
// not exact match
291+
292+
leftValues |= loweringMask;
293+
rightValues |= loweringMask;
294+
295+
if (Vector512.GreaterThanAny((leftValues - vecA) & notEquals, vecZMinusA) || leftValues != rightValues)
296+
{
297+
return false; // first input isn't in [A-Za-z], and not exact match of lowered
298+
}
299+
}
300+
}
301+
}
201302
else if (Avx.IsSupported && length >= (uint)Vector256<TRight>.Count)
202303
{
203304
ref TLeft currentLeftSearchSpace = ref left;
@@ -353,21 +454,26 @@ private interface ILoader<TLeft, TRight>
353454
{
354455
static abstract nuint Count128 { get; }
355456
static abstract nuint Count256 { get; }
457+
static abstract nuint Count512 { get; }
356458
static abstract Vector128<TRight> Load128(ref TLeft ptr);
357459
static abstract Vector256<TRight> Load256(ref TLeft ptr);
358-
static abstract bool EqualAndAscii(ref TLeft left, ref TRight right);
460+
static abstract Vector512<TRight> Load512(ref TLeft ptr);
461+
static abstract bool EqualAndAscii256(ref TLeft left, ref TRight right);
462+
static abstract bool EqualAndAscii512(ref TLeft left, ref TRight right);
359463
}
360464

361465
private readonly struct PlainLoader<T> : ILoader<T, T> where T : unmanaged, INumberBase<T>
362466
{
363467
public static nuint Count128 => (uint)Vector128<T>.Count;
364468
public static nuint Count256 => (uint)Vector256<T>.Count;
469+
public static nuint Count512 => (uint)Vector512<T>.Count;
365470
public static Vector128<T> Load128(ref T ptr) => Vector128.LoadUnsafe(ref ptr);
366471
public static Vector256<T> Load256(ref T ptr) => Vector256.LoadUnsafe(ref ptr);
472+
public static Vector512<T> Load512(ref T ptr) => Vector512.LoadUnsafe(ref ptr);
367473

368474
[MethodImpl(MethodImplOptions.AggressiveInlining)]
369475
[CompExactlyDependsOn(typeof(Avx))]
370-
public static bool EqualAndAscii(ref T left, ref T right)
476+
public static bool EqualAndAscii256(ref T left, ref T right)
371477
{
372478
Vector256<T> leftValues = Vector256.LoadUnsafe(ref left);
373479
Vector256<T> rightValues = Vector256.LoadUnsafe(ref right);
@@ -379,12 +485,27 @@ public static bool EqualAndAscii(ref T left, ref T right)
379485

380486
return true;
381487
}
488+
489+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
490+
public static bool EqualAndAscii512(ref T left, ref T right)
491+
{
492+
Vector512<T> leftValues = Vector512.LoadUnsafe(ref left);
493+
Vector512<T> rightValues = Vector512.LoadUnsafe(ref right);
494+
495+
if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
496+
{
497+
return false;
498+
}
499+
500+
return true;
501+
}
382502
}
383503

384504
private readonly struct WideningLoader : ILoader<byte, ushort>
385505
{
386506
public static nuint Count128 => sizeof(long);
387507
public static nuint Count256 => (uint)Vector128<byte>.Count;
508+
public static nuint Count512 => (uint)Vector256<byte>.Count;
388509

389510
[MethodImpl(MethodImplOptions.AggressiveInlining)]
390511
public static Vector128<ushort> Load128(ref byte ptr)
@@ -412,9 +533,15 @@ public static Vector256<ushort> Load256(ref byte ptr)
412533
return Vector256.Create(lower, upper);
413534
}
414535

536+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
537+
public static Vector512<ushort> Load512(ref byte ptr)
538+
{
539+
return Vector512.WidenLower(Vector256.LoadUnsafe(ref ptr).ToVector512());
540+
}
541+
415542
[MethodImpl(MethodImplOptions.AggressiveInlining)]
416543
[CompExactlyDependsOn(typeof(Avx))]
417-
public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
544+
public static bool EqualAndAscii256(ref byte utf8, ref ushort utf16)
418545
{
419546
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
420547
Debug.Assert(Vector256<byte>.Count == Vector256<ushort>.Count * 2);
@@ -437,6 +564,31 @@ public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
437564

438565
return true;
439566
}
567+
568+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
569+
public static bool EqualAndAscii512(ref byte utf8, ref ushort utf16)
570+
{
571+
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
572+
Debug.Assert(Vector512<byte>.Count == Vector512<ushort>.Count * 2);
573+
574+
Vector512<byte> leftNotWidened = Vector512.LoadUnsafe(ref utf8);
575+
if (!AllCharsInVectorAreAscii(leftNotWidened))
576+
{
577+
return false;
578+
}
579+
580+
(Vector512<ushort> leftLower, Vector512<ushort> leftUpper) = Vector512.Widen(leftNotWidened);
581+
Vector512<ushort> right = Vector512.LoadUnsafe(ref utf16);
582+
Vector512<ushort> rightNext = Vector512.LoadUnsafe(ref utf16, (uint)Vector512<ushort>.Count);
583+
584+
// A branchless version of "leftLower != right || leftUpper != rightNext"
585+
if (((leftLower ^ right) | (leftUpper ^ rightNext)) != Vector512<ushort>.Zero)
586+
{
587+
return false;
588+
}
589+
590+
return true;
591+
}
440592
}
441593
}
442594
}

src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Utility.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1972,6 +1972,22 @@ private static bool AllCharsInVectorAreAscii<T>(Vector256<T> vector)
19721972
}
19731973
}
19741974

1975+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
1976+
private static bool AllCharsInVectorAreAscii<T>(Vector512<T> vector)
1977+
where T : unmanaged
1978+
{
1979+
Debug.Assert(typeof(T) == typeof(byte) || typeof(T) == typeof(ushort));
1980+
1981+
if (typeof(T) == typeof(byte))
1982+
{
1983+
return vector.AsByte().ExtractMostSignificantBits() == 0;
1984+
}
1985+
else
1986+
{
1987+
return (vector.AsUInt16() & Vector512.Create((ushort)0xFF80)) == Vector512<ushort>.Zero;
1988+
}
1989+
}
1990+
19751991
[MethodImpl(MethodImplOptions.AggressiveInlining)]
19761992
private static Vector128<byte> ExtractAsciiVector(Vector128<ushort> vectorFirst, Vector128<ushort> vectorSecond)
19771993
{

0 commit comments

Comments
 (0)