Skip to content

Commit 6e01940

Browse files
committed
Adding TLoader method for Vector512 for EqualAndAscii
1 parent f1930dd commit 6e01940

File tree

1 file changed

+61
-30
lines changed

1 file changed

+61
-30
lines changed

src/libraries/System.Private.CoreLib/src/System/Text/Ascii.Equality.cs

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -61,42 +61,34 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
6161
}
6262
}
6363
}
64-
else if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<TRight>.Count)
64+
else if (Vector512.IsHardwareAccelerated && length >= (uint)Vector512<TLeft>.Count)
6565
{
6666
ref TLeft currentLeftSearchSpace = ref left;
67-
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref currentLeftSearchSpace, length - TLoader.Count512);
6867
ref TRight currentRightSearchSpace = ref right;
69-
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector512<TRight>.Count);
70-
71-
Vector512<TRight> leftValues;
72-
Vector512<TRight> rightValues;
68+
// Add Vector512<TLeft>.Count because TLeft == TRight
69+
// Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count
70+
Debug.Assert(Vector512<TLeft>.Count == Vector512<TRight>.Count
71+
|| (typeof(TLoader) == typeof(WideningLoader) && Vector512<TLeft>.Count == Vector512<TRight>.Count * 2));
72+
ref TRight oneVectorAwayFromRightEnd = ref Unsafe.Add(ref currentRightSearchSpace, length - (uint)Vector512<TLeft>.Count);
7373

7474
// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
7575
do
7676
{
77-
leftValues = TLoader.Load512(ref currentLeftSearchSpace);
78-
rightValues = Vector512.LoadUnsafe(ref currentRightSearchSpace);
79-
80-
if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
77+
if (!TLoader.EqualAndAscii512(ref currentLeftSearchSpace, ref currentRightSearchSpace))
8178
{
8279
return false;
8380
}
8481

85-
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector512<TRight>.Count);
86-
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, TLoader.Count512);
82+
currentRightSearchSpace = ref Unsafe.Add(ref currentRightSearchSpace, Vector512<TLeft>.Count);
83+
currentLeftSearchSpace = ref Unsafe.Add(ref currentLeftSearchSpace, Vector512<TLeft>.Count);
8784
}
8885
while (!Unsafe.IsAddressGreaterThan(ref currentRightSearchSpace, ref oneVectorAwayFromRightEnd));
8986

