@@ -61,6 +61,36 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
6161 }
6262 }
6363 }
64+ else if ( Vector512 . IsHardwareAccelerated && length >= ( uint ) Vector512 < TLeft > . Count )
65+ {
66+ ref TLeft currentLeftSearchSpace = ref left ;
67+ ref TRight currentRightSearchSpace = ref right ;
68+ // Add Vector512<TLeft>.Count because TLeft == TRight
69+ // Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count
70+ Debug . Assert ( Vector512 < TLeft > . Count == Vector512 < TRight > . Count
71+ || ( typeof ( TLoader ) == typeof ( WideningLoader ) && Vector512 < TLeft > . Count == Vector512 < TRight > . Count * 2 ) ) ;
72+ ref TRight oneVectorAwayFromRightEnd = ref Unsafe . Add ( ref currentRightSearchSpace , length - ( uint ) Vector512 < TLeft > . Count ) ;
73+
74+ // Loop until either we've finished all elements or there's less than a vector's-worth remaining.
75+ do
76+ {
77+ if ( ! TLoader . EqualAndAscii512 ( ref currentLeftSearchSpace , ref currentRightSearchSpace ) )
78+ {
79+ return false ;
80+ }
81+
82+ currentRightSearchSpace = ref Unsafe . Add ( ref currentRightSearchSpace , Vector512 < TLeft > . Count ) ;
83+ currentLeftSearchSpace = ref Unsafe . Add ( ref currentLeftSearchSpace , Vector512 < TLeft > . Count ) ;
84+ }
85+ while ( ! Unsafe . IsAddressGreaterThan ( ref currentRightSearchSpace , ref oneVectorAwayFromRightEnd ) ) ;
86+
87+ // If any elements remain, process the last vector in the search space.
88+ if ( length % ( uint ) Vector512 < TLeft > . Count != 0 )
89+ {
90+ ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe . Add ( ref left , length - ( uint ) Vector512 < TLeft > . Count ) ;
91+ return TLoader . EqualAndAscii512 ( ref oneVectorAwayFromLeftEnd , ref oneVectorAwayFromRightEnd ) ;
92+ }
93+ }
6494 else if ( Avx . IsSupported && length >= ( uint ) Vector256 < TLeft > . Count )
6595 {
6696 ref TLeft currentLeftSearchSpace = ref left ;
@@ -74,7 +104,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
74104 // Loop until either we've finished all elements or there's less than a vector's-worth remaining.
75105 do
76106 {
77- if ( ! TLoader . EqualAndAscii ( ref currentLeftSearchSpace , ref currentRightSearchSpace ) )
107+ if ( ! TLoader . EqualAndAscii256 ( ref currentLeftSearchSpace , ref currentRightSearchSpace ) )
78108 {
79109 return false ;
80110 }
@@ -88,7 +118,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
88118 if ( length % ( uint ) Vector256 < TLeft > . Count != 0 )
89119 {
90120 ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe . Add ( ref left , length - ( uint ) Vector256 < TLeft > . Count ) ;
91- return TLoader . EqualAndAscii ( ref oneVectorAwayFromLeftEnd , ref oneVectorAwayFromRightEnd ) ;
121+ return TLoader . EqualAndAscii256 ( ref oneVectorAwayFromLeftEnd , ref oneVectorAwayFromRightEnd ) ;
92122 }
93123 }
94124 else
@@ -198,6 +228,77 @@ private static bool EqualsIgnoreCase<TLeft, TRight, TLoader>(ref TLeft left, ref
198228 }
199229 }
200230 }
231+ else if ( Vector512 . IsHardwareAccelerated && length >= ( uint ) Vector512 < TRight > . Count )
232+ {
233+ ref TLeft currentLeftSearchSpace = ref left ;
234+ ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe . Add ( ref currentLeftSearchSpace , length - TLoader . Count512 ) ;
235+ ref TRight currentRightSearchSpace = ref right ;
236+ ref TRight oneVectorAwayFromRightEnd = ref Unsafe . Add ( ref currentRightSearchSpace , length - ( uint ) Vector512 < TRight > . Count ) ;
237+
238+ Vector512 < TRight > leftValues ;
239+ Vector512 < TRight > rightValues ;
240+
241+ Vector512 < TRight > loweringMask = Vector512 . Create ( TRight . CreateTruncating ( 0x20 ) ) ;
242+ Vector512 < TRight > vecA = Vector512 . Create ( TRight . CreateTruncating ( 'a' ) ) ;
243+ Vector512 < TRight > vecZMinusA = Vector512 . Create ( TRight . CreateTruncating ( ( 'z' - 'a' ) ) ) ;
244+
245+ // Loop until either we've finished all elements or there's less than a vector's-worth remaining.
246+ do
247+ {
248+ leftValues = TLoader . Load512 ( ref currentLeftSearchSpace ) ;
249+ rightValues = Vector512 . LoadUnsafe ( ref currentRightSearchSpace ) ;
250+ if ( ! AllCharsInVectorAreAscii ( leftValues | rightValues ) )
251+ {
252+ return false ;
253+ }
254+
255+ Vector512 < TRight > notEquals = ~ Vector512 . Equals ( leftValues , rightValues ) ;
256+
257+ if ( notEquals != Vector512 < TRight > . Zero )
258+ {
259+ // not exact match
260+
261+ leftValues |= loweringMask ;
262+ rightValues |= loweringMask ;
263+
264+ if ( Vector512 . GreaterThanAny ( ( leftValues - vecA ) & notEquals , vecZMinusA ) || leftValues != rightValues )
265+ {
266+ return false ; // first input isn't in [A-Za-z], and not exact match of lowered
267+ }
268+ }
269+
270+ currentRightSearchSpace = ref Unsafe . Add ( ref currentRightSearchSpace , ( uint ) Vector512 < TRight > . Count ) ;
271+ currentLeftSearchSpace = ref Unsafe . Add ( ref currentLeftSearchSpace , TLoader . Count512 ) ;
272+ }
273+ while ( ! Unsafe . IsAddressGreaterThan ( ref currentRightSearchSpace , ref oneVectorAwayFromRightEnd ) ) ;
274+
275+ // If any elements remain, process the last vector in the search space.
276+ if ( length % ( uint ) Vector512 < TRight > . Count != 0 )
277+ {
278+ leftValues = TLoader . Load512 ( ref oneVectorAwayFromLeftEnd ) ;
279+ rightValues = Vector512 . LoadUnsafe ( ref oneVectorAwayFromRightEnd ) ;
280+
281+ if ( ! AllCharsInVectorAreAscii ( leftValues | rightValues ) )
282+ {
283+ return false ;
284+ }
285+
286+ Vector512 < TRight > notEquals = ~ Vector512 . Equals ( leftValues , rightValues ) ;
287+
288+ if ( notEquals != Vector512 < TRight > . Zero )
289+ {
290+ // not exact match
291+
292+ leftValues |= loweringMask ;
293+ rightValues |= loweringMask ;
294+
295+ if ( Vector512 . GreaterThanAny ( ( leftValues - vecA ) & notEquals , vecZMinusA ) || leftValues != rightValues )
296+ {
297+ return false ; // first input isn't in [A-Za-z], and not exact match of lowered
298+ }
299+ }
300+ }
301+ }
201302 else if ( Avx . IsSupported && length >= ( uint ) Vector256 < TRight > . Count )
202303 {
203304 ref TLeft currentLeftSearchSpace = ref left ;
@@ -353,21 +454,26 @@ private interface ILoader<TLeft, TRight>
353454 {
354455 static abstract nuint Count128 { get ; }
355456 static abstract nuint Count256 { get ; }
457+ static abstract nuint Count512 { get ; }
356458 static abstract Vector128 < TRight > Load128 ( ref TLeft ptr ) ;
357459 static abstract Vector256 < TRight > Load256 ( ref TLeft ptr ) ;
358- static abstract bool EqualAndAscii ( ref TLeft left , ref TRight right ) ;
460+ static abstract Vector512 < TRight > Load512 ( ref TLeft ptr ) ;
461+ static abstract bool EqualAndAscii256 ( ref TLeft left , ref TRight right ) ;
462+ static abstract bool EqualAndAscii512 ( ref TLeft left , ref TRight right ) ;
359463 }
360464
361465 private readonly struct PlainLoader < T > : ILoader < T , T > where T : unmanaged, INumberBase < T >
362466 {
363467 public static nuint Count128 => ( uint ) Vector128 < T > . Count ;
364468 public static nuint Count256 => ( uint ) Vector256 < T > . Count ;
469+ public static nuint Count512 => ( uint ) Vector512 < T > . Count ;
365470 public static Vector128 < T > Load128 ( ref T ptr ) => Vector128 . LoadUnsafe ( ref ptr ) ;
366471 public static Vector256 < T > Load256 ( ref T ptr ) => Vector256 . LoadUnsafe ( ref ptr ) ;
472+ public static Vector512 < T > Load512 ( ref T ptr ) => Vector512 . LoadUnsafe ( ref ptr ) ;
367473
368474 [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
369475 [ CompExactlyDependsOn ( typeof ( Avx ) ) ]
370- public static bool EqualAndAscii ( ref T left , ref T right )
476+ public static bool EqualAndAscii256 ( ref T left , ref T right )
371477 {
372478 Vector256 < T > leftValues = Vector256 . LoadUnsafe ( ref left ) ;
373479 Vector256 < T > rightValues = Vector256 . LoadUnsafe ( ref right ) ;
@@ -379,12 +485,27 @@ public static bool EqualAndAscii(ref T left, ref T right)
379485
380486 return true ;
381487 }
488+
489+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
490+ public static bool EqualAndAscii512 ( ref T left , ref T right )
491+ {
492+ Vector512 < T > leftValues = Vector512 . LoadUnsafe ( ref left ) ;
493+ Vector512 < T > rightValues = Vector512 . LoadUnsafe ( ref right ) ;
494+
495+ if ( leftValues != rightValues || ! AllCharsInVectorAreAscii ( leftValues ) )
496+ {
497+ return false ;
498+ }
499+
500+ return true ;
501+ }
382502 }
383503
384504 private readonly struct WideningLoader : ILoader < byte , ushort >
385505 {
386506 public static nuint Count128 => sizeof ( long ) ;
387507 public static nuint Count256 => ( uint ) Vector128 < byte > . Count ;
508+ public static nuint Count512 => ( uint ) Vector256 < byte > . Count ;
388509
389510 [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
390511 public static Vector128 < ushort > Load128 ( ref byte ptr )
@@ -412,9 +533,15 @@ public static Vector256<ushort> Load256(ref byte ptr)
412533 return Vector256 . Create ( lower , upper ) ;
413534 }
414535
536+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
537+ public static Vector512 < ushort > Load512 ( ref byte ptr )
538+ {
539+ return Vector512 . WidenLower ( Vector256 . LoadUnsafe ( ref ptr ) . ToVector512 ( ) ) ;
540+ }
541+
415542 [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
416543 [ CompExactlyDependsOn ( typeof ( Avx ) ) ]
417- public static bool EqualAndAscii ( ref byte utf8 , ref ushort utf16 )
544+ public static bool EqualAndAscii256 ( ref byte utf8 , ref ushort utf16 )
418545 {
419546 // We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
420547 Debug . Assert ( Vector256 < byte > . Count == Vector256 < ushort > . Count * 2 ) ;
@@ -437,6 +564,31 @@ public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
437564
438565 return true ;
439566 }
567+
568+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
569+ public static bool EqualAndAscii512 ( ref byte utf8 , ref ushort utf16 )
570+ {
571+ // We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
572+ Debug . Assert ( Vector512 < byte > . Count == Vector512 < ushort > . Count * 2 ) ;
573+
574+ Vector512 < byte > leftNotWidened = Vector512 . LoadUnsafe ( ref utf8 ) ;
575+ if ( ! AllCharsInVectorAreAscii ( leftNotWidened ) )
576+ {
577+ return false ;
578+ }
579+
580+ ( Vector512 < ushort > leftLower , Vector512 < ushort > leftUpper ) = Vector512 . Widen ( leftNotWidened ) ;
581+ Vector512 < ushort > right = Vector512 . LoadUnsafe ( ref utf16 ) ;
582+ Vector512 < ushort > rightNext = Vector512 . LoadUnsafe ( ref utf16 , ( uint ) Vector512 < ushort > . Count ) ;
583+
584+ // A branchless version of "leftLower != right || leftUpper != rightNext"
585+ if ( ( ( leftLower ^ right ) | ( leftUpper ^ rightNext ) ) != Vector512 < ushort > . Zero )
586+ {
587+ return false ;
588+ }
589+
590+ return true ;
591+ }
440592 }
441593 }
442594}
0 commit comments