@@ -2518,6 +2518,62 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void *
2518
2518
}
2519
2519
2520
2520
sumf = vaddvq_f32 (sumv0 ) + vaddvq_f32 (sumv1 );
2521
+ #elif defined(__AVX2__ )
2522
+ // Initialize accumulator with zeros
2523
+ __m256 acc = _mm256_setzero_ps ();
2524
+
2525
+ // Main loop
2526
+ for (int i = 0 ; i < nb ; ++ i ) {
2527
+ const float * d0 = & x [i ].d ;
2528
+ const float * d1 = & y [i ].d ;
2529
+ const float * m0 = & x [i ].m ;
2530
+
2531
+ const __m256 d0v = _mm256_broadcast_ss ( d0 );
2532
+ const __m256 d1v = _mm256_broadcast_ss ( d1 );
2533
+ const __m256 m0v = _mm256_broadcast_ss ( m0 );
2534
+
2535
+ // Compute combined scales
2536
+ const __m256 d0d1 = _mm256_mul_ps ( d0v , d1v );
2537
+ const __m256 d1m0 = _mm256_mul_ps ( d1v , m0v );
2538
+
2539
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
2540
+ const __m256i bx = bytesFromNibbles ( x [i ].qs );
2541
+ const __m256i by = _mm256_loadu_si256 ( (const __m256i * )y [i ].qs );
2542
+
2543
+ // Get absolute values of x vectors
2544
+ const __m256i ax = _mm256_sign_epi8 ( bx , bx );
2545
+
2546
+ // Sign the values of the y vectors
2547
+ const __m256i sy = _mm256_sign_epi8 ( by , bx );
2548
+
2549
+ // Perform multiplication and create 16-bit values
2550
+ const __m256i dot = _mm256_maddubs_epi16 ( ax , sy );
2551
+ const __m256i ones = _mm256_set1_epi16 ( 1 );
2552
+ const __m256i xy_q = _mm256_madd_epi16 ( ones , dot );
2553
+
2554
+ // Convert to vector of 8 int32_t to 8 floats
2555
+ const __m256 xy = _mm256_cvtepi32_ps ( xy_q );
2556
+
2557
+ // Accumulate d0*d1*x*y
2558
+ acc = _mm256_fmadd_ps ( d0d1 , xy , acc );
2559
+
2560
+ // Compute sum of y values
2561
+ const __m256i y16_l = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
2562
+ const __m256i y16_h = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
2563
+ const __m256i ysumi = _mm256_madd_epi16 ( _mm256_add_epi16 (y16_l , y16_h ), ones );
2564
+ const __m256 ysum = _mm256_cvtepi32_ps ( ysumi );
2565
+
2566
+ // Accumulate d1*m0*y
2567
+ acc = _mm256_fmadd_ps ( d1m0 , ysum , acc );
2568
+ }
2569
+
2570
+ // Return horizontal sum of the acc vector
2571
+ __m128 res = _mm256_extractf128_ps ( acc , 1 );
2572
+ res = _mm_add_ps ( res , _mm256_castps256_ps128 ( acc ) );
2573
+ res = _mm_add_ps ( res , _mm_movehl_ps ( res , res ) );
2574
+ res = _mm_add_ss ( res , _mm_movehdup_ps ( res ) );
2575
+
2576
+ sumf = _mm_cvtss_f32 ( res );
2521
2577
#else
2522
2578
// scalar
2523
2579
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments