@@ -127,8 +127,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
127
127
#define QR8_1 1
128
128
#define QI8_1 (QK8_1 / (4 * QR8_1))
129
129
typedef struct {
130
- half d; // delta
131
- half s; // unquantized sum
130
+ half2 ds; // ds.x = delta, ds.y = sum
132
131
int8_t qs[QK8_0]; // quants
133
132
} block_q8_1;
134
133
static_assert (sizeof (block_q8_1) == 2*sizeof(ggml_fp16_t ) + QK8_0, "wrong q8_1 block size/padding");
@@ -1258,8 +1257,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
1258
1257
return ;
1259
1258
}
1260
1259
1261
- y[ib].d = d;
1262
- y[ib].s = sum;
1260
+ y[ib].ds . x = d;
1261
+ y[ib].ds . y = sum;
1263
1262
}
1264
1263
1265
1264
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -1284,18 +1283,18 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
1284
1283
}
1285
1284
1286
1285
static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl (
1287
- const int & vi, const int & ui0, const int & ui1, const float & d4, const float & d8 ) {
1286
+ const int & vi, const int & ui0, const int & ui1, const half & d4, const half2 & ds8 ) {
1288
1287
1289
1288
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
1290
1289
// subtract 8 from each quantized value
1291
- const int vi0 = __vsub4 (( vi >> 0 ) & 0x0F0F0F0F , 0x08080808 ) ;
1292
- const int vi1 = __vsub4 (( vi >> 4 ) & 0x0F0F0F0F , 0x08080808 ) ;
1290
+ const int vi0 = ( vi >> 0 ) & 0x0F0F0F0F ;
1291
+ const int vi1 = ( vi >> 4 ) & 0x0F0F0F0F ;
1293
1292
1294
1293
// SIMD dot product of quantized values
1295
1294
int sumi = __dp4a (vi0, ui0, 0 );
1296
1295
sumi = __dp4a (vi1, ui1, sumi);
1297
1296
1298
- return sumi*d4*d8 ;
1297
+ return __half2float (d4) * ( sumi * __half2float (ds8. x ) - ( 8 /QI4_0) * __half2float (ds8. y )) ;
1299
1298
#else
1300
1299
return 0 .0f ; // only to satisfy the compiler
1301
1300
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1311,7 +1310,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
1311
1310
const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
1312
1311
const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI4_0)]);
1313
1312
1314
- return vec_dot_q4_0_q8_1_impl (vi, ui0, ui1, __half2float ( bq4_0->d ), __half2float ( bq8_1->d ) );
1313
+ return vec_dot_q4_0_q8_1_impl (vi, ui0, ui1, bq4_0->d , bq8_1->ds );
1315
1314
}
1316
1315
1317
1316
static __device__ __forceinline__ float vec_dot_q4_1_q8_1 (
@@ -1324,9 +1323,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
1324
1323
const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
1325
1324
const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI4_1)]);
1326
1325
1327
- const float d = __half2float (bq4_1->d ) * __half2float (bq8_1->d );
1326
+ const float d = __half2float (bq4_1->d ) * __half2float (bq8_1->ds . x );
1328
1327
const float m = bq4_1->m ;
1329
- const float s = bq8_1->s ;
1328
+ const float s = bq8_1->ds . y ;
1330
1329
1331
1330
const int vi0 = (vi >> 0 ) & 0x0F0F0F0F ;
1332
1331
const int vi1 = (vi >> 4 ) & 0x0F0F0F0F ;
@@ -1354,7 +1353,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
1354
1353
const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
1355
1354
const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI5_0)]);
1356
1355
1357
- const float d = __half2float (bq5_0->d ) * __half2float (bq8_1->d );
1356
+ const float d = __half2float (bq5_0->d ) * __half2float (bq8_1->ds . x );
1358
1357
1359
1358
int vi0 = (qs >> 0 ) & 0x0F0F0F0F ; // lower 4 qs bits, still need qh0 as 5th bits
1360
1359
vi0 |= (qh0 << 4 ) & 0x00000010 ; // 1 -> 5
@@ -1390,9 +1389,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
1390
1389
const int ui0 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
1391
1390
const int ui1 = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + QI5_1)]);
1392
1391
1393
- const float d = __half2float (bq5_1->d ) * __half2float (bq8_1->d );
1392
+ const float d = __half2float (bq5_1->d ) * __half2float (bq8_1->ds . x );
1394
1393
const float m = bq5_1->m ;
1395
- const float s = bq8_1->s ;
1394
+ const float s = bq8_1->ds . y ;
1396
1395
1397
1396
int vi0 = (qs >> 0 ) & 0x0F0F0F0F ; // lower 4 qs bits, still need qh0 as 5th bits
1398
1397
vi0 |= (qh0 << 4 ) & 0x00000010 ; // 1 -> 5
@@ -1424,7 +1423,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
1424
1423
memcpy (&vi, &bq8_0->qs [sizeof (int ) * (iqs + 0 )], sizeof (int ));
1425
1424
const int ui = *((int *) &bq8_1->qs [sizeof (int ) * (iqs + 0 )]);
1426
1425
1427
- const float d = __half2float (bq8_0->d ) * __half2float (bq8_1->d );
1426
+ const float d = __half2float (bq8_0->d ) * __half2float (bq8_1->ds . x );
1428
1427
1429
1428
// SIMD dot product of quantized values
1430
1429
int sumi = __dp4a (vi, ui, 0 );
@@ -1456,7 +1455,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
1456
1455
const int sc = bq2_K->scales [scale_offset + 2 *i];
1457
1456
1458
1457
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
1459
- const float d8i = bq8i->d ;
1458
+ const float d8i = bq8i->ds . x ;
1460
1459
1461
1460
const int vi = (v >> (2 *i)) & 0x03030303 ;
1462
1461
const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % QI8_1)]);
@@ -1507,7 +1506,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
1507
1506
1508
1507
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
1509
1508
const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % QI8_1)]);
1510
- const float d8i = bq8i->d ;
1509
+ const float d8i = bq8i->ds . x ;
1511
1510
1512
1511
const int vil = (vl >> (2 *i)) & 0x03030303 ;
1513
1512
@@ -1548,7 +1547,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
1548
1547
1549
1548
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
1550
1549
const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % QI8_1)]);
1551
- const float d8i = bq8i->d ;
1550
+ const float d8i = bq8i->ds . x ;
1552
1551
1553
1552
const int vi = (v >> (4 *i)) & 0x0F0F0F0F ;
1554
1553
@@ -1588,7 +1587,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
1588
1587
1589
1588
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
1590
1589
const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % QI8_1)]);
1591
- const float d8i = bq8i->d ;
1590
+ const float d8i = bq8i->ds . x ;
1592
1591
1593
1592
const int vil = (vl >> (4 *i)) & 0x0F0F0F0F ;
1594
1593
@@ -1631,7 +1630,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
1631
1630
1632
1631
const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2 *i;
1633
1632
const int ui = *((int *) &bq8i->qs [sizeof (int ) * (iqs % (QI8_1))]);
1634
- const float d8i = bq8i->d ;
1633
+ const float d8i = bq8i->ds . x ;
1635
1634
1636
1635
const int vil = (vl >> (4 *i)) & 0x0F0F0F0F ;
1637
1636
@@ -1673,7 +1672,7 @@ static __global__ void mul_mat_q(
1673
1672
__shared__ int tile_x_qs[WARP_SIZE][WARP_SIZE + 1 ];
1674
1673
__shared__ half tile_x_d[WARP_SIZE][WARP_SIZE/QI4_0];
1675
1674
__shared__ int tile_y_qs[WARP_SIZE][2 *WARP_SIZE];
1676
- __shared__ half tile_y_d [WARP_SIZE][2 *WARP_SIZE/QI8_1];
1675
+ __shared__ half2 tile_y_ds [WARP_SIZE][2 *WARP_SIZE/QI8_1];
1677
1676
float sum[4 ] = {0 .0f };
1678
1677
1679
1678
for (int ib0 = 0 ; ib0 < blocks_per_row; ib0 += blocks_per_warp) {
@@ -1694,12 +1693,12 @@ static __global__ void mul_mat_q(
1694
1693
const block_q8_1 * __restrict__ by0 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby0];
1695
1694
1696
1695
tile_y_qs[tid_y + i][tid_x] = *((int *) &by0->qs [iqsy]);
1697
- tile_y_d [tid_y + i][iby0] = by0->d ;
1696
+ tile_y_ds [tid_y + i][iby0] = by0->ds ;
1698
1697
1699
1698
const block_q8_1 * __restrict__ by1 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby1];
1700
1699
1701
1700
tile_y_qs[tid_y + i][tid_x + WARP_SIZE] = *((int *) &by1->qs [iqsy]);
1702
- tile_y_d [tid_y + i][iby1] = by1->d ;
1701
+ tile_y_ds [tid_y + i][iby1] = by1->ds ;
1703
1702
}
1704
1703
1705
1704
__syncthreads ();
@@ -1709,7 +1708,7 @@ static __global__ void mul_mat_q(
1709
1708
for (int j = 0 ; j < WARP_SIZE; j += 8 ) {
1710
1709
sum[j/8 ] += vec_dot_q4_0_q8_1_impl (
1711
1710
tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0 ], tile_y_qs[tid_y + j][iqsy + (QI8_1/2 )],
1712
- tile_x_d[tid_x][k / QI4_0], tile_y_d [tid_y + j][2 * k / QI8_1]);
1711
+ tile_x_d[tid_x][k / QI4_0], tile_y_ds [tid_y + j][2 * k / QI8_1]);
1713
1712
}
1714
1713
}
1715
1714
0 commit comments