Skip to content

Commit eafd47f

Browse files
swggerganov
authored andcommitted
Q8: use int8_t, AVX/AVX2 optimizations
1 parent 3f93a00 commit eafd47f

File tree

1 file changed

+190
-26
lines changed

1 file changed

+190
-26
lines changed

ggml.c

Lines changed: 190 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 bloc
583583

584584
typedef struct {
585585
float d; // delta
586-
uint8_t qs[QK]; // nibbles / quants
586+
int8_t qs[QK]; // quants
587587
} block_q8_0;
588588
static_assert(sizeof(block_q8_0) == sizeof(float) + QK, "wrong q8_0 block size/padding");
589589

@@ -1060,9 +1060,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r
10601060

10611061
for (int l = 0; l < QK; ++l) {
10621062
const float v = x[i*QK + l]*id;
1063-
const uint8_t vi = (int8_t)roundf(v) + 128;
1064-
1065-
y[i].qs[l] = vi;
1063+
y[i].qs[l] = roundf(v);
10661064
}
10671065
}
10681066
}
@@ -1095,15 +1093,99 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
10951093

10961094
for (int l = 0; l < 8; l++) {
10971095
const float32x4_t v = vmulq_n_f32(srcv[l], id);
1098-
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(128.5f));
1099-
const int32x4_t vi = vcvtq_s32_f32(vf);
1096+
//TODO: rounding
1097+
const int32x4_t vi = vcvtq_s32_f32(v);
11001098

11011099
y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0);
11021100
y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1);
11031101
y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2);
11041102
y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3);
11051103
}
11061104
}
1105+
#elif defined(__AVX2__) || defined(__AVX__)
1106+
for (int i = 0; i < nb; i++) {
1107+
// Load elements into 4 AVX vectors
1108+
__m256 v0 = _mm256_loadu_ps( x );
1109+
__m256 v1 = _mm256_loadu_ps( x + 8 );
1110+
__m256 v2 = _mm256_loadu_ps( x + 16 );
1111+
__m256 v3 = _mm256_loadu_ps( x + 24 );
1112+
x += 32;
1113+
1114+
// Compute max(abs(e)) for the block
1115+
const __m256 signBit = _mm256_set1_ps( -0.0f );
1116+
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
1117+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
1118+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
1119+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
1120+
1121+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
1122+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
1123+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
1124+
const float maxScalar = _mm_cvtss_f32( max4 );
1125+
1126+
// Quantize these floats
1127+
const float d = maxScalar / 127.f;
1128+
y[i].d = d;
1129+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
1130+
const __m256 mul = _mm256_set1_ps( id );
1131+
1132+
// Apply the multiplier
1133+
v0 = _mm256_mul_ps( v0, mul );
1134+
v1 = _mm256_mul_ps( v1, mul );
1135+
v2 = _mm256_mul_ps( v2, mul );
1136+
v3 = _mm256_mul_ps( v3, mul );
1137+
1138+
// Round to nearest integer
1139+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
1140+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
1141+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
1142+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
1143+
1144+
// Convert floats to integers
1145+
__m256i i0 = _mm256_cvtps_epi32( v0 );
1146+
__m256i i1 = _mm256_cvtps_epi32( v1 );
1147+
__m256i i2 = _mm256_cvtps_epi32( v2 );
1148+
__m256i i3 = _mm256_cvtps_epi32( v3 );
1149+
1150+
#if defined(__AVX2__)
1151+
// Convert int32 to int16
1152+
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
1153+
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
1154+
// Convert int16 to int8
1155+
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
1156+
1157+
// We got our precious signed bytes, but the order is now wrong
1158+
// These AVX2 pack instructions process 16-byte pieces independently
1159+
// The following instruction is fixing the order
1160+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
1161+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
1162+
1163+
_mm256_storeu_si256((__m256i *)y[i].qs, i0);
1164+
#else
1165+
// Since we don't have in AVX some necessary functions,
1166+
// we split the registers in half and call AVX2 analogs from SSE
1167+
__m128i ni0 = _mm256_castsi256_si128( i0 );
1168+
__m128i ni1 = _mm256_extractf128_si256( i0, 1);
1169+
__m128i ni2 = _mm256_castsi256_si128( i1 );
1170+
__m128i ni3 = _mm256_extractf128_si256( i1, 1);
1171+
__m128i ni4 = _mm256_castsi256_si128( i2 );
1172+
__m128i ni5 = _mm256_extractf128_si256( i2, 1);
1173+
__m128i ni6 = _mm256_castsi256_si128( i3 );
1174+
__m128i ni7 = _mm256_extractf128_si256( i3, 1);
1175+
1176+
// Convert int32 to int16
1177+
ni0 = _mm_packs_epi32( ni0, ni1 );
1178+
ni2 = _mm_packs_epi32( ni2, ni3 );
1179+
ni4 = _mm_packs_epi32( ni4, ni5 );
1180+
ni6 = _mm_packs_epi32( ni6, ni7 );
1181+
// Convert int16 to int8
1182+
ni0 = _mm_packs_epi16( ni0, ni2 );
1183+
ni4 = _mm_packs_epi16( ni4, ni6 );
1184+
1185+
_mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0);
1186+
_mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4);
1187+
#endif
1188+
}
11071189
#else
11081190
// scalar
11091191
quantize_row_q8_0_reference(x, y, k);
@@ -2508,7 +2590,6 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25082590

