@@ -61,42 +61,34 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
6161 }
6262 }
6363 }
64- else if ( Vector512 . IsHardwareAccelerated && length >= ( uint ) Vector512 < TRight > . Count )
64+ else if ( Vector512 . IsHardwareAccelerated && length >= ( uint ) Vector512 < TLeft > . Count )
6565 {
6666 ref TLeft currentLeftSearchSpace = ref left ;
67- ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe . Add ( ref currentLeftSearchSpace , length - TLoader . Count512 ) ;
6867 ref TRight currentRightSearchSpace = ref right ;
69- ref TRight oneVectorAwayFromRightEnd = ref Unsafe . Add ( ref currentRightSearchSpace , length - ( uint ) Vector512 < TRight > . Count ) ;
70-
71- Vector512 < TRight > leftValues ;
72- Vector512 < TRight > rightValues ;
68+ // Add Vector512<TLeft>.Count because TLeft == TRight
69+ // Or we are in the Widen case where we iterate 2 * TRight.Count which is the same as TLeft.Count
70+ Debug . Assert ( Vector512 < TLeft > . Count == Vector512 < TRight > . Count
71+ || ( typeof ( TLoader ) == typeof ( WideningLoader ) && Vector512 < TLeft > . Count == Vector512 < TRight > . Count * 2 ) ) ;
72+ ref TRight oneVectorAwayFromRightEnd = ref Unsafe . Add ( ref currentRightSearchSpace , length - ( uint ) Vector512 < TLeft > . Count ) ;
7373
7474 // Loop until either we've finished all elements or there's less than a vector's-worth remaining.
7575 do
7676 {
77- leftValues = TLoader . Load512 ( ref currentLeftSearchSpace ) ;
78- rightValues = Vector512 . LoadUnsafe ( ref currentRightSearchSpace ) ;
79-
80- if ( leftValues != rightValues || ! AllCharsInVectorAreAscii ( leftValues | rightValues ) )
77+ if ( ! TLoader . EqualAndAscii512 ( ref currentLeftSearchSpace , ref currentRightSearchSpace ) )
8178 {
8279 return false ;
8380 }
8481
85- currentRightSearchSpace = ref Unsafe . Add ( ref currentRightSearchSpace , Vector512 < TRight > . Count ) ;
86- currentLeftSearchSpace = ref Unsafe . Add ( ref currentLeftSearchSpace , TLoader . Count512 ) ;
82+ currentRightSearchSpace = ref Unsafe . Add ( ref currentRightSearchSpace , Vector512 < TLeft > . Count ) ;
83+ currentLeftSearchSpace = ref Unsafe . Add ( ref currentLeftSearchSpace , Vector512 < TLeft > . Count ) ;
8784 }
8885 while ( ! Unsafe . IsAddressGreaterThan ( ref currentRightSearchSpace , ref oneVectorAwayFromRightEnd ) ) ;
8986
9087 // If any elements remain, process the last vector in the search space.
91- if ( length % ( uint ) Vector512 < TRight > . Count != 0 )
88+ if ( length % ( uint ) Vector512 < TLeft > . Count != 0 )
9289 {
93- leftValues = TLoader . Load512 ( ref oneVectorAwayFromLeftEnd ) ;
94- rightValues = Vector512 . LoadUnsafe ( ref oneVectorAwayFromRightEnd ) ;
95-
96- if ( leftValues != rightValues || ! AllCharsInVectorAreAscii ( leftValues | rightValues ) )
97- {
98- return false ;
99- }
90+ ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe . Add ( ref left , length - ( uint ) Vector512 < TLeft > . Count ) ;
91+ return TLoader . EqualAndAscii512 ( ref oneVectorAwayFromLeftEnd , ref oneVectorAwayFromRightEnd ) ;
10092 }
10193 }
10294 else if ( Avx . IsSupported && length >= ( uint ) Vector256 < TLeft > . Count )
@@ -112,7 +104,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
112104 // Loop until either we've finished all elements or there's less than a vector's-worth remaining.
113105 do
114106 {
115- if ( ! TLoader . EqualAndAscii ( ref currentLeftSearchSpace , ref currentRightSearchSpace ) )
107+ if ( ! TLoader . EqualAndAscii256 ( ref currentLeftSearchSpace , ref currentRightSearchSpace ) )
116108 {
117109 return false ;
118110 }
@@ -126,7 +118,7 @@ private static bool Equals<TLeft, TRight, TLoader>(ref TLeft left, ref TRight ri
126118 if ( length % ( uint ) Vector256 < TLeft > . Count != 0 )
127119 {
128120 ref TLeft oneVectorAwayFromLeftEnd = ref Unsafe . Add ( ref left , length - ( uint ) Vector256 < TLeft > . Count ) ;
129- return TLoader . EqualAndAscii ( ref oneVectorAwayFromLeftEnd , ref oneVectorAwayFromRightEnd ) ;
121+ return TLoader . EqualAndAscii256 ( ref oneVectorAwayFromLeftEnd , ref oneVectorAwayFromRightEnd ) ;
130122 }
131123 }
132124 else
@@ -255,7 +247,6 @@ private static bool EqualsIgnoreCase<TLeft, TRight, TLoader>(ref TLeft left, ref
255247 {
256248 leftValues = TLoader . Load512 ( ref currentLeftSearchSpace ) ;
257249 rightValues = Vector512 . LoadUnsafe ( ref currentRightSearchSpace ) ;
258-
259250 if ( ! AllCharsInVectorAreAscii ( leftValues | rightValues ) )
260251 {
261252 return false ;
@@ -467,7 +458,8 @@ private interface ILoader<TLeft, TRight>
467458 static abstract Vector128 < TRight > Load128 ( ref TLeft ptr ) ;
468459 static abstract Vector256 < TRight > Load256 ( ref TLeft ptr ) ;
469460 static abstract Vector512 < TRight > Load512 ( ref TLeft ptr ) ;
470- static abstract bool EqualAndAscii ( ref TLeft left , ref TRight right ) ;
461+ static abstract bool EqualAndAscii256 ( ref TLeft left , ref TRight right ) ;
462+ static abstract bool EqualAndAscii512 ( ref TLeft left , ref TRight right ) ;
471463 }
472464
473465 private readonly struct PlainLoader < T > : ILoader < T , T > where T : unmanaged, INumberBase < T >
@@ -477,10 +469,11 @@ private interface ILoader<TLeft, TRight>
477469 public static nuint Count512 => ( uint ) Vector512 < T > . Count ;
478470 public static Vector128 < T > Load128 ( ref T ptr ) => Vector128 . LoadUnsafe ( ref ptr ) ;
479471 public static Vector256 < T > Load256 ( ref T ptr ) => Vector256 . LoadUnsafe ( ref ptr ) ;
472+ public static Vector512 < T > Load512 ( ref T ptr ) => Vector512 . LoadUnsafe ( ref ptr ) ;
480473
481474 [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
482475 [ CompExactlyDependsOn ( typeof ( Avx ) ) ]
483- public static bool EqualAndAscii ( ref T left , ref T right )
476+ public static bool EqualAndAscii256 ( ref T left , ref T right )
484477 {
485478 Vector256 < T > leftValues = Vector256 . LoadUnsafe ( ref left ) ;
486479 Vector256 < T > rightValues = Vector256 . LoadUnsafe ( ref right ) ;
@@ -493,7 +486,19 @@ public static bool EqualAndAscii(ref T left, ref T right)
493486 return true ;
494487 }
495488
496- public static Vector512 < T > Load512 ( ref T ptr ) => Vector512 . LoadUnsafe ( ref ptr ) ;
489+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
490+ public static bool EqualAndAscii512 ( ref T left , ref T right )
491+ {
492+ Vector512 < T > leftValues = Vector512 . LoadUnsafe ( ref left ) ;
493+ Vector512 < T > rightValues = Vector512 . LoadUnsafe ( ref right ) ;
494+
495+ if ( leftValues != rightValues || ! AllCharsInVectorAreAscii ( leftValues ) )
496+ {
497+ return false ;
498+ }
499+
500+ return true ;
501+ }
497502 }
498503
499504 private readonly struct WideningLoader : ILoader < byte , ushort >
@@ -528,9 +533,16 @@ public static Vector256<ushort> Load256(ref byte ptr)
528533 return Vector256 . Create ( lower , upper ) ;
529534 }
530535
536+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
537+ public static Vector512 < ushort > Load512 ( ref byte ptr )
538+ {
539+ ( Vector256 < ushort > lower , Vector256 < ushort > upper ) = Vector256 . Widen ( Vector256 . LoadUnsafe ( ref ptr ) ) ;
540+ return Vector512 . Create ( lower , upper ) ;
541+ }
542+
531543 [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
532544 [ CompExactlyDependsOn ( typeof ( Avx ) ) ]
533- public static bool EqualAndAscii ( ref byte utf8 , ref ushort utf16 )
545+ public static bool EqualAndAscii256 ( ref byte utf8 , ref ushort utf16 )
534546 {
535547 // We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
536548 Debug . Assert ( Vector256 < byte > . Count == Vector256 < ushort > . Count * 2 ) ;
@@ -554,10 +566,29 @@ public static bool EqualAndAscii(ref byte utf8, ref ushort utf16)
554566 return true ;
555567 }
556568
557- public static Vector512 < ushort > Load512 ( ref byte ptr )
569+ [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
570+ public static bool EqualAndAscii512 ( ref byte utf8 , ref ushort utf16 )
558571 {
559- ( Vector256 < ushort > lower , Vector256 < ushort > upper ) = Vector256 . Widen ( Vector256 . LoadUnsafe ( ref ptr ) ) ;
560- return Vector512 . Create ( lower , upper ) ;
572+ // We widen the utf8 param so we can compare it to utf16, this doubles how much of the utf16 vector we search
573+ Debug . Assert ( Vector512 < byte > . Count == Vector512 < ushort > . Count * 2 ) ;
574+
575+ Vector512 < byte > leftNotWidened = Vector512 . LoadUnsafe ( ref utf8 ) ;
576+ if ( ! AllCharsInVectorAreAscii ( leftNotWidened ) )
577+ {
578+ return false ;
579+ }
580+
581+ ( Vector512 < ushort > leftLower , Vector512 < ushort > leftUpper ) = Vector512 . Widen ( leftNotWidened ) ;
582+ Vector512 < ushort > right = Vector512 . LoadUnsafe ( ref utf16 ) ;
583+ Vector512 < ushort > rightNext = Vector512 . LoadUnsafe ( ref utf16 , ( uint ) Vector512 < ushort > . Count ) ;
584+
585+ // A branchless version of "leftLower != right || leftUpper != rightNext"
586+ if ( ( ( leftLower ^ right ) | ( leftUpper ^ rightNext ) ) != Vector512 < ushort > . Zero )
587+ {
588+ return false ;
589+ }
590+
591+ return true ;
561592 }
562593 }
563594 }
0 commit comments