Skip to content

Commit bb0993e

Browse files
dequantize_mul_mat_vec kernels for q5_1, q8_0, f16
1 parent 5a0ecf7 commit bb0993e

File tree

1 file changed

+81
-6
lines changed

1 file changed

+81
-6
lines changed

ggml-cuda.cu

+81-6
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,19 @@ typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs
3636
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
3737
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
3838

39+
// QK = number of values after dequantization
40+
// QR = QK / number of values before dequantization
41+
3942
#define QK4_0 32
43+
#define QR4_0 2
4044
typedef struct {
4145
float d; // delta
4246
uint8_t qs[QK4_0 / 2]; // nibbles / quants
4347
} block_q4_0;
4448
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
4549

4650
#define QK4_1 32
51+
#define QR4_1 2
4752
typedef struct {
4853
float d; // delta
4954
float m; // min
@@ -52,6 +57,7 @@ typedef struct {
5257
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
5358

5459
#define QK5_0 32
60+
#define QR5_0 2
5561
typedef struct {
5662
half d; // delta
5763
uint8_t qh[4]; // 5-th bit of quants
@@ -60,6 +66,7 @@ typedef struct {
6066
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
6167

6268
#define QK5_1 32
69+
#define QR5_1 2
6370
typedef struct {
6471
half d; // delta
6572
half m; // min
@@ -69,6 +76,7 @@ typedef struct {
6976
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
7077

7178
#define QK8_0 32
79+
#define QR8_0 1
7280
typedef struct {
7381
float d; // delta
7482
int8_t qs[QK8_0]; // quants
@@ -124,6 +132,44 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int
124132
v1 = x1*d;
125133
}
126134

135+
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
136+
const block_q5_1 * x = (const block_q5_1 *) vx;
137+
138+
const float d = x[ib].d;
139+
const float m = x[ib].m;
140+
141+
uint32_t qh;
142+
memcpy(&qh, x[ib].qh, sizeof(qh));
143+
144+
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
145+
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
146+
147+
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
148+
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
149+
150+
v0 = x0*d + m;
151+
v1 = x1*d + m;
152+
}
153+
154+
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
155+
const block_q8_0 * x = (const block_q8_0 *) vx;
156+
157+
const float d = x[ib].d;
158+
159+
const int8_t vi0 = x[ib].qs[iqs + 0];
160+
const int8_t vi1 = x[ib].qs[iqs + 1];
161+
162+
v0 = vi0*d;
163+
v1 = vi1*d;
164+
}
165+
166+
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
167+
const half * x = (const half *) vx;
168+
169+
v0 = __half2float(x[ib + 0]);
170+
v1 = __half2float(x[ib + 1]);
171+
}
172+
127173
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
128174
static const int qk = QK4_0;
129175

@@ -224,18 +270,20 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
224270
}
225271
}
226272

227-
template <int block_size, int qk, dequantize_kernel_t dequantize_kernel>
273+
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
228274
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
229275
const int row = blockIdx.x;
230276
const int tid = threadIdx.x;
231277

278+
const int y_offset = qr == 1 ? 1 : qk/2;
279+
232280
__shared__ float tmp[block_size]; // separate sum for each thread
233281
tmp[tid] = 0;
234282

235283
for (int i = 0; i < ncols/block_size; i += 2) {
236284
const int col = i*block_size + 2*tid;
237285
const int ib = (row*ncols + col)/qk; // block index
238-
const int iqs = (col%qk)/2; // quant index
286+
const int iqs = (col%qk)/qr; // quant index
239287
const int iybs = col - col%qk; // y block start index
240288

241289
// dequantize
@@ -244,7 +292,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
244292

245293
// matrix multiplication
246294
tmp[tid] += v0 * y[iybs + iqs + 0];
247-
tmp[tid] += v1 * y[iybs + iqs + qk/2];
295+
tmp[tid] += v1 * y[iybs + iqs + y_offset];
248296
}
249297

250298
// sum up partial sums and write back result
@@ -287,17 +335,32 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
287335

288336
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
289337
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
290-
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, dequantize_q4_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
338+
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
339+
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
291340
}
292341

293342
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
294343
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
295-
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, dequantize_q4_1><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
344+
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
345+
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
296346
}
297347

298348
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
299349
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
300-
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, dequantize_q5_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
350+
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
351+
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
352+
}
353+
354+
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
355+
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
356+
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
357+
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
358+
}
359+
360+
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
361+
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
362+
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
363+
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
301364
}
302365

303366
// TODO: optimize
@@ -313,6 +376,12 @@ static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStre
313376
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
314377
}
315378

379+
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
380+
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
381+
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
382+
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
383+
}
384+
316385
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
317386
switch (type) {
318387
case GGML_TYPE_Q4_0:
@@ -340,6 +409,12 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t
340409
return dequantize_mul_mat_vec_q4_1_cuda;
341410
case GGML_TYPE_Q5_0:
342411
return dequantize_mul_mat_vec_q5_0_cuda;
412+
case GGML_TYPE_Q5_1:
413+
return dequantize_mul_mat_vec_q5_1_cuda;
414+
case GGML_TYPE_Q8_0:
415+
return dequantize_mul_mat_vec_q8_0_cuda;
416+
case GGML_TYPE_F16:
417+
return dequantize_mul_mat_vec_q8_0_cuda;
343418
default:
344419
return nullptr;
345420
}

0 commit comments

Comments
 (0)