9087
// If any elements remain, process the last vector in the search space.
91-
if (length % (uint)Vector512<TRight>.Count != 0)
88+
if (length % (uint)Vector512<TLeft>.Count != 0)
9289
{
93-
leftValues = TLoader.Load512(ref oneVectorAwayFromLeftEnd);
94-
rightValues = Vector512.LoadUnsafe(ref oneVectorAwayFromRightEnd);
95-
96-
if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues | rightValues))
97-
{
98-
return false;
99-
}
90+
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector512<TLeft>.Count);
91+
return TLoader.EqualAndAscii512(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
10092
}
10193
}
10294
else if (Avx.IsSupported && length >= (uint)Vector256<TLeft>.Count)
@@ -112,7 +104,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
112104
// Loop until either we've finished all elements or there's less than a vector's-worth remaining.
113105
do
114106
{
115-
if (!TLoader.EqualAndAscii(ref currentLeftSearchSpace, ref currentRightSearchSpace))
107+
if (!TLoader.EqualAndAscii256(ref currentLeftSearchSpace, ref currentRightSearchSpace))
116108
{
117109
return false;
118110
}
@@ -126,7 +118,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
126118
if (length % (uint)Vector256<TLeft>.Count != 0)
127119
{
128120
ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe.Add(ref left, length - (uint)Vector256<TLeft>.Count);
129-
return TLoader.EqualAndAscii(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
121+
return TLoader.EqualAndAscii256(ref oneVectorAwayFromLeftEnd, ref oneVectorAwayFromRightEnd);
130122
}
131123
}
132124
else
@@ -255,7 +247,6 @@ private static bool EqualsIgnoreCase<TLeft, TRight, TLoader>(ref TLeft left, ref
255247
{
256248
leftValues = TLoader.Load512(ref currentLeftSearchSpace);
257249
rightValues = Vector512.LoadUnsafe(ref currentRightSearchSpace);
258-
259250
if (!AllCharsInVectorAreAscii(leftValues | rightValues))
260251
{
261252
return false;
@@ -467,7 +458,8 @@ private interface ILoader<TLeft, TRight>
467458
static abstract Vector128<TRight> Load128(ref TLeft ptr);
468459
static abstract Vector256<TRight> Load256(ref TLeft ptr);
469460
static abstract Vector512<TRight> Load512(ref TLeft ptr);
470-
static abstract bool EqualAndAscii(ref TLeft left, ref TRight right);
461+
static abstract bool EqualAndAscii256(ref TLeft left, ref TRight right);
462+
static abstract bool EqualAndAscii512(ref TLeft left, ref TRight right);
471463
}
472464

473465
private readonly struct PlainLoader<T> : ILoader<T, T> where T : unmanaged, INumberBase<T>
@@ -477,10 +469,11 @@ private interface ILoader<TLeft, TRight>
477469
public static nuint Count512 => (uint)Vector512<T>.Count;
478470
public static Vector128<T> Load128(ref T ptr) => Vector128.LoadUnsafe(ref ptr);
479471
public static Vector256<T> Load256(ref T ptr) => Vector256.LoadUnsafe(ref ptr);
472+
public static Vector512<T> Load512(ref T ptr) => Vector512.LoadUnsafe(ref ptr);
480473

481474
[MethodImpl(MethodImplOptions.AggressiveInlining)]
482475
[CompExactlyDependsOn(typeof(Avx))]
483-
public static bool EqualAndAscii(ref T left, ref T right)
476+
public static bool EqualAndAscii256(ref T left, ref T right)
484477
{
485478
Vector256<T> leftValues = Vector256.LoadUnsafe(ref left);
486479
Vector256<T> rightValues = Vector256.LoadUnsafe(ref right);
@@ -493,7 +486,19 @@ public static bool EqualAndAscii(ref T left, ref T right)
493486
return true;
494487
}
495488

496-
public static Vector512<T> Load512(ref T ptr) => Vector512.LoadUnsafe(ref ptr);
489+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
490+
public static bool EqualAndAscii512(ref T left, ref T right)
491+
{
492+
Vector512<T> leftValues = Vector512.LoadUnsafe(ref left);
493+
Vector512<T> rightValues = Vector512.LoadUnsafe(ref right);
494+
495+
if (leftValues != rightValues || !AllCharsInVectorAreAscii(leftValues))
496+
{
497+
return false;
498+
}
499+
500+
return true;
501+
}
497502
}
498503

499504
private readonly struct WideningLoader : ILoader<byte, ushort>
@@ -528,9 +533,16 @@ public static Vector256<ushort> Load256(ref byte ptr)
528533
return Vector256.Create(lower, upper);
529534
}
530535

536+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
537+
public static Vector512<ushort> Load512(ref byte ptr)
538+
{
539+
(Vector256<ushort> lower, Vector256<ushort> upper) = Vector256.Widen(Vector256.LoadUnsafe(ref ptr));
540+
return Vector512.Create(lower, upper);
541+
}
542+
531543
[MethodImpl(MethodImplOptions.AggressiveInlining)]
532544
[CompExactlyDependsOn(typeof(Avx))]
533-
public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
545+
public static bool EqualAndAscii256(ref byte utf8, ref ushort utf16)
534546
{
535547
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
536548
Debug.Assert(Vector256<byte>.Count == Vector256<ushort>.Count * 2);
@@ -554,10 +566,29 @@ public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
554566
return true;
555567
}
556568

557-
public static Vector512<ushort> Load512(ref byte ptr)
569+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
570+
public static bool EqualAndAscii512(ref byte utf8, ref ushort utf16)
558571
{
559-
(Vector256<ushort> lower, Vector256<ushort> upper) = Vector256.Widen(Vector256.LoadUnsafe(ref ptr));
560-
return Vector512.Create(lower, upper);
572+
// We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
573+
Debug.Assert(Vector512<byte>.Count == Vector512<ushort>.Count * 2);
574+
575+
Vector512<byte> leftNotWidened = Vector512.LoadUnsafe(ref utf8);
576+
if (!AllCharsInVectorAreAscii(leftNotWidened))
577+
{
578+
return false;
579+
}
580+
581+
(Vector512<ushort> leftLower, Vector512<ushort> leftUpper) = Vector512.Widen(leftNotWidened);
582+
Vector512<ushort> right = Vector512.LoadUnsafe(ref utf16);
583+
Vector512<ushort> rightNext = Vector512.LoadUnsafe(ref utf16, (uint)Vector512<ushort>.Count);
584+
585+
// A branchless version of "leftLower != right || leftUpper != rightNext"
586+
if (((leftLower ^ right) | (leftUpper ^ rightNext)) != Vector512<ushort>.Zero)
587+
{
588+
return false;
589+
}
590+
591+
return true;
561592
}
562593
}
563594
}

0 commit comments

Comments
 (0)