Skip to content

Commit a527ecc

Browse files
committed
metal: fix bugs for GQA and perplexity test.
I mixed up ne02 and nb02 in previous commit.
1 parent bfa455d commit a527ecc

File tree

1 file changed

+39
-30
lines changed

1 file changed

+39
-30
lines changed

ggml-metal.metal

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
343343
// N_DST, so this is another explicit assumption of the implementation.
344344
template<typename block_q_type, int nr, int nsg, int nw>
345345
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
346-
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa,
346+
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
347347
uint3 tgpig, uint tiisg, uint sgitg) {
348348
const int nb = ne00/QK4_0;
349349
const int r0 = tgpig.x;
350350
const int r1 = tgpig.y;
351351
const int im = tgpig.z;
352352
const int first_row = (r0 * nsg + sgitg) * nr;
353-
const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0);
353+
const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
354354
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
355-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne12;
355+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
356356
float yl[16]; // src1 vector cache
357357
float sumf[nr]={0.f};
358358

@@ -383,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
383383
for (int row = 0; row < nr; ++row) {
384384
const float tot = simd_sum(sumf[row]);
385385
if (tiisg == 0 && first_row + row < ne01) {
386-
dst[r1*ne0 + im*ne12 + first_row + row] = tot;
386+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
387387
}
388388
}
389389
}
@@ -398,11 +398,12 @@ kernel void kernel_mul_mat_q4_0_f32(
398398
constant int64_t & ne10[[buffer(9)]],
399399
constant int64_t & ne12[[buffer(11)]],
400400
constant int64_t & ne0[[buffer(15)]],
401+
constant int64_t & ne1[[buffer(16)]],
401402
constant uint & gqa[[buffer(17)]],
402403
uint3 tgpig[[threadgroup_position_in_grid]],
403404
uint tiisg[[thread_index_in_simdgroup]],
404405
uint sgitg[[simdgroup_index_in_threadgroup]]) {
405-
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
406+
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
406407
}
407408

