@@ -2405,19 +2405,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2405
2405
// Initialize accumulator with zeros
2406
2406
__m256 acc = _mm256_setzero_ps ();
2407
2407
2408
- for (int i = 0 ; i < nb ; i += 2 ) {
2409
- __m256i bx = bytesFromCrumbs (x [i + 1 ].qs , x [i ].qs );
2408
+ for (int i = 0 ; i < nb / 2 ; i ++ ) {
2409
+ __m256i bx = bytesFromCrumbs (x [i * 2 + 1 ].qs , x [i * 2 ].qs );
2410
2410
2411
2411
// Compute combined scale for the block
2412
- const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 0 ].d ) * y [i /2 ].d );
2413
- const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i + 1 ].d ) * y [i /2 ].d );
2414
- const __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2412
+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 0 ].d ));
2413
+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 1 ].d ));
2414
+ __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2415
+ scale = _mm256_mul_ps (scale , _mm256_broadcast_ss (& y [i ].d ));
2415
2416
2416
2417
const __m256i off = _mm256_set1_epi8 (2 );
2417
2418
bx = _mm256_sub_epi8 (bx , off );
2418
2419
2419
2420
// Load y vector
2420
- const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i / 2 ].qs );
2421
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2421
2422
2422
2423
// Get absolute values of x vectors
2423
2424
const __m256i ax = _mm256_sign_epi8 (bx , bx );
@@ -2470,6 +2471,7 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2470
2471
static void ggml_vec_dot_q3_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2471
2472
assert (n % QK3_0 == 0 );
2472
2473
const int nb = n / QK3_0 ;
2474
+ assert (nb % 2 == 0 );
2473
2475
2474
2476
const block_q3_0 * restrict x = vx ;
2475
2477
const block_q8_0 * restrict y = vy ;
@@ -2479,77 +2481,80 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
2479
2481
#if defined(__AVX2__ )
2480
2482
// Initialize accumulator with zeros
2481
2483
__m128 acc = _mm_setzero_ps ();
2482
- for (int i = 0 ; i < nb ; i ++ ) {
2483
- // Compute combined scale for the block
2484
- const __m128 scale = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i ].d ) * y [i /2 ].d );
2485
-
2486
- const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2487
- const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2488
-
2489
- __m256i bxx = _mm256_set1_epi64x (x [i ].qs );
2490
-
2491
- // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2492
-
2493
- // shift the copies to be able to reach all values
2494
- // 255 192 128 64 0
2495
- // | | | |
2496
- // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2497
- // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2498
- // _______________________sssssfedcba98765432__________________________________________ shift right
2499
- // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2500
- // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2501
- // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2502
- bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2503
-
2504
- // add to itself in masked places to shift some values left one bit
2505
- // 127 64 0
2506
- // | | | | | | | | | | | | | | | |
2507
- // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2508
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2509
- // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2510
- // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2511
- //
2512
- // 255 192 128
2513
- // | | | | | | | | | | | | | | | |
2514
- // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2515
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2516
- // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2517
- // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2518
- const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2519
- bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2520
-
2521
- // collect 16 bytes from 256 into 128 bits
2522
- const __m256i shufmask = _mm256_set_epi8 (
2523
- 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2524
- -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2525
- bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2484
+ for (int i = 0 ; i < nb /2 ; i ++ ) {
2485
+ const __m128 scale_y = _mm_set1_ps (y [i ].d );
2486
+ for (int u = 0 ; u < 2 ; u ++ ) { // let the compiler unroll this
2487
+ // Compute combined scale for the block
2488
+ const __m128 scale_x = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + u ].d ));
2489
+ const __m128 scale = _mm_mul_ps (scale_x , scale_y );
2490
+
2491
+ __m256i bxx = _mm256_set1_epi64x (x [i * 2 + u ].qs );
2492
+
2493
+ // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2494
+
2495
+ // shift the copies to be able to reach all values
2496
+ // 255 192 128 64 0
2497
+ // | | | |
2498
+ // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2499
+ // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2500
+ // _______________________sssssfedcba98765432__________________________________________ shift right
2501
+ // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2502
+ // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2503
+ // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2504
+ const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2505
+ const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2506
+ bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2507
+
2508
+ // add to itself in masked places to shift some values left one bit
2509
+ // 127 64 0
2510
+ // | | | | | | | | | | | | | | | |
2511
+ // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2512
+ // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2513
+ // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2514
+ // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2515
+ //
2516
+ // 255 192 128
2517
+ // | | | | | | | | | | | | | | | |
2518
+ // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2519
+ // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2520
+ // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2521
+ // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2522
+ const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2523
+ bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2524
+
2525
+ // collect 16 bytes from 256 into 128 bits
2526
+ const __m256i shufmask = _mm256_set_epi8 (
2527
+ 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2528
+ -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2529
+ bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2530
+
2531
+ __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2532
+
2533
+ const __m128i mask = _mm_set1_epi8 (7 );
2534
+ bx = _mm_and_si128 (mask , bx );
2535
+
2536
+ const __m128i off = _mm_set1_epi8 (4 );
2537
+ bx = _mm_sub_epi8 (bx , off );
2538
+
2539
+ const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + u * QK3_0 ));
2526
2540
2527
- __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2528
-
2529
- const __m128i mask = _mm_set1_epi8 (7 );
2530
- bx = _mm_and_si128 (mask , bx );
2531
-
2532
- const __m128i off = _mm_set1_epi8 (4 );
2533
- bx = _mm_sub_epi8 (bx , off );
2534
-
2535
- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i /2 ].qs + (i %2 )* QK3_0 ));
2536
-
2537
- // Get absolute values of x vectors
2538
- const __m128i ax = _mm_sign_epi8 (bx , bx );
2539
- // Sign the values of the y vectors
2540
- const __m128i sy = _mm_sign_epi8 (by , bx );
2541
- // Perform multiplication and create 16-bit values
2542
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2541
+ // Get absolute values of x vectors
2542
+ const __m128i ax = _mm_sign_epi8 (bx , bx );
2543
+ // Sign the values of the y vectors
2544
+ const __m128i sy = _mm_sign_epi8 (by , bx );
2545
+ // Perform multiplication and create 16-bit values
2546
+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2543
2547
2544
- // Convert int16_t to int32_t by adding pairwise
2545
- const __m128i ones = _mm_set1_epi16 (1 );
2546
- __m128i i32 = _mm_madd_epi16 (dot , ones );
2548
+ // Convert int16_t to int32_t by adding pairwise
2549
+ const __m128i ones = _mm_set1_epi16 (1 );
2550
+ __m128i i32 = _mm_madd_epi16 (dot , ones );
2547
2551
2548
- // Convert int32_t to float
2549
- const __m128 p = _mm_cvtepi32_ps (i32 );
2552
+ // Convert int32_t to float
2553
+ const __m128 p = _mm_cvtepi32_ps (i32 );
2550
2554
2551
- // Apply the scale, and accumulate
2552
- acc = _mm_fmadd_ps (scale , p , acc );
2555
+ // Apply the scale, and accumulate
2556
+ acc = _mm_fmadd_ps (scale , p , acc );
2557
+ }
2553
2558
}
2554
2559
2555
2560
// Return horizontal sum of the acc vector
0 commit comments