@@ -662,12 +662,12 @@ typedef struct {
662
662
static_assert (sizeof (block_q2_0 ) == sizeof (ggml_fp16_t ) + QK2_0 / 4 , "wrong q2_0 size/padding" );
663
663
664
664
#define QK3_0 16
665
- typedef union {
666
- struct {
667
- uint16_t pad [ 3 ];
668
- ggml_fp16_t d ;
669
- };
670
- uint64_t qs ;
665
+ typedef struct {
666
+ ggml_fp16_t d ;
667
+ // Instead of representing q3_0 as a packed format "...210210210210",
668
+ // represent it as two planes: "...10101010" and "...2222"
669
+ uint16_t qhi ; // The highest bit of each 3-bit number, packed together
670
+ uint32_t qlo ; // The low 2-bits of each 3-bit number, packed together
671
671
} block_q3_0 ;
672
672
static_assert (sizeof (block_q3_0 ) == sizeof (ggml_fp16_t ) + QK3_0 * 3 / 8 , "wrong q3_0 size/padding" );
673
673
@@ -762,17 +762,20 @@ static void quantize_row_q3_0(const float * restrict x, block_q3_0 * restrict y,
762
762
const float d = max / -4 ;
763
763
const float id = d ? 1.0f /d : 0.0f ;
764
764
765
- uint64_t qs = 0 ;
765
+ uint32_t lo = 0 ;
766
+ uint16_t hi = 0 ;
766
767
767
768
for (int l = 0 ; l < QK3_0 ; l ++ ) {
768
769
const float v = x [i * QK3_0 + l ]* id ;
769
770
const uint8_t vi = MIN (7 , (int8_t )roundf (v ) + 4 );
770
771
assert (vi < 8 );
771
- qs |= (uint64_t )vi << (l * 3 );
772
+ lo |= (vi & 3 ) << (l * 2 );
773
+ hi |= ((vi >> 2 ) & 1 ) << l ;
772
774
}
773
775
774
- y [i ].qs = qs ;
775
- y [i ].d = GGML_FP32_TO_FP16 (d ); // overwrite unused part of uint64_t qs
776
+ y [i ].d = GGML_FP32_TO_FP16 (d );
777
+ y [i ].qlo = lo ;
778
+ y [i ].qhi = hi ;
776
779
}
777
780
}
778
781
@@ -1573,13 +1576,15 @@ static void dequantize_row_q3_0(const void * restrict vx, float * restrict y, in
1573
1576
1574
1577
for (int i = 0 ; i < nb ; i ++ ) {
1575
1578
const float d = GGML_FP16_TO_FP32 (x [i ].d );
1576
- uint64_t qs = x [i ].qs ;
1579
+ uint_fast32_t lo = x [i ].qlo ;
1580
+ uint_fast32_t hi = x [i ].qhi << 2 ;
1577
1581
for (int l = 0 ; l < QK3_0 ; l ++ ) {
1578
- const int8_t vi = qs & 7 ;
1582
+ const int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
1579
1583
const float v = (vi - 4 )* d ;
1580
1584
y [i * QK3_0 + l ] = v ;
1581
1585
assert (!isnan (y [i * QK3_0 + l ]));
1582
- qs >>= 3 ;
1586
+ lo >>= 2 ;
1587
+ hi >>= 1 ;
1583
1588
}
1584
1589
}
1585
1590
}
@@ -2525,6 +2530,39 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
2525
2530
* s = sumf ;
2526
2531
}
2527
2532
2533
+ #if __AVX2__ || __AVX512F__
2534
+ // Computes the dot product of signed 8-bit integers packed into 256-bit vectors,
2535
+ // converting the result to 32-bit floats packed into a 256-bit vector.
2536
+ static inline __m256 dotMul (__m256i bx , __m256i by ) {
2537
+ # if __AVXVNNIINT8__
2538
+ // Perform multiplication and sum to 32-bit values
2539
+ const __m256i i32 = _mm256_dpbssd_epi32 (bx , by , _mm256_setzero_si256 ());
2540
+ # else
2541
+ // Get absolute values of x vectors
2542
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
2543
+ // Sign the values of the y vectors
2544
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
2545
+ // Perform multiplication and create 16-bit values
2546
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2547
+
2548
+ // Convert int16_t to int32_t by adding pairwise
2549
+ const __m256i ones = _mm256_set1_epi16 (1 );
2550
+ const __m256i i32 = _mm256_madd_epi16 (ones , dot );
2551
+ # endif
2552
+ // Convert int32_t to float
2553
+ return _mm256_cvtepi32_ps (i32 );
2554
+ }
2555
+
2556
+ // Return horizontal sum of 32-bit floats packed into a 256-bit vector.
2557
+ static inline float horizontalSum (__m256 acc ) {
2558
+ __m128 res = _mm256_extractf128_ps (acc , 1 );
2559
+ res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2560
+ res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2561
+ res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2562
+ return _mm_cvtss_f32 (res );
2563
+ }
2564
+ #endif
2565
+
2528
2566
static void ggml_vec_dot_q2_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2529
2567
assert (n % QK2_0 == 0 );
2530
2568
const int nb = n / QK2_0 ;
@@ -2554,30 +2592,15 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2554
2592
// Load y vector
2555
2593
const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2556
2594
2557
- // Get absolute values of x vectors
2558
- const __m256i ax = _mm256_sign_epi8 (bx , bx );
2559
- // Sign the values of the y vectors
2560
- const __m256i sy = _mm256_sign_epi8 (by , bx );
2561
- // Perform multiplication and create 16-bit values
2562
- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
2563
-
2564
- // Convert int16_t to int32_t by adding pairwise
2565
- const __m256i ones = _mm256_set1_epi16 (1 );
2566
- __m256i i32 = _mm256_madd_epi16 (ones , dot );
2567
-
2568
- // Convert int32_t to float
2569
- __m256 p = _mm256_cvtepi32_ps (i32 );
2595
+ // Do the product:
2596
+ __m256 p = dotMul (bx , by );
2570
2597
2571
2598
// Apply the scale, and accumulate
2572
2599
acc = _mm256_fmadd_ps (scale , p , acc );
2573
2600
}
2574
2601
2575
2602
// Return horizontal sum of the acc vector
2576
- __m128 res = _mm256_extractf128_ps (acc , 1 );
2577
- res = _mm_add_ps (res , _mm256_castps256_ps128 (acc ));
2578
- res = _mm_add_ps (res , _mm_movehl_ps (res , res ));
2579
- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2580
- sumf = _mm_cvtss_f32 (res );
2603
+ sumf = horizontalSum (acc );
2581
2604
#else
2582
2605
for (int i = 0 ; i < nb ; i ++ ) {
2583
2606
const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
@@ -2602,6 +2625,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
2602
2625
* s = sumf ;
2603
2626
}
2604
2627
2628
+ // Lookup table used to convert q3_0 to SIMD vectors.
2629
+ // Expands the bits of an 8-bit value into a 64 bit result, turning each bit into a byte.
2630
+ // A zero bit turns into 0xFC, while a one bit turns into 0x00.
2631
+ #define B0 (n ) 0x ## n
2632
+ #define B1 (n ) B0(n ## FC), B0(n ## 00)
2633
+ #define B2 (n ) B1(n ## FC), B1(n ## 00)
2634
+ #define B3 (n ) B2(n ## FC), B2(n ## 00)
2635
+ #define B4 (n ) B3(n ## FC), B3(n ## 00)
2636
+ #define B5 (n ) B4(n ## FC), B4(n ## 00)
2637
+ #define B6 (n ) B5(n ## FC), B5(n ## 00)
2638
+ #define B7 (n ) B6(n ## FC), B6(n ## 00)
2639
+ #define B8 ( ) B7( FC), B7( 00)
2640
+ static const uint64_t ggml_q3_table [256 ] = { B8 () };
2641
+
2605
2642
static void ggml_vec_dot_q3_0_q8_0 (const int n , float * restrict s , const void * restrict vx , const void * restrict vy ) {
2606
2643
assert (n % QK3_0 == 0 );
2607
2644
const int nb = n / QK3_0 ;
@@ -2614,103 +2651,54 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
2614
2651
2615
2652
#if defined(__AVX2__ )
2616
2653
// Initialize accumulator with zeros
2617
- __m128 acc = _mm_setzero_ps ();
2654
+ __m256 acc = _mm256_setzero_ps ();
2655
+
2618
2656
for (int i = 0 ; i < nb /2 ; i ++ ) {
2619
- const __m128 scale_y = _mm_set1_ps (y [i ].d );
2620
- for (int u = 0 ; u < 2 ; u ++ ) { // let the compiler unroll this
2621
- // Compute combined scale for the block
2622
- const __m128 scale_x = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + u ].d ));
2623
- const __m128 scale = _mm_mul_ps (scale_x , scale_y );
2624
-
2625
- __m256i bxx = _mm256_set1_epi64x (x [i * 2 + u ].qs );
2626
-
2627
- // legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2628
-
2629
- // shift the copies to be able to reach all values
2630
- // 255 192 128 64 0
2631
- // | | | |
2632
- // sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2633
- // sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2634
- // _______________________sssssfedcba98765432__________________________________________ shift right
2635
- // sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2636
- // ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2637
- // e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2638
- const __m256i shift_l = _mm256_set_epi64x (2 * 3 , 64 , 4 * 3 , 0 );
2639
- const __m256i shift_r = _mm256_set_epi64x ( 64 , 2 * 3 , 64 , 64 );
2640
- bxx = _mm256_or_si256 (_mm256_sllv_epi64 (bxx , shift_l ), _mm256_srlv_epi64 (bxx , shift_r ));
2641
-
2642
- // add to itself in masked places to shift some values left one bit
2643
- // 127 64 0
2644
- // | | | | | | | | | | | | | | | |
2645
- // ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2646
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2647
- // _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2648
- // .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2649
- //
2650
- // 255 192 128
2651
- // | | | | | | | | | | | | | | | |
2652
- // ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2653
- // _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2654
- // _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2655
- // .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2656
- const __m256i doublemask = _mm256_set1_epi64x (0x078000078000 );
2657
- bxx = _mm256_add_epi64 (bxx , _mm256_and_si256 (doublemask , bxx ));
2658
-
2659
- // collect 16 bytes from 256 into 128 bits
2660
- const __m256i shufmask = _mm256_set_epi8 (
2661
- 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 ,-1 ,-1 ,
2662
- -1 ,-1 , 5 ,14 ,-1 ,-1 ,13 , 3 ,-1 ,-1 , 2 ,11 ,-1 ,-1 ,10 , 0 );
2663
- bxx = _mm256_shuffle_epi8 (bxx , shufmask );
2664
-
2665
- __m128i bx = _mm_or_si128 (_mm256_castsi256_si128 (bxx ), _mm256_extracti128_si256 (bxx , 1 ));
2666
-
2667
- const __m128i mask = _mm_set1_epi8 (7 );
2668
- bx = _mm_and_si128 (mask , bx );
2669
-
2670
- const __m128i off = _mm_set1_epi8 (4 );
2671
- bx = _mm_sub_epi8 (bx , off );
2672
-
2673
- const __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + u * QK3_0 ));
2657
+ __m256i bx = bytes_from_crumbs (x [i * 2 + 1 ].qlo , x [i * 2 ].qlo );
2674
2658
2675
- // Get absolute values of x vectors
2676
- const __m128i ax = _mm_sign_epi8 (bx , bx );
2677
- // Sign the values of the y vectors
2678
- const __m128i sy = _mm_sign_epi8 (by , bx );
2679
- // Perform multiplication and create 16-bit values
2680
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2659
+ __m256i const bxhi = _mm256_set_epi64x (
2660
+ ggml_q3_table [x [i * 2 + 1 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 1 ].qhi & 0xFF ],
2661
+ ggml_q3_table [x [i * 2 + 0 ].qhi >> 8 ], ggml_q3_table [x [i * 2 + 0 ].qhi & 0xFF ]);
2681
2662
2682
- // Convert int16_t to int32_t by adding pairwise
2683
- const __m128i ones = _mm_set1_epi16 (1 );
2684
- __m128i i32 = _mm_madd_epi16 (dot , ones );
2663
+ // OR the high bits (which also handles the sign):
2664
+ bx = _mm256_or_si256 (bx , bxhi );
2685
2665
2686
- // Convert int32_t to float
2687
- const __m128 p = _mm_cvtepi32_ps (i32 );
2666
+ // Compute combined scale for the block
2667
+ const __m128 scale_lo = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 0 ].d ));
2668
+ const __m128 scale_hi = _mm_set1_ps (GGML_FP16_TO_FP32 (x [i * 2 + 1 ].d ));
2669
+ __m256 scale = _mm256_set_m128 (scale_hi , scale_lo );
2670
+ scale = _mm256_mul_ps (scale , _mm256_broadcast_ss (& y [i ].d ));
2688
2671
2689
- // Apply the scale, and accumulate
2690
- acc = _mm_fmadd_ps (scale , p , acc );
2691
- }
2672
+ // Load y vector
2673
+ const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2674
+
2675
+ // Do the product,
2676
+ __m256 p = dotMul (bx , by );
2677
+
2678
+ // Apply the scale, and accumulate
2679
+ acc = _mm256_fmadd_ps (scale , p , acc );
2692
2680
}
2693
2681
2694
2682
// Return horizontal sum of the acc vector
2695
- __m128 res = _mm_add_ps (acc , _mm_movehl_ps (acc , acc ));
2696
- res = _mm_add_ss (res , _mm_movehdup_ps (res ));
2697
- sumf = _mm_cvtss_f32 (res );
2683
+ sumf = horizontalSum (acc );
2698
2684
#else
2699
2685
for (int i = 0 ; i < nb ; i ++ ) {
2700
2686
const float d0 = GGML_FP16_TO_FP32 (x [i ].d );
2701
2687
const float d1 = y [i /2 ].d ;
2702
2688
2703
- uint64_t qs0 = x [i ].qs ;
2689
+ uint_fast32_t lo0 = x [i ].qlo ;
2690
+ uint_fast32_t hi0 = x [i ].qhi << 2 ;
2704
2691
const int8_t * restrict p1 = y [i /2 ].qs + (i %2 )* QK3_0 ;
2705
2692
2706
2693
int sumi = 0 ;
2707
- for (int j = 0 ; j < QK3_0 ; j ++ ) {
2708
- const int8_t i0 = (int8_t )(qs0 & 7 ) - 4 ;
2709
- const int_fast16_t i1 = p1 [j ];
2694
+ for (int l = 0 ; l < QK3_0 ; l ++ ) {
2695
+ const int8_t i0 = (int8_t )(( lo0 & 3 ) | (( hi0 & 4 ) - 4 )) ;
2696
+ const int_fast16_t i1 = p1 [l ];
2710
2697
2711
2698
sumi += i0 * i1 ;
2712
2699
2713
- qs0 >>= 3 ;
2700
+ lo0 >>= 2 ;
2701
+ hi0 >>= 1 ;
2714
2702
}
2715
2703
sumf += d0 * d1 * sumi ;
2716
2704
}
@@ -12497,11 +12485,13 @@ size_t ggml_quantize_q3_0(const float * src, void * dst, int n, int k, int64_t h
12497
12485
quantize_row_q3_0 (src + j , y , k );
12498
12486
12499
12487
for (int i = 0 ; i < nb ; i ++ ) {
12500
- uint64_t qs = y [i ].qs ;
12488
+ uint_fast32_t lo = y [i ].qlo ;
12489
+ uint_fast32_t hi = y [i ].qhi << 2 ;
12501
12490
for (int l = 0 ; l < QK3_0 ; l ++ ) {
12502
- const int8_t vi = qs & 7 ;
12491
+ int8_t vi = ( lo & 3 ) | ( hi & 4 ) ;
12503
12492
hist [vi ]++ ;
12504
- qs >>= 3 ;
12493
+ lo >>= 2 ;
12494
+ hi >>= 1 ;
12505
12495
}
12506
12496
}
12507
12497
}
0 commit comments