408409
kernel void kernel_mul_mat_q4_1_f32(
@@ -415,11 +416,12 @@ kernel void kernel_mul_mat_q4_1_f32(
415416
constant int64_t & ne10[[buffer(9)]],
416417
constant int64_t & ne12[[buffer(11)]],
417418
constant int64_t & ne0[[buffer(15)]],
419+
constant int64_t & ne1[[buffer(16)]],
418420
constant uint & gqa[[buffer(17)]],
419421
uint3 tgpig[[threadgroup_position_in_grid]],
420422
uint tiisg[[thread_index_in_simdgroup]],
421423
uint sgitg[[simdgroup_index_in_threadgroup]]) {
422-
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
424+
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
423425
}
424426

425427
kernel void kernel_mul_mat_f16_f32(
@@ -800,6 +802,7 @@ kernel void kernel_mul_mat_q2_K_f32(
800802
constant int64_t & ne10[[buffer(9)]],
801803
constant int64_t & ne12[[buffer(11)]],
802804
constant int64_t & ne0[[buffer(15)]],
805+
constant int64_t & ne1[[buffer(16)]],
803806
constant uint & gqa[[buffer(17)]],
804807
uint3 tgpig[[threadgroup_position_in_grid]],
805808
uint tiisg[[thread_index_in_simdgroup]],
@@ -812,9 +815,9 @@ kernel void kernel_mul_mat_q2_K_f32(
812815

813816
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
814817
const int ib_row = first_row * nb;
815-
const uint offset0 = r2/gqa*(ne02/QK_K);
818+
const uint offset0 = r2/gqa*(nb*ne0);
816819
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
817-
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
820+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
818821
float yl[32];
819822
float sumf[N_DST]={0.f}, all_sum;
820823

@@ -927,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32(
927930
for (int row = 0; row < N_DST; ++row) {
928931
all_sum = simd_sum(sumf[row]);
929932
if (tiisg == 0) {
930-
dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum;
933+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
931934
}
932935
}
933936
}
@@ -943,6 +946,7 @@ kernel void kernel_mul_mat_q3_K_f32(
943946
constant int64_t & ne10[[buffer(9)]],
944947
constant int64_t & ne12[[buffer(11)]],
945948
constant int64_t & ne0[[buffer(15)]],
949+
constant int64_t & ne1[[buffer(16)]],
946950
constant uint & gqa[[buffer(17)]],
947951
uint3 tgpig[[threadgroup_position_in_grid]],
948952
uint tiisg[[thread_index_in_simdgroup]],
@@ -955,9 +959,9 @@ kernel void kernel_mul_mat_q3_K_f32(
955959
const int64_t r2 = tgpig.z;
956960

957961
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
958-
const uint offset0 = r2/gqa*(ne02/QK_K);
962+
const uint offset0 = r2/gqa*(nb*ne0);
959963
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
960-
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
964+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
961965

962966
float yl[16];
963967

@@ -1045,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32(
10451049
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
10461050
const float tot = simd_sum(sumf);
10471051
if (tiisg == 0) {
1048-
dst[r1*ne0 + r2*ne12 + first_row + row] = tot;
1052+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
10491053
}
10501054
}
10511055
}
@@ -1060,6 +1064,7 @@ kernel void kernel_mul_mat_q3_K_f32(
10601064
constant int64_t & ne10[[buffer(9)]],
10611065
constant int64_t & ne12[[buffer(11)]],
10621066
constant int64_t & ne0[[buffer(15)]],
1067+
constant int64_t & ne1[[buffer(16)]],
10631068
constant uint & gqa[[buffer(17)]],
10641069
uint3 tgpig[[threadgroup_position_in_grid]],
10651070
uint tiisg[[thread_index_in_simdgroup]],
@@ -1072,9 +1077,9 @@ kernel void kernel_mul_mat_q3_K_f32(
10721077
const int64_t r2 = tgpig.z;
10731078

10741079
const int row = 2 * r0 + sgitg;
1075-
const uint offset0 = r2/gqa*(ne02/QK_K);
1080+
const uint offset0 = r2/gqa*(nb*ne0);
10761081
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1077-
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
1082+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
10781083
const int ix = tiisg/4;
10791084
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
10801085
const int im = il/8; // 0, 0, 1, 1
@@ -1113,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32(
11131118

11141119
const float tot = simd_sum(sumf);
11151120
if (tiisg == 0) {
1116-
dst[r1*ne0 + r2*ne12 + row] = tot;
1121+
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
11171122
}
11181123

11191124
}
@@ -1130,6 +1135,7 @@ kernel void kernel_mul_mat_q4_K_f32(
11301135
constant int64_t & ne10[[buffer(9)]],
11311136
constant int64_t & ne12[[buffer(11)]],
11321137
constant int64_t & ne0[[buffer(15)]],
1138+
constant int64_t & ne1[[buffer(16)]],
11331139
constant uint & gqa[[buffer(17)]],
11341140
uint3 tgpig[[threadgroup_position_in_grid]],
11351141
uint tiisg[[thread_index_in_simdgroup]],
@@ -1150,9 +1156,9 @@ kernel void kernel_mul_mat_q4_K_f32(
11501156
const int r2 = tgpig.z;
11511157
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
11521158
const int ib_row = first_row * nb;
1153-
const uint offset0 = r2/gqa*(ne02/QK_K);
1159+
const uint offset0 = r2/gqa*(nb*ne0);
11541160
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1155-
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
1161+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
11561162
float yl[16];
11571163
float yh[16];
11581164
float sumf[N_DST]={0.f}, all_sum;
@@ -1219,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32(
12191225
for (int row = 0; row < N_DST; ++row) {
12201226
all_sum = simd_sum(sumf[row]);
12211227
if (tiisg == 0) {
1222-
dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum;
1228+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
12231229
}
12241230
}
12251231
}
@@ -1234,6 +1240,7 @@ kernel void kernel_mul_mat_q4_K_f32(
12341240
constant int64_t & ne10[[buffer(9)]],
12351241
constant int64_t & ne12[[buffer(11)]],
12361242
constant int64_t & ne0[[buffer(15)]],
1243+
constant int64_t & ne1[[buffer(16)]],
12371244
constant uint & gqa[[buffer(17)]],
12381245
uint3 tgpig[[threadgroup_position_in_grid]],
12391246
uint tiisg[[thread_index_in_simdgroup]],
@@ -1248,9 +1255,9 @@ kernel void kernel_mul_mat_q4_K_f32(
12481255
const int r2 = tgpig.z;
12491256
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
12501257
const int ib_row = first_row * nb;
1251-
const uint offset0 = r2/gqa*(ne02/QK_K);
1258+
const uint offset0 = r2/gqa*(nb*ne0);
12521259
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1253-
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
1260+
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
12541261
float yl[8];
12551262
float yh[8];
12561263
float sumf[N_DST]={0.f}, all_sum;
@@ -1306,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32(
13061313
for (int row = 0; row < N_DST; ++row) {
13071314
all_sum = simd_sum(sumf[row]);
13081315
if (tiisg == 0) {
1309-
dst[r1*ne0+ r2*ne12 + first_row + row] = all_sum;
1316+
dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
13101317
}
13111318
}
13121319
}
@@ -1322,6 +1329,7 @@ kernel void kernel_mul_mat_q5_K_f32(
13221329
constant int64_t & ne10[[buffer(9)]],
13231330
constant int64_t & ne12[[buffer(11)]],
13241331
constant int64_t & ne0[[buffer(15)]],
1332+
constant int64_t & ne1[[buffer(16)]],
13251333
constant uint & gqa[[buffer(17)]],
13261334
uint3 tgpig[[threadgroup_position_in_grid]],
13271335
uint tiisg[[thread_index_in_simdgroup]],
@@ -1334,9 +1342,9 @@ kernel void kernel_mul_mat_q5_K_f32(
13341342
const int r2 = tgpig.z;
13351343

13361344
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
1337-
const uint offset0 = r2/gqa*(ne02/QK_K);
1345+
const uint offset0 = r2/gqa*(nb*ne0);
13381346
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
1339-
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
1347+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
13401348

13411349
float sumf[2]={0.f};
13421350

@@ -1470,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32(
14701478
for (int row = 0; row < 2; ++row) {
14711479
const float tot = simd_sum(sumf[row]);
14721480
if (tiisg == 0) {
1473-
dst[r1*ne0 + r2*ne12 + first_row + row] = tot;
1481+
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
14741482
}
14751483
}
14761484

@@ -1486,6 +1494,7 @@ kernel void kernel_mul_mat_q6_K_f32(
14861494
constant int64_t & ne10[[buffer(9)]],
14871495
constant int64_t & ne12[[buffer(11)]],
14881496
constant int64_t & ne0[[buffer(15)]],
1497+
constant int64_t & ne1[[buffer(16)]],
14891498
constant uint & gqa[[buffer(17)]],
14901499
uint3 tgpig[[threadgroup_position_in_grid]],
14911500
uint tiisg[[thread_index_in_simdgroup]],
@@ -1503,9 +1512,9 @@ kernel void kernel_mul_mat_q6_K_f32(
15031512
const int r2 = tgpig.z;
15041513

15051514
const int row = 2 * r0 + sgitg;
1506-
const uint offset0 = r2/gqa*(ne02/QK_K);
1515+
const uint offset0 = r2/gqa*(nb*ne0);
15071516
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
1508-
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
1517+
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
15091518

15101519
float sumf = 0;
15111520

@@ -1571,7 +1580,7 @@ kernel void kernel_mul_mat_q6_K_f32(
15711580

15721581
const float tot = simd_sum(sumf);
15731582
if (tiisg == 0) {
1574-
dst[r1*ne0 + r2*ne12 + row] = tot;
1583+
dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
15751584
}
15761585
}
15771586

@@ -1835,7 +1844,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
18351844
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
18361845
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
18371846
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
1838-
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12;
1847+
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
18391848

18401849
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
18411850
//load data and store to threadgroup memory
@@ -1880,7 +1889,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
18801889

18811890
if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
18821891
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
1883-
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne12;
1892+
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
18841893
for (int i = 0; i < 8; i++) {
18851894
simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
18861895
}
@@ -1893,7 +1902,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
18931902
}
18941903

18951904
threadgroup_barrier(mem_flags::mem_threadgroup);
1896-
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne12;
1905+
device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
18971906
if (sgitg==0) {
18981907
for (int i = 0; i < n_rows; i++) {
18991908
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {

0 commit comments

Comments
 (0)