25092591
const uint8x16_t m4b = vdupq_n_u8(0xf);
25102592
const int8x16_t s8b = vdupq_n_s8(0x8);
2511-
const uint8x16_t u128b = vdupq_n_u8(128);
25122593

25132594
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
25142595
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
@@ -2526,21 +2607,16 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25262607
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
25272608

25282609
// load y
2529-
const uint8x16_t v1_0l = vld1q_u8(y0->qs);
2530-
const uint8x16_t v1_0h = vld1q_u8(y0->qs + 16);
2531-
const uint8x16_t v1_1l = vld1q_u8(y1->qs);
2532-
const uint8x16_t v1_1h = vld1q_u8(y1->qs + 16);
2610+
const int8x16_t v1_0l = vld1q_s8(y0->qs);
2611+
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
2612+
const int8x16_t v1_1l = vld1q_s8(y1->qs);
2613+
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
25332614

25342615
// interleave
2535-
const uint8x16_t v1_0lz = vuzp1q_u8(v1_0l, v1_0h);
2536-
const uint8x16_t v1_0hz = vuzp2q_u8(v1_0l, v1_0h);
2537-
const uint8x16_t v1_1lz = vuzp1q_u8(v1_1l, v1_1h);
2538-
const uint8x16_t v1_1hz = vuzp2q_u8(v1_1l, v1_1h);
2539-
2540-
const int8x16_t v1_0ls = vreinterpretq_s8_u8(vsubq_u8(v1_0lz, u128b));
2541-
const int8x16_t v1_0hs = vreinterpretq_s8_u8(vsubq_u8(v1_0hz, u128b));
2542-
const int8x16_t v1_1ls = vreinterpretq_s8_u8(vsubq_u8(v1_1lz, u128b));
2543-
const int8x16_t v1_1hs = vreinterpretq_s8_u8(vsubq_u8(v1_1hz, u128b));
2616+
const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
2617+
const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
2618+
const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
2619+
const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
25442620

25452621
#if defined(__ARM_FEATURE_DOTPROD)
25462622
// dot product into int32x4_t
@@ -2578,14 +2654,102 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25782654
}
25792655

