Skip to content

Commit 904d2a8

Browse files
authored
Q4_1 quantization (#193)
* Add AVX2 version of ggml_vec_dot_q4_1 * Small optimisations to q4_1 dot product (@Const-me) * Rearrange Q4_1 quantization to work for multipart models. (Fix #152) * Fix ggml_vec_mad_q4_1 too * Fix non-vectorised q4_1 vec mul
1 parent 7213110 commit 904d2a8

File tree

2 files changed

+130
-39
lines changed

2 files changed

+130
-39
lines changed

ggml.c

Lines changed: 118 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -607,10 +607,11 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
607607
assert(k % QK == 0);
608608

609609
const int nb = k / QK;
610+
const size_t bs = 2*sizeof(float) + QK/2;
610611

611-
float * restrict pm = (float *) (y);
612-
float * restrict pd = (float *) (pm + nb);
613-
uint8_t * restrict pb = (uint8_t *) (pd + nb);
612+
uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
613+
uint8_t * restrict pm = ((uint8_t *)y + 0*bs + sizeof(float));
614+
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float));
614615

615616
uint8_t pp[QK/2];
616617

@@ -627,8 +628,10 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
627628
const float d = (max - min) / ((1 << 4) - 1);
628629
const float id = d ? 1.0f/d : 0.0f;
629630

630-
pm[i] = min;
631-
pd[i] = d;
631+
*(float *)pm = min;
632+
*(float *)pd = d;
633+
pm += bs;
634+
pd += bs;
632635

633636
for (int l = 0; l < QK; l += 2) {
634637
const float v0 = (x[i*QK + l + 0] - min)*id;
@@ -643,7 +646,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
643646
pp[l/2] = vi0 | (vi1 << 4);
644647
}
645648

646-
memcpy(pb + i*QK/2, pp, sizeof(pp));
649+
memcpy(pb, pp, sizeof(pp));
650+
pb += bs;
647651
}
648652
}
649653

@@ -687,16 +691,17 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
687691
assert(k % QK == 0);
688692

689693
const int nb = k / QK;
694+
const size_t bs = 2*sizeof(float) + QK/2;
690695

691-
const float * restrict pm = (const float *) (x);
692-
const float * restrict pd = (const float *) (pm + nb);
693-
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
696+
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
697+
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
698+
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
694699

695700
for (int i = 0; i < nb; i++) {
696-
const float m = pm[i];
697-
const float d = pd[i];
701+
const float d = *(const float *) (pd + i*bs);
702+
const float m = *(const float *) (pm + i*bs);
698703

699-
const uint8_t * restrict pp = pb + i*QK/2;
704+
const uint8_t * restrict pp = pb + i*bs;
700705

701706
for (int l = 0; l < QK; l += 2) {
702707
const uint8_t vi = pp[l/2];
@@ -1584,28 +1589,109 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
15841589
inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
15851590
const int nb = n / QK;
15861591

1587-
const float * restrict pm0 = (const float *) x;
1588-
const float * restrict pm1 = (const float *) y;
1592+
const size_t bs = 2*sizeof(float) + QK/2;
15891593

1590-
const float * restrict pd0 = (const float *) (pm0 + nb);
1591-
const float * restrict pd1 = (const float *) (pm1 + nb);
1594+
const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
1595+
const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);
1596+
1597+
const uint8_t * restrict pm0 = ((const uint8_t *)x + 0*bs + sizeof(float));
1598+
const uint8_t * restrict pm1 = ((const uint8_t *)y + 0*bs + sizeof(float));
15921599

1593-
const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
1594-
const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
1600+
const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
1601+
const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + 2*sizeof(float));
15951602

15961603
float sumf = 0.0;
15971604

