Skip to content

Commit 7aa501c

Browse files
pubbysw
authored andcommitted
Faster q3_0 implementation, using two planes, by @pubby
1 parent 8c90a86 commit 7aa501c

File tree

1 file changed

+102
-112
lines changed

1 file changed

+102
-112
lines changed

ggml.c

Lines changed: 102 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -662,12 +662,12 @@ typedef struct {
662662
static_assert(sizeof(block_q2_0) == sizeof(ggml_fp16_t) + QK2_0 / 4, "wrong q2_0 size/padding");
663663

664664
#define QK3_0 16
665-
typedef union {
666-
struct {
667-
uint16_t pad[3];
668-
ggml_fp16_t d;
669-
};
670-
uint64_t qs;
665+
typedef struct {
666+
ggml_fp16_t d;
667+
// Instead of representing q3_0 as a packed format "...210210210210",
668+
// represent it as two planes: "...10101010" and "...2222"
669+
uint16_t qhi; // The highest bit of each 3-bit number, packed together
670+
uint32_t qlo; // The low 2-bits of each 3-bit number, packed together
671671
} block_q3_0;
672672
static_assert(sizeof(block_q3_0) == sizeof(ggml_fp16_t) + QK3_0 * 3 / 8, "wrong q3_0 size/padding");
673673

@@ -762,17 +762,20 @@ static void quantize_row_q3_0(const float * restrict x, block_q3_0 * restrict y,
762762
const float d = max / -4;
763763
const float id = d ? 1.0f/d : 0.0f;
764764

765-
uint64_t qs = 0;
765+
uint32_t lo = 0;
766+
uint16_t hi = 0;
766767

767768
for (int l = 0; l < QK3_0; l++) {
768769
const float v = x[i*QK3_0 + l]*id;
769770
const uint8_t vi = MIN(7, (int8_t)roundf(v) + 4);
770771
assert(vi < 8);
771-
qs |= (uint64_t)vi << (l*3);
772+
lo |= (vi & 3) << (l * 2);
773+
hi |= ((vi >> 2) & 1) << l;
772774
}
773775

774-
y[i].qs = qs;
775-
y[i].d = GGML_FP32_TO_FP16(d); // overwrite unused part of uint64_t qs
776+
y[i].d = GGML_FP32_TO_FP16(d);
777+
y[i].qlo = lo;
778+
y[i].qhi = hi;
776779
}
777780
}
778781

@@ -1573,13 +1576,15 @@ static void dequantize_row_q3_0(const void * restrict vx, float * restrict y, in
15731576

15741577
for (int i = 0; i < nb; i++) {
15751578
const float d = GGML_FP16_TO_FP32(x[i].d);
1576-
uint64_t qs = x[i].qs;
1579+
uint_fast32_t lo = x[i].qlo;
1580+
uint_fast32_t hi = x[i].qhi << 2;
15771581
for (int l = 0; l < QK3_0; l++) {
1578-
const int8_t vi = qs & 7;
1582+
const int8_t vi = (lo & 3) | (hi & 4);
15791583
const float v = (vi - 4)*d;
15801584
y[i*QK3_0 + l] = v;
15811585
assert(!isnan(y[i*QK3_0 + l]));
1582-
qs >>= 3;
1586+
lo >>= 2;
1587+
hi >>= 1;
15831588
}
15841589
}
15851590
}
@@ -2525,6 +2530,39 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
25252530
*s = sumf;
25262531
}
25272532

2533+
#if __AVX2__ || __AVX512F__
2534+
// Computes the dot product of signed 8-bit integers packed into 256-bit vectors,
2535+
// converting the result to 32-bit floats packed into a 256-bit vector.
2536+
static inline __m256 dotMul(__m256i bx, __m256i by) {
2537+
# if __AVXVNNIINT8__
2538+
// Perform multiplication and sum to 32-bit values
2539+
const __m256i i32 = _mm256_dpbssd_epi32(bx, by, _mm256_setzero_si256());
2540+
# else
2541+
// Get absolute values of x vectors
2542+
const __m256i ax = _mm256_sign_epi8(bx, bx);
2543+
// Sign the values of the y vectors
2544+
const __m256i sy = _mm256_sign_epi8(by, bx);
2545+
// Perform multiplication and create 16-bit values
2546+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2547+
2548+
// Convert int16_t to int32_t by adding pairwise
2549+
const __m256i ones = _mm256_set1_epi16(1);
2550+
const __m256i i32 = _mm256_madd_epi16(ones, dot);
2551+
# endif
2552+
// Convert int32_t to float
2553+
return _mm256_cvtepi32_ps(i32);
2554+
}
2555+
2556+
// Return horizontal sum of 32-bit floats packed into a 256-bit vector.
2557+
static inline float horizontalSum(__m256 acc) {
2558+
__m128 res = _mm256_extractf128_ps(acc, 1);
2559+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2560+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2561+
res = _mm_add_ss(res, _mm_movehdup_ps(res));
2562+
return _mm_cvtss_f32(res);
2563+
}
2564+
#endif
2565+
25282566
static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
25292567
assert(n % QK2_0 == 0);
25302568
const int nb = n / QK2_0;
@@ -2554,30 +2592,15 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
25542592
// Load y vector
25552593
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
25562594

