@@ -351,7 +351,7 @@ kernel void kernel_rms_norm(
351
351
352
352
threadgroup_barrier (mem_flags::mem_threadgroup);
353
353
// broadcast, simd group number is ntg / 32
354
- for (int i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
354
+ for (uint i = ntg / 32 / 2 ; i > 0 ; i /= 2 ) {
355
355
if (tpitg < i) {
356
356
sum[tpitg] += sum[tpitg + i];
357
357
}
@@ -1339,6 +1339,7 @@ kernel void kernel_mul_mat_q2_K_f32(
1339
1339
}
1340
1340
}
1341
1341
1342
+ #if QK_K == 256
1342
1343
kernel void kernel_mul_mat_q3_K_f32 (
1343
1344
device const void * src0,
1344
1345
device const float * src1,
@@ -1347,40 +1348,41 @@ kernel void kernel_mul_mat_q3_K_f32(
1347
1348
constant int64_t & ne10,
1348
1349
constant int64_t & ne0,
1349
1350
constant int64_t & ne1,
1350
- threadgroup float * sum [[threadgroup(0 )]],
1351
1351
uint2 tgpig[[threadgroup_position_in_grid]],
1352
- uint2 tpitg[[thread_position_in_threadgroup ]],
1353
- uint2 tptg[[threads_per_threadgroup ]]) {
1352
+ uint tiisg[[thread_index_in_simdgroup ]],
1353
+ uint sgitg[[simdgroup_index_in_threadgroup ]]) {
1354
1354
1355
1355
const int nb = ne00/QK_K;
1356
1356
1357
1357
const int64_t r0 = tgpig.x ;
1358
1358
const int64_t r1 = tgpig.y ;
1359
1359
1360
- device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1361
- device const float * yy = (device const float *) src1 + r1*ne10;
1362
-
1363
- const int nth = tptg.x *tptg.y ;
1364
- const int ith = tptg.y *tpitg.x + tpitg.y ;
1360
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2 ;
1365
1361
1366
- #if QK_K == 256
1362
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
1363
+ device const float * yy = (device const float *) src1 + r1*ne10;
1367
1364
1368
- const uint8_t m3 = 3 ;
1369
- const int8_t m4 = 4 ;
1365
+ float yl[16 ];
1370
1366
1371
1367
const uint16_t kmask1 = 0x0303 ;
1372
1368
const uint16_t kmask2 = 0x0f0f ;
1373
1369
1374
- const int tid = tpitg.y ; // expecting 16
1370
+ const int tid = tiisg/2 ;
1371
+ const int ix = tiisg%2 ;
1375
1372
const int ip = tid/8 ; // 0 or 1
1376
1373
const int il = tid/2 - 4 *ip; // 0...3
1377
1374
const int ir = tid%2 ;
1378
1375
const int n = 8 ;
1379
1376
const int l0 = n*ir;
1380
1377
1381
- const uint8_t m = 1 << (4 *ip + il);
1378
+ const uint16_t m1 = 1 << (4 *ip + il);
1379
+ const uint16_t m2 = m1 << 8 ;
1382
1380
1383
1381
const int shift = 2 *il;
1382
+ const uint16_t qm1 = 0x0003 << shift;
1383
+ const uint16_t qm2 = 0x0300 << shift;
1384
+ const int32_t v1 = 4 << shift;
1385
+ const int32_t v2 = 1024 << shift;
1384
1386
1385
1387
const uint16_t s_shift1 = 4 *ip;
1386
1388
const uint16_t s_shift2 = s_shift1 + 2 *(il/2 );
@@ -1389,93 +1391,132 @@ kernel void kernel_mul_mat_q3_K_f32(
1389
1391
const int q_offset = 32 *ip + l0;
1390
1392
const int y_offset = 128 *ip + 32 *il + l0;
1391
1393
1392
- // float sumf = 0;
1393
- float sumf1 = 0 , sumf2 = 0 ;
1394
- for (int i = tpitg.x ; i < nb; i += tptg.x ) {
1394
+ const int step = sizeof (block_q3_K) * nb / 2 ;
1395
1395
1396
- const float d_all = (float )(x[i].d );
1397
-
1398
- device const uint8_t * q = x[i].qs + q_offset;
1399
- device const uint8_t * h = x[i].hmask + l0;
1400
- device const float * y = yy + i * QK_K + y_offset;
1396
+ device const float * y1 = yy + ix*QK_K + y_offset;
1401
1397
1402
- device const uint16_t * a = (device const uint16_t *)x[i]. scales ;
1403
- const char2 scales = as_type<char2>(( uint16_t )(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4 )));
1398
+ float sumf1[ 2 ] = { 0 . f }, sumf2[ 2 ] = { 0 . f } ;
1399
+ for ( int i = ix; i < nb; i += 2 ) {
1404
1400
1405
- float s = 0 ;
1406
- for ( int l = 0 ; l < n; ++l) {
1407
- s += y [l+ 0 ] * (( int8_t )((q [l+ 0 ] >> shift) & m3) - ((h[l+ 0 ] & m) ? 0 : m4)) ;
1401
+ for ( int l = 0 ; l < 8 ; ++l) {
1402
+ yl[l+ 0 ] = y1 [l+ 0 ];
1403
+ yl [l+8 ] = y1 [l+16 ] ;
1408
1404
}
1409
- float d = d_all * s;
1410
- sumf1 += d * scales[0 ];
1411
- sumf2 += d;
1412
- // sumf += d_all * s * (scales[0] - 32);
1413
1405
1414
- s = 0 ;
1415
- for (int l = 0 ; l < n; ++l) {
1416
- s += y[l+16 ] * ((int8_t )((q[l+16 ] >> shift) & m3) - ((h[l+16 ] & m) ? 0 : m4));
1406
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
1407
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
1408
+ device const uint16_t * a = (device const uint16_t *)(x[i].scales );
1409
+ device const half * dh = &x[i].d ;
1410
+
1411
+ for (int row = 0 ; row < 2 ; ++row) {
1412
+
1413
+ const float d_all = (float )dh[0 ];
1414
+ const char2 scales = as_type<char2>((uint16_t )(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4 )));
1415
+
1416
+ float s1 = 0 , s2 = 0 ;
1417
+ for (int l = 0 ; l < n; l += 2 ) {
1418
+ const uint16_t qs = q[l/2 ];
1419
+ s1 += yl[l+0 ] * ((int32_t )(qs & qm1) - ((h[l/2 ] & m1) ? 0 : v1));
1420
+ s2 += yl[l+1 ] * ((int32_t )(qs & qm2) - ((h[l/2 ] & m2) ? 0 : v2));
1421
+ }
1422
+ float d = d_all * (s1 + 1 .f /256 .f * s2);
1423
+ sumf1[row] += d * scales[0 ];
1424
+ sumf2[row] += d;
1425
+
1426
+ s1 = s2 = 0 ;
1427
+ for (int l = 0 ; l < n; l += 2 ) {
1428
+ const uint16_t qs = q[l/2 +8 ];
1429
+ s1 += yl[l+8 ] * ((int32_t )(qs & qm1) - ((h[l/2 +8 ] & m1) ? 0 : v1));
1430
+ s2 += yl[l+9 ] * ((int32_t )(qs & qm2) - ((h[l/2 +8 ] & m2) ? 0 : v2));
1431
+ }
1432
+ d = d_all * (s1 + 1 .f /256 .f * s2);
1433
+ sumf1[row] += d * scales[1 ];
1434
+ sumf2[row] += d;
1435
+
1436
+ q += step;
1437
+ h += step;
1438
+ a += step;
1439
+ dh += step;
1440
+
1417
1441
}
1418
- d = d_all * s;
1419
- sumf1 += d * scales[1 ];
1420
- sumf2 += d;
1421
- // sumf += d_all * s * (scales[1] - 32);
1442
+
1443
+ y1 += 2 * QK_K;
1422
1444
1423
1445
}
1424
1446
1425
- // sum[ith] = sumf;
1426
- sum[ith] = sumf1 - 32 .f *sumf2;
1447
+ for (int row = 0 ; row < 2 ; ++row) {
1448
+ const float sumf = (sumf1[row] - 32 .f *sumf2[row]) / (1 << shift);
1449
+ const float tot = simd_sum (sumf);
1450
+ if (tiisg == 0 ) {
1451
+ dst[r1*ne0 + first_row + row] = tot;
1452
+ }
1453
+ }
1454
+ }
1427
1455
#else
1428
- const int il = 4 * tpitg.x ; // 0, 4, 8, 12
1456
+ kernel void kernel_mul_mat_q3_K_f32 (
1457
+ device const void * src0,
1458
+ device const float * src1,
1459
+ device float * dst,
1460
+ constant int64_t & ne00,
1461
+ constant int64_t & ne10,
1462
+ constant int64_t & ne0,
1463
+ constant int64_t & ne1,
1464
+ uint2 tgpig[[threadgroup_position_in_grid]],
1465
+ uint tiisg[[thread_index_in_simdgroup]],
1466
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
1467
+
1468
+ const int nb = ne00/QK_K;
1469
+
1470
+ const int64_t r0 = tgpig.x ;
1471
+ const int64_t r1 = tgpig.y ;
1472
+
1473
+ const int row = 2 * r0 + sgitg;
1474
+
1475
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
1476
+ device const float * yy = (device const float *) src1 + r1*ne10;
1477
+ const int ix = tiisg/4 ;
1478
+ const int il = 4 * (tiisg%4 );// 0, 4, 8, 12
1429
1479
const int im = il/8 ; // 0, 0, 1, 1
1430
1480
const int in = il%8 ; // 0, 4, 0, 4
1431
1481
1432
- float sumf = 0 ;
1482
+ float2 sum = { 0 . f , 0 . f } ;
1433
1483
1434
- for (int i = tpitg. y ; i < nb; i += tptg. y ) {
1484
+ for (int i = ix ; i < nb; i += 8 ) {
1435
1485
1436
1486
const float d_all = (float )(x[i].d );
1437
1487
1438
- device const uint8_t * q = x[i].qs + il;
1439
- device const uint8_t * h = x[i].hmask + in;
1440
- device const float * y = yy + i * QK_K + il;
1441
-
1442
- const float d1 = d_all * ((x[i].scales [0 ] & 0xF ) - 8 );
1443
- const float d2 = d_all * ((x[i].scales [0 ] >> 4 ) - 8 );
1444
- const float d3 = d_all * ((x[i].scales [1 ] & 0xF ) - 8 );
1445
- const float d4 = d_all * ((x[i].scales [1 ] >> 4 ) - 8 );
1446
-
1447
- for (int l = 0 ; l < 4 ; ++l) {
1448
- const uint8_t hm = h[l] >> im;
1449
- sumf += y[l+ 0 ] * d1 * ((int8_t )((q[l+0 ] >> 0 ) & 3 ) - ((hm & 0x01 ) ? 0 : 4 ))
1450
- + y[l+16 ] * d2 * ((int8_t )((q[l+0 ] >> 2 ) & 3 ) - ((hm & 0x04 ) ? 0 : 4 ))
1451
- + y[l+32 ] * d3 * ((int8_t )((q[l+0 ] >> 4 ) & 3 ) - ((hm & 0x10 ) ? 0 : 4 ))
1452
- + y[l+48 ] * d4 * ((int8_t )((q[l+0 ] >> 6 ) & 3 ) - ((hm & 0x40 ) ? 0 : 4 ));
1488
+ device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
1489
+ device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
1490
+ device const uint16_t * s = (device const uint16_t *)(x[i].scales );
1491
+ device const float * y = yy + i * QK_K + il;
1492
+
1493
+ const float d1 = d_all * ((int32_t )(s[0 ] & 0x000F ) - 8 );
1494
+ const float d2 = d_all * ((int32_t )(s[0 ] & 0x00F0 ) - 128 ) * 1 .f /64 .f ;
1495
+ const float d3 = d_all * ((int32_t )(s[0 ] & 0x0F00 ) - 2048 ) * 1 .f /4096 .f ;
1496
+ const float d4 = d_all * ((int32_t )(s[0 ] & 0xF000 ) - 32768 ) * 1 .f /262144 .f ;
1497
+
1498
+ for (int l = 0 ; l < 4 ; l += 2 ) {
1499
+ const uint16_t hm = h[l/2 ] >> im;
1500
+ sum[0 ] += y[l+ 0 ] * d1 * ((int32_t )(q[l/2 ] & 0x0003 ) - ((hm & 0x0001 ) ? 0 : 4 ))
1501
+ + y[l+16 ] * d2 * ((int32_t )(q[l/2 ] & 0x000c ) - ((hm & 0x0004 ) ? 0 : 16 ))
1502
+ + y[l+32 ] * d3 * ((int32_t )(q[l/2 ] & 0x0030 ) - ((hm & 0x0010 ) ? 0 : 64 ))
1503
+ + y[l+48 ] * d4 * ((int32_t )(q[l/2 ] & 0x00c0 ) - ((hm & 0x0040 ) ? 0 : 256 ));
1504
+ sum[1 ] += y[l+ 1 ] * d1 * ((int32_t )(q[l/2 ] & 0x0300 ) - ((hm & 0x0100 ) ? 0 : 1024 ))
1505
+ + y[l+17 ] * d2 * ((int32_t )(q[l/2 ] & 0x0c00 ) - ((hm & 0x0400 ) ? 0 : 4096 ))
1506
+ + y[l+33 ] * d3 * ((int32_t )(q[l/2 ] & 0x3000 ) - ((hm & 0x1000 ) ? 0 : 16384 ))
1507
+ + y[l+49 ] * d4 * ((int32_t )(q[l/2 ] & 0xc000 ) - ((hm & 0x4000 ) ? 0 : 65536 ));
1453
1508
}
1454
1509
1455
1510
}
1511
+ const float sumf = sum[0 ] + sum[1 ] * 1 .f /256 .f ;
1456
1512
1457
- sum[ith] = sumf;
1458
-
1459
- #endif
1460
-
1461
- //
1462
- // Accumulate the sum from all threads in the threadgroup
1463
- //
1464
- threadgroup_barrier (mem_flags::mem_threadgroup);
1465
- if (ith%4 == 0 ) {
1466
- for (int i = 1 ; i < 4 ; ++i) sum[ith] += sum[ith + i];
1467
- }
1468
- threadgroup_barrier (mem_flags::mem_threadgroup);
1469
- if (ith%16 == 0 ) {
1470
- for (int i = 4 ; i < 16 ; i += 4 ) sum[ith] += sum[ith + i];
1471
- }
1472
- threadgroup_barrier (mem_flags::mem_threadgroup);
1473
- if (ith == 0 ) {
1474
- for (int i = 16 ; i < nth; i += 16 ) sum[0 ] += sum[i];
1475
- dst[r1*ne0 + r0] = sum[0 ];
1513
+ const float tot = simd_sum (sumf);
1514
+ if (tiisg == 0 ) {
1515
+ dst[r1*ne0 + row] = tot;
1476
1516
}
1477
1517
1478
1518
}
1519
+ #endif
1479
1520
1480
1521
#if QK_K == 256
1481
1522
kernel void kernel_mul_mat_q4_K_f32 (
@@ -1773,7 +1814,6 @@ kernel void kernel_mul_mat_q5_K_f32(
1773
1814
1774
1815
for (int i = ix; i < nb; i += 8 ) {
1775
1816
1776
- float4 sumy = {0 .f , 0 .f , 0 .f , 0 .f };
1777
1817
for (int l = 0 ; l < 4 ; ++l) {
1778
1818
yl[l+0 ] = y[l+ 0 ];
1779
1819
yl[l+4 ] = y[l+16 ];
0 commit comments