Skip to content

Commit 80047a3

Browse files
committed
More AVX2 optimizations
1 parent 76a744e commit 80047a3

File tree

1 file changed

+78
-73
lines changed

1 file changed

+78
-73
lines changed

ggml.c

Lines changed: 78 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,19 +2207,20 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
22072207
// Initialize accumulator with zeros
22082208
__m256 acc = _mm256_setzero_ps();
22092209

2210-
for (int i = 0; i < nb; i += 2) {
2211-
__m256i bx = bytesFromCrumbs(x[i+1].qs, x[i].qs);
2210+
for (int i = 0; i < nb/2; i++) {
2211+
__m256i bx = bytesFromCrumbs(x[i*2+1].qs, x[i*2].qs);
22122212

22132213
// Compute combined scale for the block
2214-
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+0].d) * y[i/2].d);
2215-
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+1].d) * y[i/2].d);
2216-
const __m256 scale = _mm256_set_m128(scale_hi, scale_lo);
2214+
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+0].d));
2215+
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+1].d));
2216+
__m256 scale = _mm256_set_m128(scale_hi, scale_lo);
2217+
scale = _mm256_mul_ps(scale, _mm256_broadcast_ss(&y[i].d));
22172218

22182219
const __m256i off = _mm256_set1_epi8(2);
22192220
bx = _mm256_sub_epi8(bx, off);
22202221

22212222
// Load y vector
2222-
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i/2].qs);
2223+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
22232224

22242225
// Get absolute values of x vectors
22252226
const __m256i ax = _mm256_sign_epi8(bx, bx);
@@ -2272,6 +2273,7 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
22722273
static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
22732274
assert(n % QK3_0 == 0);
22742275
const int nb = n / QK3_0;
2276+
assert(nb % 2 == 0);
22752277

22762278
const block_q3_0 * restrict x = vx;
22772279
const block_q8_0 * restrict y = vy;
@@ -2281,77 +2283,80 @@ static void ggml_vec_dot_q3_0_q8_0(const int n, float * restrict s, const void *
22812283
#if defined(__AVX2__)
22822284
// Initialize accumulator with zeros
22832285
__m128 acc = _mm_setzero_ps();
2284-
for (int i = 0; i < nb; i++) {
2285-
// Compute combined scale for the block
2286-
const __m128 scale = _mm_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i/2].d);
2287-
2288-
const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0);
2289-
const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64);
2290-
2291-
__m256i bxx = _mm256_set1_epi64x(x[i].qs);
2292-
2293-
// legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2294-
2295-
// shift the copies to be able to reach all values
2296-
// 255 192 128 64 0
2297-
// | | | |
2298-
// sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2299-
// sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2300-
// _______________________sssssfedcba98765432__________________________________________ shift right
2301-
// sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2302-
// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2303-
// e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2304-
bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r));
2305-
2306-
// add to itself in masked places to shift some values left one bit
2307-
// 127 64 0
2308-
// | | | | | | | | | | | | | | | |
2309-
// ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2310-
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2311-
// _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2312-
// .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2313-
//
2314-
// 255 192 128
2315-
// | | | | | | | | | | | | | | | |
2316-
// ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2317-
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2318-
// _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2319-
// .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2320-
const __m256i doublemask = _mm256_set1_epi64x(0x078000078000);
2321-
bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx));
2322-
2323-
// collect 16 bytes from 256 into 128 bits
2324-
const __m256i shufmask = _mm256_set_epi8(
2325-
5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1,
2326-
-1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0);
2327-
bxx = _mm256_shuffle_epi8(bxx, shufmask);
2286+
for (int i = 0; i < nb/2; i++) {
2287+
const __m128 scale_y = _mm_set1_ps(y[i].d);
2288+
for (int u = 0; u < 2; u++) { // let the compiler unroll this
2289+
// Compute combined scale for the block
2290+
const __m128 scale_x = _mm_set1_ps(GGML_FP16_TO_FP32(x[i*2+u].d));
2291+
const __m128 scale = _mm_mul_ps(scale_x, scale_y);
2292+
2293+
__m256i bxx = _mm256_set1_epi64x(x[i*2+u].qs);
2294+
2295+
// legend: _=zero +=one .=don't care 0-f=3bit quantized values s=fp16 scale
2296+
2297+
// shift the copies to be able to reach all values
2298+
// 255 192 128 64 0
2299+
// | | | |
2300+
// sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210sssssfedcba9876543210 in
2301+
// sssfedcba9876543210_______________________sfedcba9876543210____sssssfedcba9876543210 shift left
2302+
// _______________________sssssfedcba98765432__________________________________________ shift right
2303+
// sssfedcba9876543210____sssssfedcba98765432sfedcba9876543210____sssssfedcba9876543210 out
2304+
// ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
2305+
// e b 6 3 _ . f a 7 2 c 9 4 1 _ . d 8 5 0
2306+
const __m256i shift_l = _mm256_set_epi64x(2*3, 64, 4*3, 0);
2307+
const __m256i shift_r = _mm256_set_epi64x( 64, 2*3, 64, 64);
2308+
bxx = _mm256_or_si256(_mm256_sllv_epi64(bxx, shift_l), _mm256_srlv_epi64(bxx, shift_r));
2309+
2310+
// add to itself in masked places to shift some values left one bit
2311+
// 127 64 0
2312+
// | | | | | | | | | | | | | | | |
2313+
// ssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222111000 in
2314+
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2315+
// _____________________.999____________________.111____________________________________.ddd____________________.555_______________ masked
2316+
// .............ccc.....999.............444.....111....____________.....................ddd.............888.....555.............000 sum
2317+
//
2318+
// 255 192 128
2319+
// | | | | | | | | | | | | | | | |
2320+
// ssssssssssfffeeedddcccbbbaaa999888777666555444333222111000____________ssssssssssssssssfffeeedddcccbbbaaa999888777666555444333222 in
2321+
// _____________________++++____________________++++____________________________________++++____________________++++_______________ mask
2322+
// _____________________.bbb____________________.333____________________________________.fff____________________.777_______________ masked
2323+
// .............eee.....bbb.............666.....333..........____________...............fff.............aaa.....777.............222 sum
2324+
const __m256i doublemask = _mm256_set1_epi64x(0x078000078000);
2325+
bxx = _mm256_add_epi64(bxx, _mm256_and_si256(doublemask, bxx));
2326+
2327+
// collect 16 bytes from 256 into 128 bits
2328+
const __m256i shufmask = _mm256_set_epi8(
2329+
5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0,-1,-1,
2330+
-1,-1, 5,14,-1,-1,13, 3,-1,-1, 2,11,-1,-1,10, 0);
2331+
bxx = _mm256_shuffle_epi8(bxx, shufmask);
2332+
2333+
__m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1));
2334+
2335+
const __m128i mask = _mm_set1_epi8(7);
2336+
bx = _mm_and_si128(mask, bx);
2337+
2338+
const __m128i off = _mm_set1_epi8(4);
2339+
bx = _mm_sub_epi8(bx, off);
2340+
2341+
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + u*QK3_0));
23282342

