Skip to content

Commit 4d76a5f

Browse files
ikawrakowKawrakow
andauthored
Faster Q3_K implementation on Metal (#2307)
* Faster Q3_K on Metal * Additional Q3_K speedup on Metal * Q3_K for QK_K = 64 * Better Q3_K for QK_K = 64 21.6 ms/t -> 21.1 ms/t --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 0db14fe commit 4d76a5f

File tree

2 files changed

+127
-84
lines changed

2 files changed

+127
-84
lines changed

ggml-metal.m

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,8 @@ void ggml_metal_graph_compute(
685685
GGML_ASSERT(ne02 == 1);
686686
GGML_ASSERT(ne12 == 1);
687687

688-
nth0 = 4;
689-
nth1 = 16;
688+
nth0 = 2;
689+
nth1 = 32;
690690
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
691691
} break;
692692
case GGML_TYPE_Q4_K:
@@ -743,15 +743,18 @@ void ggml_metal_graph_compute(
743743
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
744744
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
745745
}
746+
else if (src0t == GGML_TYPE_Q3_K) {
747+
#ifdef GGML_QKK_64
748+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
749+
#else
750+
[encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
751+
#endif
752+
}
746753
else if (src0t == GGML_TYPE_Q5_K) {
747754
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
748755
}
749756
else if (src0t == GGML_TYPE_Q6_K) {
750757
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
751-
}
752-
else if (src0t == GGML_TYPE_Q3_K) {
753-
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
754-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
755758
} else {
756759
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
757760
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];

ggml-metal.metal

Lines changed: 118 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ kernel void kernel_rms_norm(
351351

352352
threadgroup_barrier(mem_flags::mem_threadgroup);
353353
// broadcast, simd group number is ntg / 32
354-
for (int i = ntg / 32 / 2; i > 0; i /= 2) {
354+
for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
355355
if (tpitg < i) {
356356
sum[tpitg] += sum[tpitg + i];
357357
}
@@ -1339,6 +1339,7 @@ kernel void kernel_mul_mat_q2_K_f32(
13391339
}
13401340
}
13411341

1342+
#if QK_K == 256
13421343
kernel void kernel_mul_mat_q3_K_f32(
13431344
device const void * src0,
13441345
device const float * src1,
@@ -1347,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
13471348
constant int64_t & ne10,
13481349
constant int64_t & ne0,
13491350
constant int64_t & ne1,
1350-
threadgroup float * sum [[threadgroup(0)]],
13511351
uint2 tgpig[[threadgroup_position_in_grid]],
1352-
uint2 tpitg[[thread_position_in_threadgroup]],
1353-
uint2 tptg[[threads_per_threadgroup]]) {
1352+
uint tiisg[[thread_index_in_simdgroup]],
1353+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
13541354

13551355
const int nb = ne00/QK_K;
13561356

13571357
const int64_t r0 = tgpig.x;
13581358
const int64_t r1 = tgpig.y;
13591359

1360-
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1361-
device const float * yy = (device const float *) src1 + r1*ne10;
1362-
1363-
const int nth = tptg.x*tptg.y;
1364-
const int ith = tptg.y*tpitg.x + tpitg.y;
1360+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
13651361

1366-
#if QK_K == 256
1362+
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1363+
device const float * yy = (device const float *) src1 + r1*ne10;
13671364

1368-
const uint8_t m3 = 3;
1369-
const int8_t m4 = 4;
1365+
float yl[16];
13701366

13711367
const uint16_t kmask1 = 0x0303;
13721368
const uint16_t kmask2 = 0x0f0f;
13731369

1374-
const int tid = tpitg.y; // expecting 16
1370+
const int tid = tiisg/2;
1371+
const int ix = tiisg%2;
13751372
const int ip = tid/8; // 0 or 1
13761373
const int il = tid/2 - 4*ip; // 0...3
13771374
const int ir = tid%2;
13781375
const int n = 8;
13791376
const int l0 = n*ir;
13801377

1381-
const uint8_t m = 1 << (4*ip + il);
1378+
const uint16_t m1 = 1 << (4*ip + il);
1379+
const uint16_t m2 = m1 << 8;
13821380

13831381
const int shift = 2*il;
1382+
const uint16_t qm1 = 0x0003 << shift;
1383+
const uint16_t qm2 = 0x0300 << shift;
1384+
const int32_t v1 = 4 << shift;
1385+
const int32_t v2 = 1024 << shift;
13841386

13851387
const uint16_t s_shift1 = 4*ip;
13861388
const uint16_t s_shift2 = s_shift1 + 2*(il/2);
@@ -1389,93 +1391,132 @@ kernel void kernel_mul_mat_q3_K_f32(
13891391
const int q_offset = 32*ip + l0;
13901392
const int y_offset = 128*ip + 32*il + l0;
13911393

1392-
//float sumf = 0;
1393-
float sumf1 = 0, sumf2 = 0;
1394-
for (int i = tpitg.x; i < nb; i += tptg.x) {
1394+
const int step = sizeof(block_q3_K) * nb / 2;
13951395

1396-
const float d_all = (float)(x[i].d);
1397-
1398-
device const uint8_t * q = x[i].qs + q_offset;
1399-
device const uint8_t * h = x[i].hmask + l0;
1400-
device const float * y = yy + i * QK_K + y_offset;
1396+
device const float * y1 = yy + ix*QK_K + y_offset;
14011397

1402-
device const uint16_t * a = (device const uint16_t *)x[i].scales;
1403-
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1398+
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
1399+
for (int i = ix; i < nb; i += 2) {
14041400

1405-
float s = 0;
1406-
for (int l = 0; l < n; ++l) {
1407-
s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4));
1401+
for (int l = 0; l < 8; ++l) {
1402+
yl[l+0] = y1[l+ 0];
1403+
yl[l+8] = y1[l+16];
14081404
}
1409-
float d = d_all * s;
1410-
sumf1 += d * scales[0];
1411-
sumf2 += d;
1412-
//sumf += d_all * s * (scales[0] - 32);
14131405

1414-
s = 0;
1415-
for (int l = 0; l < n; ++l) {
1416-
s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4));
1406+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1407+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1408+
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
1409+
device const half * dh = &x[i].d;
1410+
1411+
for (int row = 0; row < 2; ++row) {
1412+
1413+
const float d_all = (float)dh[0];
1414+
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
1415+
1416+
float s1 = 0, s2 = 0;
1417+
for (int l = 0; l < n; l += 2) {
1418+
const uint16_t qs = q[l/2];
1419+
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
1420+
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
1421+
}
1422+
float d = d_all * (s1 + 1.f/256.f * s2);
1423+
sumf1[row] += d * scales[0];
1424+
sumf2[row] += d;
1425+
1426+
s1 = s2 = 0;
1427+
for (int l = 0; l < n; l += 2) {
1428+
const uint16_t qs = q[l/2+8];
1429+
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
1430+
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
1431+
}
1432+
d = d_all * (s1 + 1.f/256.f * s2);
1433+
sumf1[row] += d * scales[1];
1434+
sumf2[row] += d;
1435+
1436+
q += step;
1437+
h += step;
1438+
a += step;
1439+
dh += step;
1440+
14171441
}
1418-
d = d_all * s;
1419-
sumf1 += d * scales[1];
1420-
sumf2 += d;
1421-
//sumf += d_all * s * (scales[1] - 32);
1442+
1443+
y1 += 2 * QK_K;
14221444

14231445
}
14241446

1425-
//sum[ith] = sumf;
1426-
sum[ith] = sumf1 - 32.f*sumf2;
1447+
for (int row = 0; row < 2; ++row) {
1448+
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
1449+
const float tot = simd_sum(sumf);
1450+
if (tiisg == 0) {
1451+
dst[r1*ne0 + first_row + row] = tot;
1452+
}
1453+
}
1454+
}
14271455
#else
1428-
const int il = 4 * tpitg.x; // 0, 4, 8, 12
1456+
kernel void kernel_mul_mat_q3_K_f32(
1457+
device const void * src0,
1458+
device const float * src1,
1459+
device float * dst,
1460+
constant int64_t & ne00,
1461+
constant int64_t & ne10,
1462+
constant int64_t & ne0,
1463+
constant int64_t & ne1,
1464+
uint2 tgpig[[threadgroup_position_in_grid]],
1465+
uint tiisg[[thread_index_in_simdgroup]],
1466+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
1467+
1468+
const int nb = ne00/QK_K;
1469+
1470+
const int64_t r0 = tgpig.x;
1471+
const int64_t r1 = tgpig.y;
1472+
1473+
const int row = 2 * r0 + sgitg;
1474+
1475+
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1476+
device const float * yy = (device const float *) src1 + r1*ne10;
1477+
const int ix = tiisg/4;
1478+
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
14291479
const int im = il/8; // 0, 0, 1, 1
14301480
const int in = il%8; // 0, 4, 0, 4
14311481

1432-
float sumf = 0;
1482+
float2 sum = {0.f, 0.f};
14331483

1434-
for (int i = tpitg.y; i < nb; i += tptg.y) {
1484+
for (int i = ix; i < nb; i += 8) {
14351485

14361486
const float d_all = (float)(x[i].d);
14371487

1438-
device const uint8_t * q = x[i].qs + il;
1439-
device const uint8_t * h = x[i].hmask + in;
1440-
device const float * y = yy + i * QK_K + il;
1441-
1442-
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1443-
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1444-
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1445-
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1446-
1447-
for (int l = 0; l < 4; ++l) {
1448-
const uint8_t hm = h[l] >> im;
1449-
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1450-
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1451-
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1452-
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1488+
device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1489+
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1490+
device const uint16_t * s = (device const uint16_t *)(x[i].scales);
1491+
device const float * y = yy + i * QK_K + il;
1492+
1493+
const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
1494+
const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
1495+
const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
1496+
const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
1497+
1498+
for (int l = 0; l < 4; l += 2) {
1499+
const uint16_t hm = h[l/2] >> im;
1500+
sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
1501+
+ y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
1502+
+ y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
1503+
+ y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
1504+
sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
1505+
+ y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
1506+
+ y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
1507+
+ y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
14531508
}
14541509

14551510
}
1511+
const float sumf = sum[0] + sum[1] * 1.f/256.f;
14561512

1457-
sum[ith] = sumf;
1458-
1459-
#endif
1460-
1461-
//
1462-
// Accumulate the sum from all threads in the threadgroup
1463-
//
1464-
threadgroup_barrier(mem_flags::mem_threadgroup);
1465-
if (ith%4 == 0) {
1466-
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
1467-
}
1468-
threadgroup_barrier(mem_flags::mem_threadgroup);
1469-
if (ith%16 == 0) {
1470-
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
1471-
}
1472-
threadgroup_barrier(mem_flags::mem_threadgroup);
1473-
if (ith == 0) {
1474-
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1475-
dst[r1*ne0 + r0] = sum[0];
1513+
const float tot = simd_sum(sumf);
1514+
if (tiisg == 0) {
1515+
dst[r1*ne0 + row] = tot;
14761516
}
14771517

14781518
}
1519+
#endif
14791520

14801521
#if QK_K == 256
14811522
kernel void kernel_mul_mat_q4_K_f32(
@@ -1773,7 +1814,6 @@ kernel void kernel_mul_mat_q5_K_f32(
17731814

17741815
for (int i = ix; i < nb; i += 8) {
17751816

1776-
float4 sumy = {0.f, 0.f, 0.f, 0.f};
17771817
for (int l = 0; l < 4; ++l) {
17781818
yl[l+0] = y[l+ 0];
17791819
yl[l+4] = y[l+16];

0 commit comments

Comments
 (0)