@@ -427,9 +427,35 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
427
427
// quantization
428
428
//
429
429
430
- // AVX routines provided by GH user Const-me
431
- // ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
432
430
#if __AVX2__ || __AVX512F__
431
+ // Unpack 32 2-bit fields into 32 bytes
432
+ // The output vector contains 32 bytes, each one in [ 0 .. 3 ] interval
433
+ static inline __m256i bytesFromCrumbs (uint32_t packed_hi , uint32_t packed_lo ) {
434
+ __m128i bx_hi = _mm_set1_epi32 (packed_hi );
435
+ __m128i bx_lo = _mm_set1_epi32 (packed_lo );
436
+ __m256i bx = _mm256_set_m128i (bx_hi , bx_lo );
437
+
438
+ // shift counts to get all bit pairs in lowest position of each byte
439
+ const __m256i shift256 = _mm256_set_epi32 (6 , 4 , 2 , 0 ,
440
+ 6 , 4 , 2 , 0 );
441
+ bx = _mm256_srlv_epi32 (bx , shift256 );
442
+
443
+ const __m256i shufmask = _mm256_set_epi8 (15 ,11 , 7 , 3 ,
444
+ 14 ,10 , 6 , 2 ,
445
+ 13 , 9 , 5 , 1 ,
446
+ 12 , 8 , 4 , 0 ,
447
+ 15 ,11 , 7 , 3 ,
448
+ 14 ,10 , 6 , 2 ,
449
+ 13 , 9 , 5 , 1 ,
450
+ 12 , 8 , 4 , 0 );
451
+ bx = _mm256_shuffle_epi8 (bx , shufmask );
452
+
453
+ const __m256i mask = _mm256_set1_epi8 (3 );
454
+ bx = _mm256_and_si256 (mask , bx );
455
+
456
+ return bx ;
457
+ }
458
+
433
459
// Unpack 32 4-bit fields into 32 bytes
434
460
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
435
461
static inline __m256i bytesFromNibbles ( const uint8_t * rsi )
@@ -2170,6 +2196,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2170
2196
static void ggml_vec_dot_q2_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2171
2197
assert (n % QK2_0 == 0 );
2172
2198
const int nb = n / QK2_0 ;
2199
+ assert (nb % 2 == 0 );
2173
2200
2174
2201
const block_q2_0 * restrict x = vx ;
2175
2202
const block_q8_0 * restrict y = vy ;
@@ -2178,49 +2205,44 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2178
2205
2179
2206
#if defined(__AVX2__ )
2180
2207
// Initialize accumulator with zeros
2181
- __m128 acc = _mm_setzero_ps ();
2182
-
2183
- for (int i = 0 ; i < nb ; i ++ ) {
2184
- // Compute combined scale for the block
2185
- const __m128 scale = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i ].d ) * y [i /2 ].d );
2186
-
2187
- __m128i bx = _mm_set1_epi32 (x [i ].qs );
2208
+ __m256 acc = _mm256_setzero_ps ();
2188
2209
2189
- // shift counts to get all bit pairs in lowest position of each byte
2190
- const __m128i shift128 = _mm_set_epi32 (6 , 4 , 2 , 0 );
2191
- bx = _mm_srlv_epi32 (bx , shift128 );
2210
+ for (int i = 0 ; i < nb ; i += 2 ) {
2211
+ __m256i bx = bytesFromCrumbs (x [i + 1 ].qs , x [i ].qs );
2192
2212
2193
- const __m128i shufmask = _mm_set_epi8 (15 ,11 ,7 ,3 ,14 ,10 ,6 ,2 ,13 ,9 ,5 ,1 ,12 ,8 ,4 ,0 );
2194
- bx = _mm_shuffle_epi8 (bx , shufmask );
2213
+ // Compute combined scale for the block
2214
+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 0 ].d ) * y [i /2 ].d );
2215
+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 1 ].d ) * y [i /2 ].d );
2216
+ const __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2195
2217
2196
- const __m128i mask = _mm_set1_epi8 ( 3 );
2197
- bx = _mm_and_si128 ( mask , bx );
2218
+ const __m256i off = _mm256_set1_epi8 ( 2 );
2219
+ bx = _mm256_sub_epi8 ( bx , off );
2198
2220
2199
- const __m128i off = _mm_set1_epi8 (2 );
2200
- bx = _mm_sub_epi8 (bx , off );
2201
-
2202
- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i /2 ].qs + (i %2 )* QK2_0 ));
2221
+ // Load y vector
2222
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i /2 ].qs );
2203
2223
2204
2224
// Get absolute values of x vectors
2205
- const __m128i ax = _mm_sign_epi8 (bx , bx );
2225
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2206
2226
// Sign the values of the y vectors
2207
- const __m128i sy = _mm_sign_epi8 (by , bx );
2227
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2208
2228
// Perform multiplication and create 16-bit values
2209
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2229
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2210
2230
2211
2231
// Convert int16_t to int32_t by adding pairwise
2212
- const __m128i ones = _mm_set1_epi16 (1 );
2213
- __m128i i32 = _mm_madd_epi16 ( dot , ones );
2232
+ const __m256i ones = _mm256_set1_epi16 (1 );
2233
+ __m256i i32 = _mm256_madd_epi16 ( ones , dot );
2214
2234
2215
2235
// Convert int32_t to float
2216
- const __m128 p = _mm_cvtepi32_ps (i32 );
2236
+ __m256 p = _mm256_cvtepi32_ps (i32 );
2217
2237
2218
2238
// Apply the scale, and accumulate
2219
- acc = _mm_fmadd_ps (scale , p , acc );
2239
+ acc = _mm256_fmadd_ps (scale , p , acc );
2220
2240
}
2221
2241
2222
2242
// Return horizontal sum of the acc vector
2223
- __m128 res = _mm_add_ps (acc , _mm_movehl_ps (acc , acc ));
2243
+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2244
+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2245
+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2224
2246
res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2225
2247
sumf = _mm_cvtss_f32 (res );
2226
2248
#else
0 commit comments