Skip to content

Commit 04a6b36

Browse files
authored
ggml : AVX2 implementation of ggml_vec_dot_q4_1_q8_0 (#1051)
1 parent ed24225 commit 04a6b36

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

ggml.c

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,6 +2518,62 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
25182518
}
25192519

25202520
sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
2521+
#elif defined(__AVX2__)
2522+
// Initialize accumulator with zeros
2523+
__m256 acc = _mm256_setzero_ps();
2524+
2525+
// Main loop
2526+
for (int i = 0; i < nb; ++i) {
2527+
const float * d0 = &x[i].d;
2528+
const float * d1 = &y[i].d;
2529+
const float * m0 = &x[i].m;
2530+
2531+
const __m256 d0v = _mm256_broadcast_ss( d0 );
2532+
const __m256 d1v = _mm256_broadcast_ss( d1 );
2533+
const __m256 m0v = _mm256_broadcast_ss( m0 );
2534+
2535+
// Compute combined scales
2536+
const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
2537+
const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
2538+
2539+
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2540+
const __m256i bx = bytesFromNibbles( x[i].qs );
2541+
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
2542+
2543+
// Get absolute values of x vectors
2544+
const __m256i ax = _mm256_sign_epi8( bx, bx );
2545+
2546+
// Sign the values of the y vectors
2547+
const __m256i sy = _mm256_sign_epi8( by, bx );
2548+
2549+
// Perform multiplication and create 16-bit values
2550+
const __m256i dot = _mm256_maddubs_epi16( ax, sy );
2551+
const __m256i ones = _mm256_set1_epi16( 1 );
2552+
const __m256i xy_q = _mm256_madd_epi16( ones, dot );
2553+
2554+
// Convert to vector of 8 int32_t to 8 floats
2555+
const __m256 xy = _mm256_cvtepi32_ps( xy_q );
2556+
2557+
// Accumulate d0*d1*x*y
2558+
acc = _mm256_fmadd_ps( d0d1, xy, acc );
2559+
2560+
// Compute sum of y values
2561+
const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
2562+
const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
2563+
const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
2564+
const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
2565+
2566+
// Accumulate d1*m0*y
2567+
acc = _mm256_fmadd_ps( d1m0, ysum, acc );
2568+
}
2569+
2570+
// Return horizontal sum of the acc vector
2571+
__m128 res = _mm256_extractf128_ps( acc, 1 );
2572+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2573+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2574+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2575+
2576+
sumf = _mm_cvtss_f32( res );
25212577
#else
25222578
// scalar
25232579
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)