@@ -2207,19 +2207,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2207
2207
// Initialize accumulator with zeros
2208
2208
__m256 acc = _mm256_setzero_ps ();
2209
2209
2210
- for (int i = 0 ; i < nb ; i += 2 ) {
2211
- __m256i bx = bytesFromCrumbs (x [i + 1 ].qs , x [i ].qs );
2210
+ for (int i = 0 ; i < nb / 2 ; i ++ ) {
2211
+ __m256i bx = bytesFromCrumbs (x [i * 2 + 1 ].qs , x [i * 2 ].qs );
2212
2212
2213
2213
// Compute combined scale for the block
2214
- const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 0 ].d ) * y [i /2 ].d );
2215
- const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 1 ].d ) * y [i /2 ].d );
2216
- const __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2214
+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 0 ].d ));
2215
+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 1 ].d ));
2216
+ __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2217
+ scale = _mm256_mul_ps (scale , _mm256_broadcast_ss (& y [i ].d ));
2217
2218
2218
2219
const __m256i off = _mm256_set1_epi8 (2 );
2219
2220
bx = _mm256_sub_epi8 (bx , off );
2220
2221
2221
2222
// Load y vector
2222
- const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i / 2 ].qs );
2223
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2223
2224
2224
2225
// Get absolute values of x vectors
2225
2226
const __m256i ax = _mm256_sign_epi8 (bx , bx );
@@ -2272,6 +2273,7 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2272
2273
static void ggml_vec_dot_q3_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2273
2274
assert (n % QK3_0 == 0 );
2274
2275
const int nb = n / QK3_0 ;
2276
+ assert (nb % 2 == 0 );
2275
2277
2276
2278
const block_q3_0 * restrict x = vx ;
2277
2279
const block_q8_0 * restrict y = vy ;
@@ -2281,77 +2283,80 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
2281
2283
#if defined(__AVX2__ )
2282
2284
// Initialize accumulator with zeros
2283
2285
__m128 acc = _mm_setzero_ps ();
2284
- for (int i = 0 ; i < nb ; i ++ ) {
2285
- // Compute combined scale for the block
2286
- const __m128 scale = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i ].d ) * y [i /2 ].d );
2287
-
2288
- const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2289
- const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2290
-
2291
- __m256i bxx = _mm256_set1_epi64x (x [i ].qs );
2292
-
2293
- // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2294
-
2295
- // shift the copies to be able to reach all values
2296
- // 255 192 128 64 0
2297
- // | | | |
2298
- // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2299
- // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2300
- // _______________________sssssfedcba98765432__________________________________________ shift right
2301
- // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2302
- // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2303
- // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2304
- bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2305
-
2306
- // add to itself in masked places to shift some values left one bit
2307
- // 127 64 0
2308
- // | | | | | | | | | | | | | | | |
2309
- // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2310
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2311
- // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2312
- // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2313
- //
2314
- // 255 192 128
2315
- // | | | | | | | | | | | | | | | |
2316
- // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2317
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2318
- // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2319
- // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2320
- const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2321
- bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2322
-
2323
- // collect 16 bytes from 256 into 128 bits
2324
- const __m256i shufmask = _mm256_set_epi8 (
2325
- 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2326
- -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2327
- bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2286
+ for (int i = 0 ; i < nb /2 ; i ++ ) {
2287
+ const __m128 scale_y = _mm_set1_ps (y [i ].d );
2288
+ for (int u = 0 ; u < 2 ; u ++ ) { // let the compiler unroll this
2289
+ // Compute combined scale for the block
2290
+ const __m128 scale_x = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + u ].d ));
2291
+ const __m128 scale = _mm_mul_ps (scale_x , scale_y );
2292
+
2293
+ __m256i bxx = _mm256_set1_epi64x (x [i * 2 + u ].qs );
2294
+
2295
+ // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2296
+
2297
+ // shift the copies to be able to reach all values
2298
+ // 255 192 128 64 0
2299
+ // | | | |
2300
+ // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2301
+ // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2302
+ // _______________________sssssfedcba98765432__________________________________________ shift right
2303
+ // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2304
+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2305
+ // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2306
+ const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2307
+ const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2308
+ bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2309
+
2310
+ // add to itself in masked places to shift some values left one bit
2311
+ // 127 64 0
2312
+ // | | | | | | | | | | | | | | | |
2313
+ // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2314
+ // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2315
+ // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2316
+ // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2317
+ //
2318
+ // 255 192 128
2319
+ // | | | | | | | | | | | | | | | |
2320
+ // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2321
+ // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2322
+ // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2323
+ // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2324
+ const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2325
+ bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2326
+
2327
+ // collect 16 bytes from 256 into 128 bits
2328
+ const __m256i shufmask = _mm256_set_epi8 (
2329
+ 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2330
+ -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2331
+ bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2332
+
2333
+ __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2334
+
2335
+ const __m128i mask = _mm_set1_epi8 (7 );
2336
+ bx = _mm_and_si128 (mask , bx );
2337
+
2338
+ const __m128i off = _mm_set1_epi8 (4 );
2339
+ bx = _mm_sub_epi8 (bx , off );
2340
+
2341
+ const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + u * QK3_0 ));
2328
2342
2329
- __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2330
-
2331
- const __m128i mask = _mm_set1_epi8 (7 );
2332
- bx = _mm_and_si128 (mask , bx );
2333
-
2334
- const __m128i off = _mm_set1_epi8 (4 );
2335
- bx = _mm_sub_epi8 (bx , off );
2336
-
2337
- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i /2 ].qs + (i %2 )* QK3_0 ));
2338
-
2339
- // Get absolute values of x vectors
2340
- const __m128i ax = _mm_sign_epi8 (bx , bx );
2341
- // Sign the values of the y vectors
2342
- const __m128i sy = _mm_sign_epi8 (by , bx );
2343
- // Perform multiplication and create 16-bit values
2344
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2343
+ // Get absolute values of x vectors
2344
+ const __m128i ax = _mm_sign_epi8 (bx , bx );
2345
+ // Sign the values of the y vectors
2346
+ const __m128i sy = _mm_sign_epi8 (by , bx );
2347
+ // Perform multiplication and create 16-bit values
2348
+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2345
2349
2346
- // Convert int16_t to int32_t by adding pairwise
2347
- const __m128i ones = _mm_set1_epi16 (1 );
2348
- __m128i i32 = _mm_madd_epi16 (dot , ones );
2350
+ // Convert int16_t to int32_t by adding pairwise
2351
+ const __m128i ones = _mm_set1_epi16 (1 );
2352
+ __m128i i32 = _mm_madd_epi16 (dot , ones );
2349
2353
2350
- // Convert int32_t to float
2351
- const __m128 p = _mm_cvtepi32_ps (i32 );
2354
+ // Convert int32_t to float
2355
+ const __m128 p = _mm_cvtepi32_ps (i32 );
2352
2356
2353
- // Apply the scale, and accumulate
2354
- acc = _mm_fmadd_ps (scale , p , acc );
2357
+ // Apply the scale, and accumulate
2358
+ acc = _mm_fmadd_ps (scale , p , acc );
2359
+ }
2355
2360
}
2356
2361
2357
2362
// Return horizontal sum of the acc vector
0 commit comments