Skip to content

Commit 2e664f1

Browse files
authored
Add initial AVX512 support for dot product on Linux (#320)
* Update Makefile to detect AVX512 support and add compiler flags if it's available * Based on existing AVX2 implementation, dot product on one 32-value block of 4-bit quantized ints at a time * Perform 8 bit -> 16 bit sign extension and multiply+add on 32 values at time instead of 16 * Use built-in AVX512 horizontal reduce add to get sum at the end * Manual unrolling on inner dot product loop to reduce loop counter overhead
1 parent 8cf9f34 commit 2e664f1

File tree

2 files changed

+109
-3
lines changed

2 files changed

+109
-3
lines changed

Makefile

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,38 @@ ifeq ($(UNAME_M),$(filter $(UNAME_M),x86_64 i686))
9595
ifneq (,$(findstring sse3,$(SSE3_M)))
9696
CFLAGS += -msse3
9797
endif
98+
AVX512F_M := $(shell grep "avx512f " /proc/cpuinfo)
99+
ifneq (,$(findstring avx512f,$(AVX512F_M)))
100+
CFLAGS += -mavx512f
101+
endif
102+
AVX512BW_M := $(shell grep "avx512bw " /proc/cpuinfo)
103+
ifneq (,$(findstring avx512bw,$(AVX512BW_M)))
104+
CFLAGS += -mavx512bw
105+
endif
106+
AVX512DQ_M := $(shell grep "avx512dq " /proc/cpuinfo)
107+
ifneq (,$(findstring avx512dq,$(AVX512DQ_M)))
108+
CFLAGS += -mavx512dq
109+
endif
110+
AVX512VL_M := $(shell grep "avx512vl " /proc/cpuinfo)
111+
ifneq (,$(findstring avx512vl,$(AVX512VL_M)))
112+
CFLAGS += -mavx512vl
113+
endif
114+
AVX512CD_M := $(shell grep "avx512cd " /proc/cpuinfo)
115+
ifneq (,$(findstring avx512cd,$(AVX512CD_M)))
116+
CFLAGS += -mavx512cd
117+
endif
118+
AVX512ER_M := $(shell grep "avx512er " /proc/cpuinfo)
119+
ifneq (,$(findstring avx512er,$(AVX512ER_M)))
120+
CFLAGS += -mavx512er
121+
endif
122+
AVX512IFMA_M := $(shell grep "avx512ifma " /proc/cpuinfo)
123+
ifneq (,$(findstring avx512ifma,$(AVX512IFMA_M)))
124+
CFLAGS += -mavx512ifma
125+
endif
126+
AVX512PF_M := $(shell grep "avx512pf " /proc/cpuinfo)
127+
ifneq (,$(findstring avx512pf,$(AVX512PF_M)))
128+
CFLAGS += -mavx512pf
129+
endif
98130
else ifeq ($(UNAME_S),Haiku)
99131
AVX1_M := $(shell sysinfo -cpu | grep "AVX ")
100132
ifneq (,$(findstring avx,$(AVX1_M)))

ggml.c

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
361361

362362
// AVX routines provided by GH user Const-me
363363
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
364-
#if __AVX2__
364+
#if __AVX2__ || __AVX512F__
365365
// Unpack 32 4-bit fields into 32 bytes
366366
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
367367
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
@@ -397,7 +397,6 @@ static inline __m128i packNibbles( __m256i bytes )
397397
}
398398
#endif
399399

400-
401400
// method 5
402401
// blocks of QK elements
403402
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
@@ -1262,6 +1261,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
12621261
*s = sumf;
12631262
}
12641263

1264+
#if __AVX512F__ && QK == 32
1265+
static inline __m512 dot_q4_0_oneblock_avx512(
1266+
__m512 acc,
1267+
const uint8_t * pd0,
1268+
const uint8_t * pd1,
1269+
const uint8_t * pb0,
1270+
const uint8_t * pb1,
1271+
size_t bs,
1272+
int i
1273+
) {
1274+
const float * d0_0 = (const float *) (pd0 + i*bs);
1275+
const float * d1_0 = (const float *) (pd1 + i*bs);
1276+
1277+
const uint8_t * restrict p0 = pb0 + (i+0)*bs;
1278+
const uint8_t * restrict p1 = pb1 + (i+0)*bs;
1279+
1280+
// Compute combined scale for the block
1281+
float scaleScalar = d0_0[0] * d1_0[0];
1282+
__m512 scale = _mm512_set1_ps( scaleScalar );
1283+
1284+
__m256i bx = bytesFromNibbles( p0 );
1285+
__m256i by = bytesFromNibbles( p1 );
1286+
1287+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1288+
const __m256i off = _mm256_set1_epi8( 8 );
1289+
bx = _mm256_sub_epi8( bx, off );
1290+
by = _mm256_sub_epi8( by, off );
1291+
1292+
// Sign-extend 16 signed bytes into int16_t
1293+
__m512i x32 = _mm512_cvtepi8_epi16( bx );
1294+
__m512i y32 = _mm512_cvtepi8_epi16( by );
1295+
// Compute products of int16_t integers, add pairwise
1296+
__m512i i64 = _mm512_madd_epi16( x32, y32 );
1297+
1298+
// Convert int32_t to float
1299+
__m512 p = _mm512_cvtepi32_ps( i64 );
1300+
// Apply the scale, and accumulate
1301+
return _mm512_fmadd_ps( scale, p, acc );
1302+
}
1303+
#endif
1304+
12651305
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
12661306
ggml_float sumf = 0.0;
12671307

@@ -1417,6 +1457,40 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
14171457
#else
14181458
#error "not implemented for QK"
14191459
#endif
1460+
#elif defined(__AVX512F__)
1461+
1462+
#if QK == 32
1463+
// Initialize accumulator with zeros
1464+
__m512 acc0 = _mm512_setzero_ps();
1465+
__m512 acc1 = _mm512_setzero_ps();
1466+
1467+
const int superblock_size = 8;
1468+
const int superblock_count = nb / superblock_size;
1469+
const int remainder = nb % superblock_size;
1470+
1471+
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
1472+
int i = superblock_ix * superblock_size;
1473+
1474+
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 );
1475+
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 );
1476+
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 );
1477+
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 );
1478+
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 );
1479+
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 );
1480+
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 );
1481+
acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 );
1482+
}
1483+
1484+
// Remainders
1485+
for (int i = superblock_count * superblock_size; i < nb; ++i) {
1486+
acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i );
1487+
}
1488+
1489+
// Horizontal sum of all lanes of the accumulator
1490+
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
1491+
#else
1492+
#error "not implemented for QK"
1493+
#endif
14201494
#elif defined(__AVX2__)
14211495
#if QK == 32
14221496
const size_t countBlocks = nb;
@@ -1928,7 +2002,7 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
19282002
const size_t bs = 2*sizeof(float) + QK/2;
19292003

19302004
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
1931-
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
2005+
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
19322006
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
19332007

19342008
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)