2557-
// Get absolute values of x vectors
2558-
const __m256i ax = _mm256_sign_epi8(bx, bx);
2559-
// Sign the values of the y vectors
2560-
const __m256i sy = _mm256_sign_epi8(by, bx);
2561-
// Perform multiplication and create 16-bit values
2562-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2563-
2564-
// Convert int16_t to int32_t by adding pairwise
2565-
const __m256i ones = _mm256_set1_epi16(1);
2566-
__m256i i32 = _mm256_madd_epi16(ones, dot);
2567-
2568-
// Convert int32_t to float
2569-
__m256 p = _mm256_cvtepi32_ps(i32);
2595+
// Do the product:
2596+
__m256 p = dotMul(bx, by);
25702597

25712598
// Apply the scale, and accumulate
25722599
acc = _mm256_fmadd_ps(scale, p, acc);
25732600
}
25742601

25752602
// Return horizontal sum of the acc vector
2576-
__m128 res = _mm256_extractf128_ps(acc, 1);
2577-
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2578-
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
2579-
res = _mm_add_ss(res, _mm_movehdup_ps(res));
2580-
sumf = _mm_cvtss_f32(res);
2603+
sumf = horizontalSum(acc);
25812604
#else
25822605
for (int i = 0; i < nb; i++) {
25832606
const float d0 = GGML_FP16_TO_FP32(x[i].d);
@@ -2602,6 +2625,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
26022625
*s = sumf;
26032626
}
26042627

2628+
// Lookup table used to convert q3_0 to SIMD vectors.
2629+
// Expands the bits of an 8-bit value into a 64 bit result, turning each bit into a byte.
2630+
// A zero bit turns into 0xFC, while a one bit turns into 0x00.
2631+
#define B0(n) 0x ## n
2632+
#define B1(n) B0(n ## FC), B0(n ## 00)
2633+
#define B2(n) B1(n ## FC), B1(n ## 00)
2634+
#define B3(n) B2(n ## FC), B2(n ## 00)
2635+
#define B4(n) B3(n ## FC), B3(n ## 00)
2636+
#define B5(n) B4(n ## FC), B4(n ## 00)
2637+
#define B6(n) B5(n ## FC), B5(n ## 00)
2638+
#define B7(n) B6(n ## FC), B6(n ## 00)
2639+
#define B8( ) B7( FC), B7( 00)
2640+
static const uint64_t ggml_q3_table[256] = { B8() };
2641+
26052642
static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
26062643
assert(n % QK3_0 == 0);
26072644
const int nb = n / QK3_0;
@@ -2614,103 +2651,54 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
26142651

26152652
#if defined(__AVX2__)
26162653
// Initialize accumulator with zeros
2617-
__m128 acc = _mm_setzero_ps();
2654+
__m256 acc = _mm256_setzero_ps();
2655+
26182656
for (int i = 0; i < nb/2; i++) {
2619-
const __m128 scale_y = _mm_set1_ps(y[i].d);
2620-
for (int u = 0; u < 2; u++) { // let the compiler unroll this
2621-
// Compute combined scale for the block
2622-
const __m128 scale_x = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+u].d));
2623-
const __m128 scale = _mm_mul_ps(scale_x, scale_y);
2624-
2625-
__m256i bxx = _mm256_set1_epi64x(x[i*2+u].qs);
2626-
2627-
// legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2628-
2629-
// shift the copies to be able to reach all values
2630-
// 255 192 128 64 0
2631-
// | | | |
2632-
// sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2633-
// sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2634-
// _______________________sssssfedcba98765432__________________________________________ shift right
2635-
// sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2636-
// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2637-
// e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2638-
const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0);
2639-
const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64);
2640-
bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r));
2641-
2642-
// add to itself in masked places to shift some values left one bit
2643-
// 127 64 0
2644-
// | | | | | | | | | | | | | | | |
2645-
// ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2646-
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2647-
// _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2648-
// .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2649-
//
2650-
// 255 192 128
2651-
// | | | | | | | | | | | | | | | |
2652-
// ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2653-
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2654-
// _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2655-
// .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2656-
const __m256i doublemask = _mm256_set1_epi64x(0x078000078000);
2657-
bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx));
2658-
2659-
// collect 16 bytes from 256 into 128 bits
2660-
const __m256i shufmask = _mm256_set_epi8(
2661-
5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1,
2662-
-1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0);
2663-
bxx = _mm256_shuffle_epi8(bxx, shufmask);
2664-
2665-
__m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1));
2666-
2667-
const __m128i mask = _mm_set1_epi8(7);
2668-
bx = _mm_and_si128(mask, bx);
2669-
2670-
const __m128i off = _mm_set1_epi8(4);
2671-
bx = _mm_sub_epi8(bx, off);
2672-
2673-
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + u*QK3_0));
2657+
__m256i bx = bytes_from_crumbs(x[i*2+1].qlo, x[i*2].qlo);
26742658

