@@ -473,16 +473,23 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
473
473
//
474
474
475
475
#if __AVX__ || __AVX2__ || __AVX512F__
476
- // multiply int8_t, add results pairwise twice
477
- static inline __m128i mul_sum_i8_pairs (const __m128i x , const __m128i y ) {
478
- // Get absolute values of x vectors
479
- const __m128i ax = _mm_sign_epi8 (x , x );
480
- // Sign the values of the y vectors
481
- const __m128i sy = _mm_sign_epi8 (y , x );
482
- // Perform multiplication and create 16-bit values
483
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
484
- const __m128i ones = _mm_set1_epi16 (1 );
485
- return _mm_madd_epi16 (ones , dot );
476
+ // Unpack 16 4-bit fields into 16 bytes
477
+ // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
478
+ static inline __m128i bytes_from_nibbles_16 (const uint8_t * rsi )
479
+ {
480
+ // Load 8 bytes from memory
481
+ __m128i tmp = _mm_loadl_epi64 ( ( const __m128i * )rsi );
482
+
483
+ // Expand bytes into uint16_t values
484
+ __m128i bytes = _mm_cvtepu8_epi16 ( tmp );
485
+
486
+ // Unpack values into individual bytes
487
+ const __m128i lowMask = _mm_set1_epi8 ( 0xF );
488
+ __m128i high = _mm_andnot_si128 ( lowMask , bytes );
489
+ __m128i low = _mm_and_si128 ( lowMask , bytes );
490
+ high = _mm_slli_epi16 ( high , 4 );
491
+ bytes = _mm_or_si128 ( low , high );
492
+ return bytes ;
486
493
}
487
494
488
495
// horizontally add 8 floats
@@ -529,10 +536,19 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
529
536
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
530
537
static inline __m256i bytes_from_nibbles_32 (const uint8_t * rsi )
531
538
{
532
- const __m128i tmp = _mm_loadu_si128 ((const __m128i * )rsi );
533
- const __m256i bytes = _mm256_set_m128i (_mm_srli_epi16 (tmp , 4 ), tmp );
539
+ // Load 16 bytes from memory
540
+ __m128i tmp = _mm_loadu_si128 ( ( const __m128i * )rsi );
541
+
542
+ // Expand bytes into uint16_t values
543
+ __m256i bytes = _mm256_cvtepu8_epi16 ( tmp );
544
+
545
+ // Unpack values into individual bytes
534
546
const __m256i lowMask = _mm256_set1_epi8 ( 0xF );
535
- return _mm256_and_si256 (lowMask , bytes );
547
+ __m256i high = _mm256_andnot_si256 ( lowMask , bytes );
548
+ __m256i low = _mm256_and_si256 ( lowMask , bytes );
549
+ high = _mm256_slli_epi16 ( high , 4 );
550
+ bytes = _mm256_or_si256 ( low , high );
551
+ return bytes ;
536
552
}
537
553
538
554
// add int16_t pairwise and return as float vector
@@ -2109,23 +2125,31 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2109
2125
// Compute combined scale for the block
2110
2126
const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2111
2127
2112
- const __m128i lowMask = _mm_set1_epi8 (0xF );
2113
- const __m128i off = _mm_set1_epi8 (8 );
2128
+ __m128i i32 [2 ];
2129
+ for (int j = 0 ; j < 2 ; ++ j ) {
2130
+ // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2131
+ __m128i bx = bytes_from_nibbles_16 (x [i ].qs + 8 * j );
2132
+ __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 * j ));
2133
+
2134
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2135
+ const __m128i off = _mm_set1_epi8 ( 8 );
2136
+ bx = _mm_sub_epi8 ( bx , off );
2114
2137
2115
- const __m128i tmp = _mm_loadu_si128 ((const __m128i * )x [i ].qs );
2138
+ // Get absolute values of x vectors
2139
+ const __m128i ax = _mm_sign_epi8 (bx , bx );
2116
2140
2117
- __m128i bx = _mm_and_si128 (lowMask , tmp );
2118
- __m128i by = _mm_loadu_si128 ((const __m128i * )y [i ].qs );
2119
- bx = _mm_sub_epi8 (bx , off );
2120
- const __m128i i32_0 = mul_sum_i8_pairs (bx , by );
2141
+ // Sign the values of the y vectors
2142
+ const __m128i sy = _mm_sign_epi8 (by , bx );
2121
2143
2122
- bx = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp , 4 ));
2123
- by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 ));
2124
- bx = _mm_sub_epi8 (bx , off );
2125
- const __m128i i32_1 = mul_sum_i8_pairs (bx , by );
2144
+ // Perform multiplication and create 16-bit values
2145
+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2146
+
2147
+ const __m128i ones = _mm_set1_epi16 (1 );
2148
+ i32 [j ] = _mm_madd_epi16 (ones , dot );
2149
+ }
2126
2150
2127
2151
// Convert int32_t to float
2128
- __m256 p = _mm256_cvtepi32_ps (_mm256_set_m128i (i32_0 , i32_1 ));
2152
+ __m256 p = _mm256_cvtepi32_ps ( _mm256_set_m128i ( i32 [ 0 ], i32 [ 1 ] ));
2129
2153
// Apply the scale, and accumulate
2130
2154
acc = _mm256_add_ps (_mm256_mul_ps ( d , p ), acc );
2131
2155
}
@@ -2472,8 +2496,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2472
2496
int sumi = 0 ;
2473
2497
2474
2498
for (int j = 0 ; j < qk /2 ; ++ j ) {
2475
- const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
2476
- const uint8_t xh_1 = ((qh >> ( j + 12 )) ) & 0x10 ;
2499
+ const uint8_t xh_0 = ((qh & ( 1u << ( j + 0 ))) >> (j + 0 )) << 4 ;
2500
+ const uint8_t xh_1 = ((qh & ( 1u << ( j + 16 ))) >> ( j + 12 )) ;
2477
2501
2478
2502
const int32_t x0 = ((x [i ].qs [j ] & 0x0F ) | xh_0 ) - 16 ;
2479
2503
const int32_t x1 = ((x [i ].qs [j ] >> 4 ) | xh_1 ) - 16 ;
@@ -2698,8 +2722,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2698
2722
int sumi = 0 ;
2699
2723
2700
2724
for (int j = 0 ; j < qk /2 ; ++ j ) {
2701
- const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
2702
- const uint8_t xh_1 = ((qh >> ( j + 12 )) ) & 0x10 ;
2725
+ const uint8_t xh_0 = ((qh & ( 1u << ( j + 0 ))) >> (j + 0 )) << 4 ;
2726
+ const uint8_t xh_1 = ((qh & ( 1u << ( j + 16 ))) >> ( j + 12 )) ;
2703
2727
2704
2728
const int32_t x0 = (x [i ].qs [j ] & 0xF ) | xh_0 ;
2705
2729
const int32_t x1 = (x [i ].qs [j ] >> 4 ) | xh_1 ;
0 commit comments