Skip to content

Commit 99291ce

Browse files
q8_1 half2 ds
1 parent 8435636 commit 99291ce

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

ggml-cuda.cu

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
127127
#define QR8_1 1
128128
#define QI8_1 (QK8_1 / (4 * QR8_1))
129129
typedef struct {
130-
half d; // delta
131-
half s; // unquantized sum
130+
half2 ds; // ds.x = delta, ds.y = sum
132131
int8_t qs[QK8_0]; // quants
133132
} block_q8_1;
134133
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
@@ -1258,8 +1257,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
12581257
return;
12591258
}
12601259

1261-
y[ib].d = d;
1262-
y[ib].s = sum;
1260+
y[ib].ds.x = d;
1261+
y[ib].ds.y = sum;
12631262
}
12641263

12651264
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -1284,18 +1283,18 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
12841283
}
12851284

12861285
static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
1287-
const int & vi, const int & ui0, const int & ui1, const float & d4, const float & d8) {
1286+
const int & vi, const int & ui0, const int & ui1, const half & d4, const half2 & ds8) {
12881287

12891288
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
12901289
// subtract 8 from each quantized value
1291-
const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
1292-
const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
1290+
const int vi0 = (vi >> 0) & 0x0F0F0F0F;
1291+
const int vi1 = (vi >> 4) & 0x0F0F0F0F;
12931292

12941293
// SIMD dot product of quantized values
12951294
int sumi = __dp4a(vi0, ui0, 0);
12961295
sumi = __dp4a(vi1, ui1, sumi);
12971296

1298-
return sumi*d4*d8;
1297+
return __half2float(d4) * (sumi * __half2float(ds8.x) - (8/QI4_0) * __half2float(ds8.y));
12991298
#else
13001299
return 0.0f; // only to satisfy the compiler
13011300
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
@@ -1311,7 +1310,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
13111310
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
13121311
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
13131312

1314-
return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, __half2float(bq4_0->d), __half2float(bq8_1->d));
1313+
return vec_dot_q4_0_q8_1_impl(vi, ui0, ui1, bq4_0->d, bq8_1->ds);
13151314
}
13161315

13171316
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
@@ -1324,9 +1323,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
13241323
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
13251324
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
13261325

1327-
const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d);
1326+
const float d = __half2float(bq4_1->d) * __half2float(bq8_1->ds.x);
13281327
const float m = bq4_1->m;
1329-
const float s = bq8_1->s;
1328+
const float s = bq8_1->ds.y;
13301329

13311330
const int vi0 = (vi >> 0) & 0x0F0F0F0F;
13321331
const int vi1 = (vi >> 4) & 0x0F0F0F0F;
@@ -1354,7 +1353,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
13541353
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
13551354
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]);
13561355

1357-
const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d);
1356+
const float d = __half2float(bq5_0->d) * __half2float(bq8_1->ds.x);
13581357

13591358
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
13601359
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
@@ -1390,9 +1389,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
13901389
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
13911390
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
13921391

1393-
const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
1392+
const float d = __half2float(bq5_1->d) * __half2float(bq8_1->ds.x);
13941393
const float m = bq5_1->m;
1395-
const float s = bq8_1->s;
1394+
const float s = bq8_1->ds.y;
13961395

13971396
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
13981397
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
@@ -1424,7 +1423,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
14241423
memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
14251424
const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
14261425

1427-
const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d);
1426+
const float d = __half2float(bq8_0->d) * __half2float(bq8_1->ds.x);
14281427

14291428
// SIMD dot product of quantized values
14301429
int sumi = __dp4a(vi, ui, 0);
@@ -1456,7 +1455,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
14561455
const int sc = bq2_K->scales[scale_offset + 2*i];
14571456

14581457
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
1459-
const float d8i = bq8i->d;
1458+
const float d8i = bq8i->ds.x;
14601459

14611460
const int vi = (v >> (2*i)) & 0x03030303;
14621461
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
@@ -1507,7 +1506,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
15071506

15081507
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
15091508
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
1510-
const float d8i = bq8i->d;
1509+
const float d8i = bq8i->ds.x;
15111510

15121511
const int vil = (vl >> (2*i)) & 0x03030303;
15131512

@@ -1548,7 +1547,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
15481547

15491548
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
15501549
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
1551-
const float d8i = bq8i->d;
1550+
const float d8i = bq8i->ds.x;
15521551

15531552
const int vi = (v >> (4*i)) & 0x0F0F0F0F;
15541553

@@ -1588,7 +1587,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
15881587

15891588
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
15901589
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
1591-
const float d8i = bq8i->d;
1590+
const float d8i = bq8i->ds.x;
15921591

15931592
const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
15941593

@@ -1631,7 +1630,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
16311630

16321631
const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i;
16331632
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]);
1634-
const float d8i = bq8i->d;
1633+
const float d8i = bq8i->ds.x;
16351634

16361635
const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
16371636

@@ -1673,7 +1672,7 @@ static __global__ void mul_mat_q(
16731672
__shared__ int tile_x_qs[WARP_SIZE][WARP_SIZE + 1];
16741673
__shared__ half tile_x_d[WARP_SIZE][WARP_SIZE/QI4_0];
16751674
__shared__ int tile_y_qs[WARP_SIZE][2*WARP_SIZE];
1676-
__shared__ half tile_y_d[WARP_SIZE][2*WARP_SIZE/QI8_1];
1675+
__shared__ half2 tile_y_ds[WARP_SIZE][2*WARP_SIZE/QI8_1];
16771676
float sum[4] = {0.0f};
16781677

16791678
for (int ib0 = 0; ib0 < blocks_per_row; ib0 += blocks_per_warp) {
@@ -1694,12 +1693,12 @@ static __global__ void mul_mat_q(
16941693
const block_q8_1 * __restrict__ by0 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby0];
16951694

16961695
tile_y_qs[tid_y + i][tid_x] = *((int *) &by0->qs[iqsy]);
1697-
tile_y_d[tid_y + i][iby0] = by0->d;
1696+
tile_y_ds[tid_y + i][iby0] = by0->ds;
16981697

16991698
const block_q8_1 * __restrict__ by1 = &y[(col_y_0 + tid_y + i)*blocks_per_row + ib0 + iby1];
17001699

17011700
tile_y_qs[tid_y + i][tid_x + WARP_SIZE] = *((int *) &by1->qs[iqsy]);
1702-
tile_y_d[tid_y + i][iby1] = by1->d;
1701+
tile_y_ds[tid_y + i][iby1] = by1->ds;
17031702
}
17041703

17051704
__syncthreads();
@@ -1709,7 +1708,7 @@ static __global__ void mul_mat_q(
17091708
for (int j = 0; j < WARP_SIZE; j += 8) {
17101709
sum[j/8] += vec_dot_q4_0_q8_1_impl(
17111710
tile_x_qs[tid_x][k], tile_y_qs[tid_y + j][iqsy + 0], tile_y_qs[tid_y + j][iqsy + (QI8_1/2)],
1712-
tile_x_d[tid_x][k / QI4_0], tile_y_d[tid_y + j][2 * k / QI8_1]);
1711+
tile_x_d[tid_x][k / QI4_0], tile_y_ds[tid_y + j][2 * k / QI8_1]);
17131712
}
17141713
}
17151714

0 commit comments

Comments
 (0)