@@ -606,12 +606,12 @@ typedef struct {
606
606
static_assert (sizeof (block_q2_0 ) == sizeof (ggml_fp16_t ) + QK2_0 / 4 , "wrong q2_0 size/padding" );
607
607
608
608
#define QK3_0 16
609
- typedef union {
610
- struct {
611
- uint16_t pad [ 3 ];
612
- ggml_fp16_t d ;
613
- };
614
- uint64_t qs ;
609
+ typedef struct {
610
+ ggml_fp16_t d ;
611
+ // Instead of representing q3_0 as a packed format "...210210210210",
612
+ // represent it as two planes: "...10101010" and "...2222"
613
+ uint16_t qhi ; // The highest bit of each 3-bit number, packed together
614
+ uint32_t qlo ; // The low 2-bits of each 3-bit number, packed together
615
615
} block_q3_0 ;
616
616
static_assert (sizeof (block_q3_0 ) == sizeof (ggml_fp16_t ) + QK3_0 * 3 / 8 , "wrong q3_0 size/padding" );
617
617
@@ -691,17 +691,20 @@ static void quantize_row_q3_0(const float * restrict x, block_q3_0 * restrict y,
691
691
const float d = max / -4 ;
692
692
const float id = d ? 1.0f /d : 0.0f ;
693
693
694
- uint64_t qs = 0 ;
694
+ uint32_t lo = 0 ;
695
+ uint16_t hi = 0 ;
695
696
696
697
for (int l = 0 ; l < QK3_0 ; l ++ ) {
697
698
const float v = x [i * QK3_0 + l ]* id ;
698
699
const uint8_t vi = MIN (7 , (int8_t )roundf (v ) + 4 );
699
700
assert (vi < 8 );
700
- qs |= (uint64_t )vi << (l * 3 );
701
+ lo |= (vi & 3 ) << (l * 2 );
702
+ hi |= ((vi >> 2 ) & 1 ) << l ;
701
703
}
702
704
703
- y [i ].qs = qs ;
704
- y [i ].d = GGML_FP32_TO_FP16 (d ); // overwrite unused part of uint64_t qs
705
+ y [i ].d = GGML_FP32_TO_FP16 (d );
706
+ y [i ].qlo = lo ;
707
+ y [i ].qhi = hi ;
705
708
}
706
709
}
707
710
@@ -1335,13 +1338,15 @@ static void dequantize_row_q3_0(const void * restrict vx, float * restrict y, in
1335
1338
1336
1339
for (int i = 0 ; i < nb ; i ++ ) {
1337
1340
const float d = GGML_FP16_TO_FP32 (x [i ].d );
1338
- uint64_t qs = x [i ].qs ;
1341
+ uint_fast32_t lo = x [i ].qlo ;
1342
+ uint_fast32_t hi = x [i ].qhi << 2 ;
1339
1343
for (int l = 0 ; l < QK3_0 ; l ++ ) {
1340
- const int8_t vi = qs & 7 ;
1344
+ const int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
1341
1345
const float v = (vi - 4 )* d ;
1342
1346
y [i * QK3_0 + l ] = v ;
1343
1347
assert (!isnan (y [i * QK3_0 + l ]));
1344
- qs >>= 3 ;
1348
+ lo >>= 2 ;
1349
+ hi >>= 1 ;
1345
1350
}
1346
1351
}
1347
1352
}
@@ -2391,6 +2396,39 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2391
2396
* s = sumf ;
2392
2397
}
2393
2398
2399
+ #if __AVX2__ || __AVX512F__
2400
+ // Computes the dot product of signed 8-bit integers packed into 256-bit vectors,
2401
+ // converting the result to 32-bit floats packed into a 256-bit vector.
2402
+ static inline __m256 dotMul (__m256i bx , __m256i by ) {
2403
+ # if __AVXVNNIINT8__
2404
+ // Perform multiplication and sum to 32-bit values
2405
+ const __m256i i32 = _mm256_dpbssd_epi32 (bx , by , _mm256_setzero_si256 ());
2406
+ # else
2407
+ // Get absolute values of x vectors
2408
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2409
+ // Sign the values of the y vectors
2410
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2411
+ // Perform multiplication and create 16-bit values
2412
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2413
+
2414
+ // Convert int16_t to int32_t by adding pairwise
2415
+ const __m256i ones = _mm256_set1_epi16 (1 );
2416
+ const __m256i i32 = _mm256_madd_epi16 (ones , dot );
2417
+ # endif
2418
+ // Convert int32_t to float
2419
+ return _mm256_cvtepi32_ps (i32 );
2420
+ }
2421
+
2422
+ // Return horizontal sum of 32-bit floats packed into a 256-bit vector.
2423
+ static inline float horizontalSum (__m256 acc ) {
2424
+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2425
+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2426
+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2427
+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2428
+ return _mm_cvtss_f32 (res );
2429
+ }
2430
+ #endif
2431
+
2394
2432
static void ggml_vec_dot_q2_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2395
2433
assert (n % QK2_0 == 0 );
2396
2434
const int nb = n / QK2_0 ;
@@ -2420,30 +2458,15 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2420
2458
// Load y vector
2421
2459
const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2422
2460
2423
- // Get absolute values of x vectors
2424
- const __m256i ax = _mm256_sign_epi8 (bx , bx );
2425
- // Sign the values of the y vectors
2426
- const __m256i sy = _mm256_sign_epi8 (by , bx );
2427
- // Perform multiplication and create 16-bit values
2428
- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2429
-
2430
- // Convert int16_t to int32_t by adding pairwise
2431
- const __m256i ones = _mm256_set1_epi16 (1 );
2432
- __m256i i32 = _mm256_madd_epi16 (ones , dot );
2433
-
2434
- // Convert int32_t to float
2435
- __m256 p = _mm256_cvtepi32_ps (i32 );
2461
+ // Do the product:
2462
+ __m256 p = dotMul (bx , by );
2436
2463
2437
2464
// Apply the scale, and accumulate
2438
2465
acc = _mm256_fmadd_ps (scale , p , acc );
2439
2466
}
2440
2467
2441
2468
// Return horizontal sum of the acc vector
2442
- __m128 res = _mm256_extractf128_ps (acc , 1 );
2443
- res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2444
- res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2445
- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2446
- sumf = _mm_cvtss_f32 (res );
2469
+ sumf = horizontalSum (acc );
2447
2470
#else
2448
2471
for (int i = 0 ; i < nb ; i ++ ) {
2449
2472
const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
@@ -2468,6 +2491,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2468
2491
* s = sumf ;
2469
2492
}
2470
2493
2494
+ // Lookup table used to convert q3_0 to SIMD vectors.
2495
+ // Expands the bits of an 8-bit value into a 64 bit result, turning each bit into a byte.
2496
+ // A zero bit turns into 0xFC, while a one bit turns into 0x00.
2497
+ #define B0 (n ) 0x ## n
2498
+ #define B1 (n ) B0(n ## FC), B0(n ## 00)
2499
+ #define B2 (n ) B1(n ## FC), B1(n ## 00)
2500
+ #define B3 (n ) B2(n ## FC), B2(n ## 00)
2501
+ #define B4 (n ) B3(n ## FC), B3(n ## 00)
2502
+ #define B5 (n ) B4(n ## FC), B4(n ## 00)
2503
+ #define B6 (n ) B5(n ## FC), B5(n ## 00)
2504
+ #define B7 (n ) B6(n ## FC), B6(n ## 00)
2505
+ #define B8 ( ) B7( FC), B7( 00)
2506
+ static const uint64_t ggml_q3_table [256 ] = { B8 () };
2507
+
2471
2508
static void ggml_vec_dot_q3_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2472
2509
assert (n % QK3_0 == 0 );
2473
2510
const int nb = n / QK3_0 ;
@@ -2480,103 +2517,54 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
2480
2517
2481
2518
#if defined(__AVX2__ )
2482
2519
// Initialize accumulator with zeros
2483
- __m128 acc = _mm_setzero_ps ();
2520
+ __m256 acc = _mm256_setzero_ps ();
2521
+
2484
2522
for (int i = 0 ; i < nb /2 ; i ++ ) {
2485
- const __m128 scale_y = _mm_set1_ps (y [i ].d );
2486
- for (int u = 0 ; u < 2 ; u ++ ) { // let the compiler unroll this
2487
- // Compute combined scale for the block
2488
- const __m128 scale_x = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + u ].d ));
2489
- const __m128 scale = _mm_mul_ps (scale_x , scale_y );
2490
-
2491
- __m256i bxx = _mm256_set1_epi64x (x [i * 2 + u ].qs );
2492
-
2493
- // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2494
-
2495
- // shift the copies to be able to reach all values
2496
- // 255 192 128 64 0
2497
- // | | | |
2498
- // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2499
- // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2500
- // _______________________sssssfedcba98765432__________________________________________ shift right
2501
- // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2502
- // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2503
- // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2504
- const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2505
- const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2506
- bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2507
-
2508
- // add to itself in masked places to shift some values left one bit
2509
- // 127 64 0
2510
- // | | | | | | | | | | | | | | | |
2511
- // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2512
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2513
- // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2514
- // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2515
- //
2516
- // 255 192 128
2517
- // | | | | | | | | | | | | | | | |
2518
- // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2519
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2520
- // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2521
- // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2522
- const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2523
- bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2524
-
2525
- // collect 16 bytes from 256 into 128 bits
2526
- const __m256i shufmask = _mm256_set_epi8 (
2527
- 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2528
- -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2529
- bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2530
-
2531
- __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2532
-
2533
- const __m128i mask = _mm_set1_epi8 (7 );
2534
- bx = _mm_and_si128 (mask , bx );
2535
-
2536
- const __m128i off = _mm_set1_epi8 (4 );
2537
- bx = _mm_sub_epi8 (bx , off );
2538
-
2539
- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + u * QK3_0 ));
2523
+ __m256i bx = bytesFromCrumbs (x [i * 2 + 1 ].qlo , x [i * 2 ].qlo );
2540
2524
2541
- // Get absolute values of x vectors
2542
- const __m128i ax = _mm_sign_epi8 (bx , bx );
2543
- // Sign the values of the y vectors
2544
- const __m128i sy = _mm_sign_epi8 (by , bx );
2545
- // Perform multiplication and create 16-bit values
2546
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2525
+ __m256i const bxhi = _mm256_set_epi64x (
2526
+ ggml_q3_table [x [i * 2 + 1 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 1 ].qhi & 0xFF ],
2527
+ ggml_q3_table [x [i * 2 + 0 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 0 ].qhi & 0xFF ]);
2547
2528
2548
- // Convert int16_t to int32_t by adding pairwise
2549
- const __m128i ones = _mm_set1_epi16 (1 );
2550
- __m128i i32 = _mm_madd_epi16 (dot , ones );
2529
+ // OR the high bits (which also handles the sign):
2530
+ bx = _mm256_or_si256 (bx , bxhi );
2531
+
2532
+ // Compute combined scale for the block
2533
+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 0 ].d ));
2534
+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 1 ].d ));
2535
+ __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2536
+ scale = _mm256_mul_ps (scale , _mm256_broadcast_ss (& y [i ].d ));
2551
2537
2552
- // Convert int32_t to float
2553
- const __m128 p = _mm_cvtepi32_ps ( i32 );
2538
+ // Load y vector
2539
+ const __m256i by = _mm256_loadu_si256 (( const __m256i * ) y [ i ]. qs );
2554
2540
2555
- // Apply the scale, and accumulate
2556
- acc = _mm_fmadd_ps (scale , p , acc );
2557
- }
2541
+ // Do the product,
2542
+ __m256 p = dotMul (bx , by );
2543
+
2544
+ // Apply the scale, and accumulate
2545
+ acc = _mm256_fmadd_ps (scale , p , acc );
2558
2546
}
2559
2547
2560
2548
// Return horizontal sum of the acc vector
2561
- __m128 res = _mm_add_ps (acc , _mm_movehl_ps (acc , acc ));
2562
- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2563
- sumf = _mm_cvtss_f32 (res );
2549
+ sumf = horizontalSum (acc );
2564
2550
#else
2565
2551
for (int i = 0 ; i < nb ; i ++ ) {
2566
2552
const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
2567
2553
const float d1 = y [i /2 ].d ;
2568
2554
2569
- uint64_t qs0 = x [i ].qs ;
2555
+ uint_fast32_t lo0 = x [i ].qlo ;
2556
+ uint_fast32_t hi0 = x [i ].qhi << 2 ;
2570
2557
const int8_t * restrict p1 = y [i /2 ].qs + (i %2 )* QK3_0 ;
2571
2558
2572
2559
int sumi = 0 ;
2573
- for (int j = 0 ; j < QK3_0 ; j ++ ) {
2574
- const int8_t i0 = (int8_t )(qs0 & 7 ) - 4 ;
2575
- const int_fast16_t i1 = p1 [j ];
2560
+ for (int l = 0 ; l < QK3_0 ; l ++ ) {
2561
+ const int8_t i0 = (int8_t )(( lo0 & 3 ) | (( hi0 & 4 ) - 4 )) ;
2562
+ const int_fast16_t i1 = p1 [l ];
2576
2563
2577
2564
sumi += i0 * i1 ;
2578
2565
2579
- qs0 >>= 3 ;
2566
+ lo0 >>= 2 ;
2567
+ hi0 >>= 1 ;
2580
2568
}
2581
2569
sumf += d0 * d1 * sumi ;
2582
2570
}
@@ -12064,11 +12052,13 @@ size_t ggml_quantize_q3_0(const float * src, void * dst, int n, int k, int64_t h
12064
12052
quantize_row_q3_0 (src + j , y , k );
12065
12053
12066
12054
for (int i = 0 ; i < nb ; i ++ ) {
12067
- uint64_t qs = y [i ].qs ;
12055
+ uint_fast32_t lo = y [i ].qlo ;
12056
+ uint_fast32_t hi = y [i ].qhi << 2 ;
12068
12057
for (int l = 0 ; l < QK3_0 ; l ++ ) {
12069
- const int8_t vi = qs & 7 ;
12058
+ int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
12070
12059
hist [vi ]++ ;
12071
- qs >>= 3 ;
12060
+ lo >>= 2 ;
12061
+ hi >>= 1 ;
12072
12062
}
12073
12063
}
12074
12064
}
0 commit comments