1598-
#if 1
1605+
#if defined(__AVX2__)
1606+
#if QK == 32
1607+
// Initialize accumulator with zeros
1608+
__m256 acc = _mm256_setzero_ps();
1609+
// Accumulator for constant offsets
1610+
float acc_offset = 0.0f;
1611+
1612+
// Main loop
1613+
for (int i = 0; i < nb; ++i) {
1614+
const float * m0 = (const float *) (pm0 + i*bs);
1615+
const float * m1 = (const float *) (pm1 + i*bs);
1616+
1617+
const float * d0 = (const float *) (pd0 + i*bs);
1618+
const float * d1 = (const float *) (pd1 + i*bs);
1619+
1620+
const uint8_t * restrict p0 = pb0 + i*bs;
1621+
const uint8_t * restrict p1 = pb1 + i*bs;
1622+
1623+
const __m256 d0v = _mm256_broadcast_ss( d0 );
1624+
const __m256 d1v = _mm256_broadcast_ss( d1 );
1625+
const __m256 m0v = _mm256_broadcast_ss( m0 );
1626+
const __m256 m1v = _mm256_broadcast_ss( m1 );
1627+
1628+
1629+
// Compute combined scale for the block
1630+
const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
1631+
1632+
// Compute cross scales for the block
1633+
const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
1634+
const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
1635+
const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 );
1636+
1637+
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1638+
__m256i bx = bytesFromNibbles( p0 );
1639+
__m256i by = bytesFromNibbles( p1 );
1640+
1641+
// Now we have a vector with bytes in [ 0 .. 15 ] interval.
1642+
1643+
// Sign-extend first 16 signed bytes into int16_t
1644+
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
1645+
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
1646+
// Compute products of int16_t integers, add pairwise
1647+
__m256i i32 = _mm256_madd_epi16( x16, y16 );
1648+
1649+
// Sign-extend last 16 signed bytes into int16_t vectors
1650+
__m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
1651+
__m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
1652+
// Accumulate products of int16_t integers
1653+
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
1654+
1655+
// compute sums of unsigned bytes in bx, by in blocks of 8.
1656+
// This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
1657+
// which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
1658+
// so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
1659+
__m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
1660+
__m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
1661+
__m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
1662+
__m256 sums = _mm256_cvtepi32_ps( sumsi );
1663+
1664+
// Convert int32_t to float
1665+
__m256 p = _mm256_cvtepi32_ps( i32 );
1666+
// Apply the scale, and accumulate
1667+
// acc += d0*d1*x*y + d0*m1*x + d1*m0*y
1668+
acc = _mm256_fmadd_ps( scale_01, p, acc );
1669+
acc = _mm256_fmadd_ps( cross_scales, sums, acc );
1670+
// acc_offset += m0*m1 (for each entry in the block)
1671+
acc_offset += (*m0)*(*m1);
1672+
}
1673+
1674+
// Return horizontal sum of the acc vector
1675+
__m128 res = _mm256_extractf128_ps( acc, 1 );
1676+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
1677+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
1678+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
1679+
1680+
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
1681+
#else
1682+
#error "not implemented for QK"
1683+
#endif
1684+
#else
15991685
// scalar
16001686
for (int i = 0; i < nb; i++) {
1601-
const float m0 = pm0[i];
1602-
const float m1 = pm1[i];
1687+
const float m0 = *(const float *) (pm0 + i*bs);
1688+
const float m1 = *(const float *) (pm1 + i*bs);
16031689

1604-
const float d0 = pd0[i];
1605-
const float d1 = pd1[i];
1690+
const float d0 = *(const float *) (pd0 + i*bs);
1691+
const float d1 = *(const float *) (pd1 + i*bs);
16061692

1607-
const uint8_t * restrict p0 = pb0 + i*QK/2;
1608-
const uint8_t * restrict p1 = pb1 + i*QK/2;
1693+
const uint8_t * restrict p0 = pb0 + i*bs;
1694+
const uint8_t * restrict p1 = pb1 + i*bs;
16091695

16101696
for (int j = 0; j < QK/2; j++) {
16111697
const uint8_t v0 = p0[j];
@@ -1839,16 +1925,17 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
18391925
assert(n % QK == 0);
18401926

18411927
const int nb = n / QK;
1928+
const size_t bs = 2*sizeof(float) + QK/2;
18421929

1843-
const float * restrict pm = (const float *) (x);
1844-
const float * restrict pd = (const float *) (pm + nb);
1845-
const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
1930+
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
1931+
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
1932+
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
18461933

18471934
for (int i = 0; i < nb; i++) {
1848-
const float m = pm[i];
1849-
const float d = pd[i];
1935+
const float d = *(const float *) (pd + i*bs);
1936+
const float m = *(const float *) (pm + i*bs);
18501937

1851-
const uint8_t * restrict pp = pb + i*QK/2;
1938+
const uint8_t * restrict pp = pb + i*bs;
18521939

18531940
for (int l = 0; l < QK; l += 2) {
18541941
const uint8_t vi = pp[l/2];

utils.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,8 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t
489489

490490
size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t * hist) {
491491
const int nb = k / qk;
492-
const size_t row_size = nb*(2*sizeof(float) + sizeof(uint8_t)*qk/2);
492+
const size_t bs = (2*sizeof(float) + sizeof(uint8_t)*qk/2);
493+
const size_t row_size = nb*bs;
493494

494495
assert(k % qk == 0);
495496

@@ -498,10 +499,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
498499

499500
char * pdst = (char *) dst;
500501

501-
for (int j = 0; j < n; j += k) {
502-
float * pm = (float *) (pdst + (j/k)*row_size);
503-
float * pd = (float *) (pm + nb);
504-
uint8_t * pb = (uint8_t *) (pd + nb);
502+
for (int j = 0; j < n; j += k) {
503+
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
504+
uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
505+
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
505506

506507
//printf("n = %d, k = %d, nb = %d, row_size = %d, j = %d, pm = %p, pd = %p, pb = %p\n", n, k, nb, row_size, j, pm, pd, pb);
507508

@@ -519,8 +520,10 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
519520
const float d = (max - min) / ((1 << 4) - 1);
520521
const float id = d ? 1.0f/d : 0.0f;
521522

522-
pm[i] = min;
523-
pd[i] = d;
523+
*(float *) pd = d;
524+
*(float *) pm = min;
525+
pd += bs;
526+
pm += bs;
524527

525528
for (int l = 0; l < qk; l += 2) {
526529
const float v0 = (src[j + i*qk + l + 0] - min)*id;
@@ -538,7 +541,8 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
538541
pp[l/2] = vi0 | (vi1 << 4);
539542
}
540543

541-
memcpy(pb + i*qk/2, pp, pp_size);
544+
memcpy(pb, pp, pp_size);
545+
pb += bs;
542546
}
543547
}
544548
}

0 commit comments

Comments
 (0)