25802656
sumf = sum0 + sum1;
2657+
#elif defined(__AVX2__)
2658+
// Initialize accumulator with zeros
2659+
__m256 acc = _mm256_setzero_ps();
2660+
2661+
// Main loop
2662+
for (int i = 0; i < nb; ++i) {
2663+
/* Compute combined scale for the block */
2664+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2665+
2666+
__m256i bx = bytesFromNibbles(x[i].qs);
2667+
2668+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2669+
const __m256i off = _mm256_set1_epi8( 8 );
2670+
bx = _mm256_sub_epi8( bx, off );
2671+
2672+
__m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2673+
2674+
// Get absolute values of x vectors
2675+
const __m256i ax = _mm256_sign_epi8(bx, bx);
2676+
2677+
// Sign the values of the y vectors
2678+
const __m256i sy = _mm256_sign_epi8(by, bx);
2679+
2680+
// Perform multiplication and create 16-bit values
2681+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
2682+
2683+
const __m256i ones = _mm256_set1_epi16(1);
2684+
__m256i xy_q = _mm256_madd_epi16(ones, dot);
2685+
2686+
/* Convert to vectore of 8 int32_t to 8 floats */
2687+
__m256 q = _mm256_cvtepi32_ps( xy_q );
2688+
2689+
/* Multiply q with scale and accumulate */
2690+
acc = _mm256_fmadd_ps( d, q, acc );
2691+
}
2692+
2693+
// Return horizontal sum of the acc vector
2694+
__m128 res = _mm256_extractf128_ps( acc, 1 );
2695+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2696+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2697+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2698+
2699+
sumf = _mm_cvtss_f32( res );
2700+
#elif defined(__AVX__)
2701+
// Initialize accumulator with zeros
2702+
__m256 acc = _mm256_setzero_ps();
2703+
2704+
// Main loop
2705+
for (int i = 0; i < nb; ++i) {
2706+
// Compute combined scale for the block
2707+
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
2708+
2709+
__m128i i32[2];
2710+
for (int j = 0; j < 2; ++j) {
2711+
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2712+
__m128i bx = bytesFromNibbles( x[i].qs + 8*j );
2713+
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2714+
2715+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2716+
const __m128i off = _mm_set1_epi8( 8 );
2717+
bx = _mm_sub_epi8( bx, off );
2718+
2719+
// Get absolute values of x vectors
2720+
const __m128i ax = _mm_sign_epi8(bx, bx);
2721+
2722+
// Sign the values of the y vectors
2723+
const __m128i sy = _mm_sign_epi8(by, bx);
2724+
2725+
// Perform multiplication and create 16-bit values
2726+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2727+
2728+
const __m128i ones = _mm_set1_epi16(1);
2729+
i32[j] = _mm_madd_epi16(ones, dot);
2730+
}
2731+
2732+
// Convert int32_t to float
2733+
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2734+
// Apply the scale, and accumulate
2735+
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
2736+
}
2737+
2738+
// Return horizontal sum of the acc vector
2739+
__m128 res = _mm256_extractf128_ps( acc, 1 );
2740+
res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
2741+
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
2742+
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
2743+
2744+
sumf = _mm_cvtss_f32( res );
25812745
#else
25822746
// scalar
25832747
for (int i = 0; i < nb; i++) {
25842748
const float d0 = x[i].d;
25852749
const float d1 = y[i].d;
25862750

25872751
const uint8_t * restrict p0 = x[i].qs;
2588-
const uint8_t * restrict p1 = y[i].qs;
2752+
const int8_t * restrict p1 = y[i].qs;
25892753

25902754
int sumi = 0;
25912755
for (int j = 0; j < QK/2; j++) {
@@ -2594,10 +2758,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
25942758
const int i0 = (int8_t) (v0 & 0xf) - 8;
25952759
const int i1 = (int8_t) (v0 >> 4) - 8;
25962760

2597-
const int i2 = (int) p1[2*j + 0] - 128;
2598-
const int i3 = (int) p1[2*j + 1] - 128;
2599-
2600-
/*printf("dot product: i0=%4d i1=%4d i2=%4d i3=%4d\n", i0, i1, i2, i3);*/
2761+
const int i2 = p1[2*j + 0];
2762+
const int i3 = p1[2*j + 1];
26012763

26022764
sumi += i0*i2 + i1*i3;
26032765
}
@@ -9923,7 +10085,9 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
992310085
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
992410086
} else
992510087
#endif
9926-
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
10088+
{
10089+
cur = GGML_TYPE_SIZE[GGML_TYPE_Q8_0]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[GGML_TYPE_Q8_0];
10090+
}
992710091
} else {
992810092
GGML_ASSERT(false);
992910093
}

0 commit comments

Comments
 (0)