Skip to content

Commit 1a6a669

Browse files
committed
ggml : fix bug in Q4_1 x Q8_1 I8MM kernel
ggml-ci
1 parent 2e752c4 commit 1a6a669

File tree

3 files changed

+38
-25
lines changed

3 files changed

+38
-25
lines changed

ggml/src/ggml-cpu/ggml-cpu-quants.c

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,11 +1791,12 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
17911791
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
17921792
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
17931793

1794-
float32_t _scale[4] = { GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
1795-
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
1796-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
1797-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
1798-
1794+
float32_t _scale[4] = {
1795+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
1796+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
1797+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
1798+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
1799+
};
17991800
float32x4_t scale = vld1q_f32(_scale);
18001801

18011802
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -1811,7 +1812,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
18111812
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
18121813

18131814
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
1814-
l1, r1)), l2, r2)), l3, r3))), scale);
1815+
l1, r1)), l2, r2)), l3, r3))), scale);
18151816
}
18161817

18171818
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
@@ -2347,10 +2348,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
23472348
const block_q8_1 * restrict b_y0 = &vy0[i];
23482349
const block_q8_1 * restrict b_y1 = &vy1[i];
23492350

2350-
float32_t summs_t[4] = {GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
2351-
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
2352-
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
2353-
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)};
2351+
float32_t summs_t[4] = {
2352+
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s),
2353+
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s),
2354+
GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s),
2355+
GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s)
2356+
};
23542357
summs0 = vaddq_f32(summs0, vld1q_f32(summs_t));
23552358

23562359
const uint8x16_t m4b = vdupq_n_u8(0x0F);
@@ -2371,10 +2374,12 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
23712374
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
23722375

23732376
// mmla into int32x4_t
2374-
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*b_y0->d,
2375-
GGML_FP16_TO_FP32(b_x0->d)*b_y1->d,
2376-
GGML_FP16_TO_FP32(b_x1->d)*b_y0->d,
2377-
GGML_FP16_TO_FP32(b_x1->d)*b_y1->d};
2377+
float32_t _scale[4] = {
2378+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
2379+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
2380+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
2381+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
2382+
};
23782383
float32x4_t scale = vld1q_f32(_scale);
23792384

23802385
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -2389,15 +2394,17 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
23892394
int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
23902395
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
23912396
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
2392-
l1, r1)), l2, r2)), l3, r3))), scale);
2397+
l1, r1)), l2, r2)), l3, r3))), scale);
23932398
}
23942399

2395-
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
2400+
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
23962401
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
2402+
23972403
sumv2 = vaddq_f32(sumv2, summs0);
23982404

23992405
vst1_f32(s, vget_low_f32 (sumv2));
24002406
vst1_f32(s + bs, vget_high_f32(sumv2));
2407+
24012408
return;
24022409
}
24032410
#endif
@@ -3374,10 +3381,12 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
33743381
const int8x16_t y1_l = vld1q_s8(b_y1->qs);
33753382
const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16);
33763383

3377-
float32_t _scale[4] = {GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3378-
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3379-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3380-
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)};
3384+
float32_t _scale[4] = {
3385+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d),
3386+
GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d),
3387+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d),
3388+
GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d)
3389+
};
33813390
float32x4_t scale = vld1q_f32(_scale);
33823391

33833392
int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l)));
@@ -3393,13 +3402,15 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
33933402
int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h)));
33943403

33953404
sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)),
3396-
l1, r1)), l2, r2)), l3, r3))), scale);
3405+
l1, r1)), l2, r2)), l3, r3))), scale);
33973406
}
3398-
float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2);
3407+
3408+
float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2);
33993409
float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1);
34003410

3401-
vst1_f32(s, vget_low_f32(sumv2));
3411+
vst1_f32(s, vget_low_f32 (sumv2));
34023412
vst1_f32(s + bs, vget_high_f32(sumv2));
3413+
34033414
return;
34043415
}
34053416
#endif

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7625,8 +7625,8 @@ UseGgmlGemm2:;
76257625
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
76267626
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
76277627

7628-
// TODO: currently the mmla kernels support only even numbered rows/cols.
7629-
// this check can be removed once they are extended to support odd numbered rows/cols too
7628+
// these checks are needed to avoid crossing dim1 boundaries
7629+
// can be optimized, but the logic would become more complicated, so keeping it like this for simplicity
76307630
if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
76317631
num_rows_per_vec_dot = 1;
76327632
}

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3334,7 +3334,9 @@ static const ggml_type all_types[] = {
33343334

33353335
static const ggml_type base_types[] = {
33363336
GGML_TYPE_F32, GGML_TYPE_F16,
3337+
GGML_TYPE_Q8_0, // for I8MM tests
33373338
GGML_TYPE_Q4_0,
3339+
GGML_TYPE_Q4_1, // for I8MM tests
33383340
GGML_TYPE_Q4_K,
33393341
GGML_TYPE_IQ2_XXS
33403342
};

0 commit comments

Comments
 (0)