@@ -361,7 +361,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
361
361
362
362
// AVX routines provided by GH user Const-me
363
363
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
364
- #if __AVX2__
364
+ #if __AVX2__ || __AVX512F__
365
365
// Unpack 32 4-bit fields into 32 bytes
366
366
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
367
367
static inline __m256i bytesFromNibbles ( const uint8_t * rsi )
@@ -397,7 +397,6 @@ static inline __m128i packNibbles( __m256i bytes )
397
397
}
398
398
#endif
399
399
400
-
401
400
// method 5
402
401
// blocks of QK elements
403
402
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1262,6 +1261,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
1262
1261
* s = sumf ;
1263
1262
}
1264
1263
1264
+ #if __AVX512F__ && QK == 32
1265
+ static inline __m512 dot_q4_0_oneblock_avx512 (
1266
+ __m512 acc ,
1267
+ const uint8_t * pd0 ,
1268
+ const uint8_t * pd1 ,
1269
+ const uint8_t * pb0 ,
1270
+ const uint8_t * pb1 ,
1271
+ size_t bs ,
1272
+ int i
1273
+ ) {
1274
+ const float * d0_0 = (const float * ) (pd0 + i * bs );
1275
+ const float * d1_0 = (const float * ) (pd1 + i * bs );
1276
+
1277
+ const uint8_t * restrict p0 = pb0 + (i + 0 )* bs ;
1278
+ const uint8_t * restrict p1 = pb1 + (i + 0 )* bs ;
1279
+
1280
+ // Compute combined scale for the block
1281
+ float scaleScalar = d0_0 [0 ] * d1_0 [0 ];
1282
+ __m512 scale = _mm512_set1_ps ( scaleScalar );
1283
+
1284
+ __m256i bx = bytesFromNibbles ( p0 );
1285
+ __m256i by = bytesFromNibbles ( p1 );
1286
+
1287
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1288
+ const __m256i off = _mm256_set1_epi8 ( 8 );
1289
+ bx = _mm256_sub_epi8 ( bx , off );
1290
+ by = _mm256_sub_epi8 ( by , off );
1291
+
1292
+ // Sign-extend 16 signed bytes into int16_t
1293
+ __m512i x32 = _mm512_cvtepi8_epi16 ( bx );
1294
+ __m512i y32 = _mm512_cvtepi8_epi16 ( by );
1295
+ // Compute products of int16_t integers, add pairwise
1296
+ __m512i i64 = _mm512_madd_epi16 ( x32 , y32 );
1297
+
1298
+ // Convert int32_t to float
1299
+ __m512 p = _mm512_cvtepi32_ps ( i64 );
1300
+ // Apply the scale, and accumulate
1301
+ return _mm512_fmadd_ps ( scale , p , acc );
1302
+ }
1303
+ #endif
1304
+
1265
1305
inline static void ggml_vec_dot_f16 (const int n , float * restrict s , ggml_fp16_t * restrict x , ggml_fp16_t * restrict y ) {
1266
1306
ggml_float sumf = 0.0 ;
1267
1307
@@ -1417,6 +1457,40 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
1417
1457
#else
1418
1458
#error "not implemented for QK"
1419
1459
#endif
1460
+ #elif defined(__AVX512F__ )
1461
+
1462
+ #if QK == 32
1463
+ // Initialize accumulator with zeros
1464
+ __m512 acc0 = _mm512_setzero_ps ();
1465
+ __m512 acc1 = _mm512_setzero_ps ();
1466
+
1467
+ const int superblock_size = 8 ;
1468
+ const int superblock_count = nb / superblock_size ;
1469
+ const int remainder = nb % superblock_size ;
1470
+
1471
+ for (int superblock_ix = 0 ; superblock_ix < superblock_count ; superblock_ix += 1 ) {
1472
+ int i = superblock_ix * superblock_size ;
1473
+
1474
+ acc0 = dot_q4_0_oneblock_avx512 ( acc0 , pd0 , pd1 , pb0 , pb1 , bs , i + 0 );
1475
+ acc1 = dot_q4_0_oneblock_avx512 ( acc1 , pd0 , pd1 , pb0 , pb1 , bs , i + 1 );
1476
+ acc0 = dot_q4_0_oneblock_avx512 ( acc0 , pd0 , pd1 , pb0 , pb1 , bs , i + 2 );
1477
+ acc1 = dot_q4_0_oneblock_avx512 ( acc1 , pd0 , pd1 , pb0 , pb1 , bs , i + 3 );
1478
+ acc0 = dot_q4_0_oneblock_avx512 ( acc0 , pd0 , pd1 , pb0 , pb1 , bs , i + 4 );
1479
+ acc1 = dot_q4_0_oneblock_avx512 ( acc1 , pd0 , pd1 , pb0 , pb1 , bs , i + 5 );
1480
+ acc0 = dot_q4_0_oneblock_avx512 ( acc0 , pd0 , pd1 , pb0 , pb1 , bs , i + 6 );
1481
+ acc1 = dot_q4_0_oneblock_avx512 ( acc1 , pd0 , pd1 , pb0 , pb1 , bs , i + 7 );
1482
+ }
1483
+
1484
+ // Remainders
1485
+ for (int i = superblock_count * superblock_size ; i < nb ; ++ i ) {
1486
+ acc0 = dot_q4_0_oneblock_avx512 ( acc0 , pd0 , pd1 , pb0 , pb1 , bs , i );
1487
+ }
1488
+
1489
+ // Horizontal sum of all lanes of the accumulator
1490
+ sumf = _mm512_reduce_add_ps ( acc0 ) + _mm512_reduce_add_ps ( acc1 );
1491
+ #else
1492
+ #error "not implemented for QK"
1493
+ #endif
1420
1494
#elif defined(__AVX2__ )
1421
1495
#if QK == 32
1422
1496
const size_t countBlocks = nb ;
@@ -1928,7 +2002,7 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
1928
2002
const size_t bs = 2 * sizeof (float ) + QK /2 ;
1929
2003
1930
2004
const uint8_t * restrict pd = ((const uint8_t * )x + 0 * bs );
1931
- const uint8_t * restrict pm = ((const uint8_t * )x + 0 * bs + sizeof (float ));
2005
+ const uint8_t * restrict pm = ((const uint8_t * )x + 0 * bs + sizeof (float ));
1932
2006
const uint8_t * restrict pb = ((const uint8_t * )x + 0 * bs + 2 * sizeof (float ));
1933
2007
1934
2008
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments