Skip to content

Commit e2f8c1e

Browse files
committed
More AVX2 optimizations
1 parent 440b8fb commit e2f8c1e

File tree

1 file changed

+78
-73
lines changed

1 file changed

+78
-73
lines changed

ggml.c

Lines changed: 78 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,19 +2405,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
24052405
// Initialize accumulator with zeros
24062406
__m256 acc = _mm256_setzero_ps();
24072407

2408-
for (int i = 0; i < nb; i += 2) {
2409-
__m256i bx = bytesFromCrumbs(x[i+1].qs, x[i].qs);
2408+
for (int i = 0; i < nb/2; i++) {
2409+
__m256i bx = bytesFromCrumbs(x[i*2+1].qs, x[i*2].qs);
24102410

24112411
// Compute combined scale for the block
2412-
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+0].d) * y[i/2].d);
2413-
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+1].d) * y[i/2].d);
2414-
const __m256 scale = _mm256_set_m128(scale_hi, scale_lo);
2412+
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+0].d));
2413+
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+1].d));
2414+
__m256 scale = _mm256_set_m128(scale_hi, scale_lo);
2415+
scale = _mm256_mul_ps(scale, _mm256_broadcast_ss(&y[i].d));
24152416

24162417
const __m256i off = _mm256_set1_epi8(2);
24172418
bx = _mm256_sub_epi8(bx, off);
24182419

24192420
// Load y vector
2420-
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i/2].qs);
2421+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
24212422

24222423
// Get absolute values of x vectors
24232424
const __m256i ax = _mm256_sign_epi8(bx, bx);
@@ -2470,6 +2471,7 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
24702471
static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
24712472
assert(n % QK3_0 == 0);
24722473
const int nb = n / QK3_0;
2474+
assert(nb % 2 == 0);
24732475

24742476
const block_q3_0 * restrict x = vx;
24752477
const block_q8_0 * restrict y = vy;
@@ -2479,77 +2481,80 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
24792481
#if defined(__AVX2__)
24802482
// Initialize accumulator with zeros
24812483
__m128 acc = _mm_setzero_ps();
2482-
for (int i = 0; i < nb; i++) {
2483-
// Compute combined scale for the block
2484-
const __m128 scale = _mm_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i/2].d);
2485-
2486-
const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0);
2487-
const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64);
2488-
2489-
__m256i bxx = _mm256_set1_epi64x(x[i].qs);
2490-
2491-
// legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2492-
2493-
// shift the copies to be able to reach all values
2494-
// 255 192 128 64 0
2495-
// | | | |
2496-
// sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2497-
// sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2498-
// _______________________sssssfedcba98765432__________________________________________ shift right
2499-
// sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2500-
// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2501-
// e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2502-
bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r));
2503-
2504-
// add to itself in masked places to shift some values left one bit
2505-
// 127 64 0
2506-
// | | | | | | | | | | | | | | | |
2507-
// ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2508-
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2509-
// _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2510-
// .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2511-
//
2512-
// 255 192 128
2513-
// | | | | | | | | | | | | | | | |
2514-
// ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2515-
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2516-
// _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2517-
// .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2518-
const __m256i doublemask = _mm256_set1_epi64x(0x078000078000);
2519-
bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx));
2520-
2521-
// collect 16 bytes from 256 into 128 bits
2522-
const __m256i shufmask = _mm256_set_epi8(
2523-
5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1,
2524-
-1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0);
2525-
bxx = _mm256_shuffle_epi8(bxx, shufmask);
2484+
for (int i = 0; i < nb/2; i++) {
2485+
const __m128 scale_y = _mm_set1_ps(y[i].d);
2486+
for (int u = 0; u < 2; u++) { // let the compiler unroll this
2487+
// Compute combined scale for the block
2488+
const __m128 scale_x = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+u].d));
2489+
const __m128 scale = _mm_mul_ps(scale_x, scale_y);
2490+
2491+
__m256i bxx = _mm256_set1_epi64x(x[i*2+u].qs);
2492+
2493+
// legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2494+
2495+
// shift the copies to be able to reach all values
2496+
// 255 192 128 64 0
2497+
// | | | |
2498+
// sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2499+
// sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2500+
// _______________________sssssfedcba98765432__________________________________________ shift right
2501+
// sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2502+
// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2503+
// e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2504+
const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0);
2505+
const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64);
2506+
bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r));
2507+
2508+
// add to itself in masked places to shift some values left one bit
2509+
// 127 64 0
2510+
// | | | | | | | | | | | | | | | |
2511+
// ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2512+
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2513+
// _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2514+
// .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2515+
//
2516+
// 255 192 128
2517+
// | | | | | | | | | | | | | | | |
2518+
// ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2519+
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2520+
// _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2521+
// .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2522+
const __m256i doublemask = _mm256_set1_epi64x(0x078000078000);
2523+
bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx));
2524+
2525+
// collect 16 bytes from 256 into 128 bits
2526+
const __m256i shufmask = _mm256_set_epi8(
2527+
5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1,
2528+
-1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0);
2529+
bxx = _mm256_shuffle_epi8(bxx, shufmask);
2530+
2531+
__m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1));
2532+
2533+
const __m128i mask = _mm_set1_epi8(7);
2534+
bx = _mm_and_si128(mask, bx);
2535+
2536+
const __m128i off = _mm_set1_epi8(4);
2537+
bx = _mm_sub_epi8(bx, off);
2538+
2539+
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + u*QK3_0));
25262540

2527-
__m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1));
2528-
2529-
const __m128i mask = _mm_set1_epi8(7);
2530-
bx = _mm_and_si128(mask, bx);
2531-
2532-
const __m128i off = _mm_set1_epi8(4);
2533-
bx = _mm_sub_epi8(bx, off);
2534-
2535-
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i/2].qs + (i%2)*QK3_0));
2536-
2537-
// Get absolute values of x vectors
2538-
const __m128i ax = _mm_sign_epi8(bx, bx);
2539-
// Sign the values of the y vectors
2540-
const __m128i sy = _mm_sign_epi8(by, bx);
2541-
// Perform multiplication and create 16-bit values
2542-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2541+
// Get absolute values of x vectors
2542+
const __m128i ax = _mm_sign_epi8(bx, bx);
2543+
// Sign the values of the y vectors
2544+
const __m128i sy = _mm_sign_epi8(by, bx);
2545+
// Perform multiplication and create 16-bit values
2546+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
25432547

2544-
// Convert int16_t to int32_t by adding pairwise
2545-
const __m128i ones = _mm_set1_epi16(1);
2546-
__m128i i32 = _mm_madd_epi16(dot, ones);
2548+
// Convert int16_t to int32_t by adding pairwise
2549+
const __m128i ones = _mm_set1_epi16(1);
2550+
__m128i i32 = _mm_madd_epi16(dot, ones);
25472551

2548-
// Convert int32_t to float
2549-
const __m128 p = _mm_cvtepi32_ps(i32);
2552+
// Convert int32_t to float
2553+
const __m128 p = _mm_cvtepi32_ps(i32);
25502554

2551-
// Apply the scale, and accumulate
2552-
acc = _mm_fmadd_ps(scale, p, acc);
2555+
// Apply the scale, and accumulate
2556+
acc = _mm_fmadd_ps(scale, p, acc);
2557+
}
25532558
}
25542559

25552560
// Return horizontal sum of the acc vector

0 commit comments

Comments
 (0)