@@ -343,16 +343,16 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
343
343
// N_DST, so this is another explicit assumption of the implementation.
344
344
template <typename block_q_type, int nr, int nsg, int nw>
345
345
void mul_vec_q_n_f32 (device const void * src0, device const float * src1, device float * dst,
346
- int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa,
346
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
347
347
uint3 tgpig, uint tiisg, uint sgitg) {
348
348
const int nb = ne00/QK4_0;
349
349
const int r0 = tgpig.x ;
350
350
const int r1 = tgpig.y ;
351
351
const int im = tgpig.z ;
352
352
const int first_row = (r0 * nsg + sgitg) * nr;
353
- const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0 );
353
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0 );
354
354
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
355
- device const float * y = (device const float *) src1 + r1*ne10 + im*ne12 ;
355
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1 ;
356
356
float yl[16 ]; // src1 vector cache
357
357
float sumf[nr]={0 .f };
358
358
@@ -383,7 +383,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
383
383
for (int row = 0 ; row < nr; ++row) {
384
384
const float tot = simd_sum (sumf[row]);
385
385
if (tiisg == 0 && first_row + row < ne01) {
386
- dst[r1*ne0 + im*ne12 + first_row + row] = tot;
386
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
387
387
}
388
388
}
389
389
}
@@ -398,11 +398,12 @@ kernel void kernel_mul_mat_q4_0_f32(
398
398
constant int64_t & ne10[[buffer(9 )]],
399
399
constant int64_t & ne12[[buffer(11 )]],
400
400
constant int64_t & ne0[[buffer(15 )]],
401
+ constant int64_t & ne1[[buffer(16 )]],
401
402
constant uint & gqa[[buffer(17 )]],
402
403
uint3 tgpig[[threadgroup_position_in_grid]],
403
404
uint tiisg[[thread_index_in_simdgroup]],
404
405
uint sgitg[[simdgroup_index_in_threadgroup]]) {
405
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
406
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1, gqa,tgpig,tiisg,sgitg);
406
407
}
407
408
408
409
kernel void kernel_mul_mat_q4_1_f32 (
@@ -415,11 +416,12 @@ kernel void kernel_mul_mat_q4_1_f32(
415
416
constant int64_t & ne10[[buffer(9 )]],
416
417
constant int64_t & ne12[[buffer(11 )]],
417
418
constant int64_t & ne0[[buffer(15 )]],
419
+ constant int64_t & ne1[[buffer(16 )]],
418
420
constant uint & gqa[[buffer(17 )]],
419
421
uint3 tgpig[[threadgroup_position_in_grid]],
420
422
uint tiisg[[thread_index_in_simdgroup]],
421
423
uint sgitg[[simdgroup_index_in_threadgroup]]) {
422
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
424
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1, gqa,tgpig,tiisg,sgitg);
423
425
}
424
426
425
427
kernel void kernel_mul_mat_f16_f32 (
@@ -800,6 +802,7 @@ kernel void kernel_mul_mat_q2_K_f32(
800
802
constant int64_t & ne10[[buffer(9 )]],
801
803
constant int64_t & ne12[[buffer(11 )]],
802
804
constant int64_t & ne0[[buffer(15 )]],
805
+ constant int64_t & ne1[[buffer(16 )]],
803
806
constant uint & gqa[[buffer(17 )]],
804
807
uint3 tgpig[[threadgroup_position_in_grid]],
805
808
uint tiisg[[thread_index_in_simdgroup]],
@@ -812,9 +815,9 @@ kernel void kernel_mul_mat_q2_K_f32(
812
815
813
816
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
814
817
const int ib_row = first_row * nb;
815
- const uint offset0 = r2/gqa*(ne02/QK_K );
818
+ const uint offset0 = r2/gqa*(nb*ne0 );
816
819
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
817
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12 ;
820
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
818
821
float yl[32 ];
819
822
float sumf[N_DST]={0 .f }, all_sum;
820
823
@@ -927,7 +930,7 @@ kernel void kernel_mul_mat_q2_K_f32(
927
930
for (int row = 0 ; row < N_DST; ++row) {
928
931
all_sum = simd_sum (sumf[row]);
929
932
if (tiisg == 0 ) {
930
- dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum;
933
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
931
934
}
932
935
}
933
936
}
@@ -943,6 +946,7 @@ kernel void kernel_mul_mat_q3_K_f32(
943
946
constant int64_t & ne10[[buffer(9 )]],
944
947
constant int64_t & ne12[[buffer(11 )]],
945
948
constant int64_t & ne0[[buffer(15 )]],
949
+ constant int64_t & ne1[[buffer(16 )]],
946
950
constant uint & gqa[[buffer(17 )]],
947
951
uint3 tgpig[[threadgroup_position_in_grid]],
948
952
uint tiisg[[thread_index_in_simdgroup]],
@@ -955,9 +959,9 @@ kernel void kernel_mul_mat_q3_K_f32(
955
959
const int64_t r2 = tgpig.z ;
956
960
957
961
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2 ;
958
- const uint offset0 = r2/gqa*(ne02/QK_K );
962
+ const uint offset0 = r2/gqa*(nb*ne0 );
959
963
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
960
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02 ;
964
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
961
965
962
966
float yl[16 ];
963
967
@@ -1045,7 +1049,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1045
1049
const float sumf = (sumf1[row] - 32 .f *sumf2[row]) / (1 << shift);
1046
1050
const float tot = simd_sum (sumf);
1047
1051
if (tiisg == 0 ) {
1048
- dst[r1*ne0 + r2*ne12 + first_row + row] = tot;
1052
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1049
1053
}
1050
1054
}
1051
1055
}
@@ -1060,6 +1064,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1060
1064
constant int64_t & ne10[[buffer(9 )]],
1061
1065
constant int64_t & ne12[[buffer(11 )]],
1062
1066
constant int64_t & ne0[[buffer(15 )]],
1067
+ constant int64_t & ne1[[buffer(16 )]],
1063
1068
constant uint & gqa[[buffer(17 )]],
1064
1069
uint3 tgpig[[threadgroup_position_in_grid]],
1065
1070
uint tiisg[[thread_index_in_simdgroup]],
@@ -1072,9 +1077,9 @@ kernel void kernel_mul_mat_q3_K_f32(
1072
1077
const int64_t r2 = tgpig.z ;
1073
1078
1074
1079
const int row = 2 * r0 + sgitg;
1075
- const uint offset0 = r2/gqa*(ne02/QK_K );
1080
+ const uint offset0 = r2/gqa*(nb*ne0 );
1076
1081
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
1077
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02 ;
1082
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
1078
1083
const int ix = tiisg/4 ;
1079
1084
const int il = 4 * (tiisg%4 );// 0, 4, 8, 12
1080
1085
const int im = il/8 ; // 0, 0, 1, 1
@@ -1113,7 +1118,7 @@ kernel void kernel_mul_mat_q3_K_f32(
1113
1118
1114
1119
const float tot = simd_sum (sumf);
1115
1120
if (tiisg == 0 ) {
1116
- dst[r1*ne0 + r2*ne12 + row] = tot;
1121
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1117
1122
}
1118
1123
1119
1124
}
@@ -1130,6 +1135,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1130
1135
constant int64_t & ne10[[buffer(9 )]],
1131
1136
constant int64_t & ne12[[buffer(11 )]],
1132
1137
constant int64_t & ne0[[buffer(15 )]],
1138
+ constant int64_t & ne1[[buffer(16 )]],
1133
1139
constant uint & gqa[[buffer(17 )]],
1134
1140
uint3 tgpig[[threadgroup_position_in_grid]],
1135
1141
uint tiisg[[thread_index_in_simdgroup]],
@@ -1150,9 +1156,9 @@ kernel void kernel_mul_mat_q4_K_f32(
1150
1156
const int r2 = tgpig.z ;
1151
1157
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1152
1158
const int ib_row = first_row * nb;
1153
- const uint offset0 = r2/gqa*(ne02/QK_K );
1159
+ const uint offset0 = r2/gqa*(nb*ne0 );
1154
1160
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1155
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12 ;
1161
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
1156
1162
float yl[16 ];
1157
1163
float yh[16 ];
1158
1164
float sumf[N_DST]={0 .f }, all_sum;
@@ -1219,7 +1225,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1219
1225
for (int row = 0 ; row < N_DST; ++row) {
1220
1226
all_sum = simd_sum (sumf[row]);
1221
1227
if (tiisg == 0 ) {
1222
- dst[r1*ne0 + r2*ne12 + first_row + row] = all_sum;
1228
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
1223
1229
}
1224
1230
}
1225
1231
}
@@ -1234,6 +1240,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1234
1240
constant int64_t & ne10[[buffer(9 )]],
1235
1241
constant int64_t & ne12[[buffer(11 )]],
1236
1242
constant int64_t & ne0[[buffer(15 )]],
1243
+ constant int64_t & ne1[[buffer(16 )]],
1237
1244
constant uint & gqa[[buffer(17 )]],
1238
1245
uint3 tgpig[[threadgroup_position_in_grid]],
1239
1246
uint tiisg[[thread_index_in_simdgroup]],
@@ -1248,9 +1255,9 @@ kernel void kernel_mul_mat_q4_K_f32(
1248
1255
const int r2 = tgpig.z ;
1249
1256
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
1250
1257
const int ib_row = first_row * nb;
1251
- const uint offset0 = r2/gqa*(ne02/QK_K );
1258
+ const uint offset0 = r2/gqa*(nb*ne0 );
1252
1259
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
1253
- device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12 ;
1260
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
1254
1261
float yl[8 ];
1255
1262
float yh[8 ];
1256
1263
float sumf[N_DST]={0 .f }, all_sum;
@@ -1306,7 +1313,7 @@ kernel void kernel_mul_mat_q4_K_f32(
1306
1313
for (int row = 0 ; row < N_DST; ++row) {
1307
1314
all_sum = simd_sum (sumf[row]);
1308
1315
if (tiisg == 0 ) {
1309
- dst[r1*ne0+ r2*ne12 + first_row + row] = all_sum;
1316
+ dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
1310
1317
}
1311
1318
}
1312
1319
}
@@ -1322,6 +1329,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1322
1329
constant int64_t & ne10[[buffer(9 )]],
1323
1330
constant int64_t & ne12[[buffer(11 )]],
1324
1331
constant int64_t & ne0[[buffer(15 )]],
1332
+ constant int64_t & ne1[[buffer(16 )]],
1325
1333
constant uint & gqa[[buffer(17 )]],
1326
1334
uint3 tgpig[[threadgroup_position_in_grid]],
1327
1335
uint tiisg[[thread_index_in_simdgroup]],
@@ -1334,9 +1342,9 @@ kernel void kernel_mul_mat_q5_K_f32(
1334
1342
const int r2 = tgpig.z ;
1335
1343
1336
1344
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2 ;
1337
- const uint offset0 = r2/gqa*(ne02/QK_K );
1345
+ const uint offset0 = r2/gqa*(nb*ne0 );
1338
1346
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
1339
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12 ;
1347
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
1340
1348
1341
1349
float sumf[2 ]={0 .f };
1342
1350
@@ -1470,7 +1478,7 @@ kernel void kernel_mul_mat_q5_K_f32(
1470
1478
for (int row = 0 ; row < 2 ; ++row) {
1471
1479
const float tot = simd_sum (sumf[row]);
1472
1480
if (tiisg == 0 ) {
1473
- dst[r1*ne0 + r2*ne12 + first_row + row] = tot;
1481
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
1474
1482
}
1475
1483
}
1476
1484
@@ -1486,6 +1494,7 @@ kernel void kernel_mul_mat_q6_K_f32(
1486
1494
constant int64_t & ne10[[buffer(9 )]],
1487
1495
constant int64_t & ne12[[buffer(11 )]],
1488
1496
constant int64_t & ne0[[buffer(15 )]],
1497
+ constant int64_t & ne1[[buffer(16 )]],
1489
1498
constant uint & gqa[[buffer(17 )]],
1490
1499
uint3 tgpig[[threadgroup_position_in_grid]],
1491
1500
uint tiisg[[thread_index_in_simdgroup]],
@@ -1503,9 +1512,9 @@ kernel void kernel_mul_mat_q6_K_f32(
1503
1512
const int r2 = tgpig.z ;
1504
1513
1505
1514
const int row = 2 * r0 + sgitg;
1506
- const uint offset0 = r2/gqa*(ne02/QK_K );
1515
+ const uint offset0 = r2/gqa*(nb*ne0 );
1507
1516
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
1508
- device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12 ;
1517
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1 ;
1509
1518
1510
1519
float sumf = 0 ;
1511
1520
@@ -1571,7 +1580,7 @@ kernel void kernel_mul_mat_q6_K_f32(
1571
1580
1572
1581
const float tot = simd_sum (sumf);
1573
1582
if (tiisg == 0 ) {
1574
- dst[r1*ne0 + r2*ne12 + row] = tot;
1583
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
1575
1584
}
1576
1585
}
1577
1586
@@ -1835,7 +1844,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1835
1844
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
1836
1845
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
1837
1846
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
1838
- + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12 ;
1847
+ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1 ;
1839
1848
1840
1849
for (int loop_k = 0 ; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
1841
1850
// load data and store to threadgroup memory
@@ -1880,7 +1889,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1880
1889
1881
1890
if ((r0 + 1 ) * BLOCK_SIZE_M <= ne0 && (r1 + 1 ) * BLOCK_SIZE_N <= ne1) {
1882
1891
device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1 ) \
1883
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1 )) * ne0 + im*ne12 ;
1892
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1 )) * ne0 + im*ne1*ne0 ;
1884
1893
for (int i = 0 ; i < 8 ; i++) {
1885
1894
simdgroup_store (c_res[i], C + 8 * (i%4 ) + 8 * ne0 * (i/4 ), ne0);
1886
1895
}
@@ -1893,7 +1902,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
1893
1902
}
1894
1903
1895
1904
threadgroup_barrier (mem_flags::mem_threadgroup);
1896
- device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne12 ;
1905
+ device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0 ;
1897
1906
if (sgitg==0 ) {
1898
1907
for (int i = 0 ; i < n_rows; i++) {
1899
1908
for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
0 commit comments