2675-
// Get absolute values of x vectors
2676-
const __m128i ax = _mm_sign_epi8(bx, bx);
2677-
// Sign the values of the y vectors
2678-
const __m128i sy = _mm_sign_epi8(by, bx);
2679-
// Perform multiplication and create 16-bit values
2680-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2659+
__m256i const bxhi = _mm256_set_epi64x(
2660+
ggml_q3_table[x[i*2+1].qhi >> 8], ggml_q3_table[x[i*2+1].qhi & 0xFF],
2661+
ggml_q3_table[x[i*2+0].qhi >> 8], ggml_q3_table[x[i*2+0].qhi & 0xFF]);
26812662

2682-
// Convert int16_t to int32_t by adding pairwise
2683-
const __m128i ones = _mm_set1_epi16(1);
2684-
__m128i i32 = _mm_madd_epi16(dot, ones);
2663+
// OR the high bits (which also handles the sign):
2664+
bx = _mm256_or_si256(bx, bxhi);
26852665

2686-
// Convert int32_t to float
2687-
const __m128 p = _mm_cvtepi32_ps(i32);
2666+
// Compute combined scale for the block
2667+
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+0].d));
2668+
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+1].d));
2669+
__m256 scale = _mm256_set_m128(scale_hi, scale_lo);
2670+
scale = _mm256_mul_ps(scale, _mm256_broadcast_ss(&y[i].d));
26882671

2689-
// Apply the scale, and accumulate
2690-
acc = _mm_fmadd_ps(scale, p, acc);
2691-
}
2672+
// Load y vector
2673+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2674+
2675+
// Do the product,
2676+
__m256 p = dotMul(bx, by);
2677+
2678+
// Apply the scale, and accumulate
2679+
acc = _mm256_fmadd_ps(scale, p, acc);
26922680
}
26932681

26942682
// Return horizontal sum of the acc vector
2695-
__m128 res = _mm_add_ps(acc, _mm_movehl_ps(acc, acc));
2696-
res = _mm_add_ss(res, _mm_movehdup_ps(res));
2697-
sumf = _mm_cvtss_f32(res);
2683+
sumf = horizontalSum(acc);
26982684
#else
26992685
for (int i = 0; i < nb; i++) {
27002686
const float d0 = GGML_FP16_TO_FP32(x[i].d);
27012687
const float d1 = y[i/2].d;
27022688

2703-
uint64_t qs0 = x[i].qs;
2689+
uint_fast32_t lo0 = x[i].qlo;
2690+
uint_fast32_t hi0 = x[i].qhi << 2;
27042691
const int8_t * restrict p1 = y[i/2].qs + (i%2)*QK3_0;
27052692

27062693
int sumi = 0;
2707-
for (int j = 0; j < QK3_0; j++) {
2708-
const int8_t i0 = (int8_t)(qs0 & 7) - 4;
2709-
const int_fast16_t i1 = p1[j];
2694+
for (int l = 0; l < QK3_0; l++) {
2695+
const int8_t i0 = (int8_t)((lo0 & 3) | ((hi0 & 4) - 4));
2696+
const int_fast16_t i1 = p1[l];
27102697

27112698
sumi += i0 * i1;
27122699

2713-
qs0 >>= 3;
2700+
lo0 >>= 2;
2701+
hi0 >>= 1;
27142702
}
27152703
sumf += d0 * d1 * sumi;
27162704
}
@@ -12497,11 +12485,13 @@ size_t ggml_quantize_q3_0(const float * src, void * dst, int n, int k, int64_t h
1249712485
quantize_row_q3_0(src + j, y, k);
1249812486

1249912487
for (int i = 0; i < nb; i++) {
12500-
uint64_t qs = y[i].qs;
12488+
uint_fast32_t lo = y[i].qlo;
12489+
uint_fast32_t hi = y[i].qhi << 2;
1250112490
for (int l = 0; l < QK3_0; l++) {
12502-
const int8_t vi = qs & 7;
12491+
int8_t vi = (lo & 3) | (hi & 4);
1250312492
hist[vi]++;
12504-
qs >>= 3;
12493+
lo >>= 2;
12494+
hi >>= 1;
1250512495
}
1250612496
}
1250712497
}

0 commit comments

Comments
 (0)