2329-
__m128i bx = _mm_or_si128(_mm256_castsi256_si128(bxx), _mm256_extracti128_si256(bxx, 1));
2330-
2331-
const __m128i mask = _mm_set1_epi8(7);
2332-
bx = _mm_and_si128(mask, bx);
2333-
2334-
const __m128i off = _mm_set1_epi8(4);
2335-
bx = _mm_sub_epi8(bx, off);
2336-
2337-
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i/2].qs + (i%2)*QK3_0));
2338-
2339-
// Get absolute values of x vectors
2340-
const __m128i ax = _mm_sign_epi8(bx, bx);
2341-
// Sign the values of the y vectors
2342-
const __m128i sy = _mm_sign_epi8(by, bx);
2343-
// Perform multiplication and create 16-bit values
2344-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2343+
// Get absolute values of x vectors
2344+
const __m128i ax = _mm_sign_epi8(bx, bx);
2345+
// Sign the values of the y vectors
2346+
const __m128i sy = _mm_sign_epi8(by, bx);
2347+
// Perform multiplication and create 16-bit values
2348+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
23452349

2346-
// Convert int16_t to int32_t by adding pairwise
2347-
const __m128i ones = _mm_set1_epi16(1);
2348-
__m128i i32 = _mm_madd_epi16(dot, ones);
2350+
// Convert int16_t to int32_t by adding pairwise
2351+
const __m128i ones = _mm_set1_epi16(1);
2352+
__m128i i32 = _mm_madd_epi16(dot, ones);
23492353

2350-
// Convert int32_t to float
2351-
const __m128 p = _mm_cvtepi32_ps(i32);
2354+
// Convert int32_t to float
2355+
const __m128 p = _mm_cvtepi32_ps(i32);
23522356

2353-
// Apply the scale, and accumulate
2354-
acc = _mm_fmadd_ps(scale, p, acc);
2357+
// Apply the scale, and accumulate
2358+
acc = _mm_fmadd_ps(scale, p, acc);
2359+
}
23552360
}
23562361

23572362
// Return horizontal sum of the acc vector

0 commit comments

Comments
 (0)