Skip to content

Commit 76a744e

Browse files
committed
Q2 AVX2: do two blocks at a time, by @slaren
1 parent 15aee10 commit 76a744e

File tree

1 file changed

+50
-28
lines changed

1 file changed

+50
-28
lines changed

ggml.c

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,35 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
427427
// quantization
428428
//
429429

430-
// AVX routines provided by GH user Const-me
431-
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
432430
#if __AVX2__ || __AVX512F__
431+
// Unpack 32 2-bit fields into 32 bytes
432+
// The output vector contains 32 bytes, each one in [ 0 .. 3 ] interval
433+
static inline __m256i bytesFromCrumbs(uint32_t packed_hi, uint32_t packed_lo) {
434+
__m128i bx_hi = _mm_set1_epi32(packed_hi);
435+
__m128i bx_lo = _mm_set1_epi32(packed_lo);
436+
__m256i bx = _mm256_set_m128i(bx_hi, bx_lo);
437+
438+
// shift counts to get all bit pairs in lowest position of each byte
439+
const __m256i shift256 = _mm256_set_epi32(6, 4, 2, 0,
440+
6, 4, 2, 0);
441+
bx = _mm256_srlv_epi32(bx, shift256);
442+
443+
const __m256i shufmask = _mm256_set_epi8(15,11, 7, 3,
444+
14,10, 6, 2,
445+
13, 9, 5, 1,
446+
12, 8, 4, 0,
447+
15,11, 7, 3,
448+
14,10, 6, 2,
449+
13, 9, 5, 1,
450+
12, 8, 4, 0);
451+
bx = _mm256_shuffle_epi8(bx, shufmask);
452+
453+
const __m256i mask = _mm256_set1_epi8(3);
454+
bx = _mm256_and_si256(mask, bx);
455+
456+
return bx;
457+
}
458+
433459
// Unpack 32 4-bit fields into 32 bytes
434460
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
435461
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
@@ -2170,6 +2196,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
21702196
static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
21712197
assert(n % QK2_0 == 0);
21722198
const int nb = n / QK2_0;
2199+
assert(nb % 2 == 0);
21732200

21742201
const block_q2_0 * restrict x = vx;
21752202
const block_q8_0 * restrict y = vy;
@@ -2178,49 +2205,44 @@ static void ggml_vec_dot_q2_0_q8_0(const int n, float * restrict s, const void *
21782205

21792206
#if defined(__AVX2__)
21802207
// Initialize accumulator with zeros
2181-
__m128 acc = _mm_setzero_ps();
2182-
2183-
for (int i = 0; i < nb; i++) {
2184-
// Compute combined scale for the block
2185-
const __m128 scale = _mm_set1_ps(GGML_FP16_TO_FP32(x[i].d) * y[i/2].d);
2186-
2187-
__m128i bx = _mm_set1_epi32(x[i].qs);
2208+
__m256 acc = _mm256_setzero_ps();
21882209

2189-
// shift counts to get all bit pairs in lowest position of each byte
2190-
const __m128i shift128 = _mm_set_epi32(6, 4, 2, 0);
2191-
bx = _mm_srlv_epi32(bx, shift128);
2210+
for (int i = 0; i < nb; i += 2) {
2211+
__m256i bx = bytesFromCrumbs(x[i+1].qs, x[i].qs);
21922212

2193-
const __m128i shufmask = _mm_set_epi8(15,11,7,3,14,10,6,2,13,9,5,1,12,8,4,0);
2194-
bx = _mm_shuffle_epi8(bx, shufmask);
2213+
// Compute combined scale for the block
2214+
const __m128 scale_lo = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+0].d) * y[i/2].d);
2215+
const __m128 scale_hi = _mm_set1_ps(GGML_FP16_TO_FP32(x[i+1].d) * y[i/2].d);
2216+
const __m256 scale = _mm256_set_m128(scale_hi, scale_lo);
21952217

2196-
const __m128i mask = _mm_set1_epi8(3);
2197-
bx = _mm_and_si128(mask, bx);
2218+
const __m256i off = _mm256_set1_epi8(2);
2219+
bx = _mm256_sub_epi8(bx, off);
21982220

2199-
const __m128i off = _mm_set1_epi8(2);
2200-
bx = _mm_sub_epi8(bx, off);
2201-
2202-
const __m128i by = _mm_loadu_si128((const __m128i *)(y[i/2].qs + (i%2)*QK2_0));
2221+
// Load y vector
2222+
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i/2].qs);
22032223

22042224
// Get absolute values of x vectors
2205-
const __m128i ax = _mm_sign_epi8(bx, bx);
2225+
const __m256i ax = _mm256_sign_epi8(bx, bx);
22062226
// Sign the values of the y vectors
2207-
const __m128i sy = _mm_sign_epi8(by, bx);
2227+
const __m256i sy = _mm256_sign_epi8(by, bx);
22082228
// Perform multiplication and create 16-bit values
2209-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2229+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
22102230

22112231
// Convert int16_t to int32_t by adding pairwise
2212-
const __m128i ones = _mm_set1_epi16(1);
2213-
__m128i i32 = _mm_madd_epi16(dot, ones);
2232+
const __m256i ones = _mm256_set1_epi16(1);
2233+
__m256i i32 = _mm256_madd_epi16(ones, dot);
22142234

22152235
// Convert int32_t to float
2216-
const __m128 p = _mm_cvtepi32_ps(i32);
2236+
__m256 p = _mm256_cvtepi32_ps(i32);
22172237

22182238
// Apply the scale, and accumulate
2219-
acc = _mm_fmadd_ps(scale, p, acc);
2239+
acc = _mm256_fmadd_ps(scale, p, acc);
22202240
}
22212241

22222242
// Return horizontal sum of the acc vector
2223-
__m128 res = _mm_add_ps(acc, _mm_movehl_ps(acc, acc));
2243+
__m128 res = _mm256_extractf128_ps(acc, 1);
2244+
res = _mm_add_ps(res, _mm256_castps256_ps128(acc));
2245+
res = _mm_add_ps(res, _mm_movehl_ps(res, res));
22242246
res = _mm_add_ss(res, _mm_movehdup_ps(res));
22252247
sumf = _mm_cvtss_f32(res);
22262248
#else

0 commit comments

Comments
 (0)