@@ -36,14 +36,19 @@ typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs
36
36
typedef void (*to_fp32_cuda_t )(const void * x, float * y, int k, cudaStream_t stream);
37
37
typedef void (*dequantize_mul_mat_vec_cuda_t )(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
38
38
39
+ // QK = number of values after dequantization
40
+ // QR = QK / number of values before dequantization
41
+
39
42
#define QK4_0 32
43
+ #define QR4_0 2
40
44
typedef struct {
41
45
float d; // delta
42
46
uint8_t qs[QK4_0 / 2 ]; // nibbles / quants
43
47
} block_q4_0;
44
48
static_assert (sizeof (block_q4_0) == sizeof(float ) + QK4_0 / 2, "wrong q4_0 block size/padding");
45
49
46
50
#define QK4_1 32
51
+ #define QR4_1 2
47
52
typedef struct {
48
53
float d; // delta
49
54
float m; // min
@@ -52,6 +57,7 @@ typedef struct {
52
57
static_assert (sizeof (block_q4_1) == sizeof(float ) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
53
58
54
59
#define QK5_0 32
60
+ #define QR5_0 2
55
61
typedef struct {
56
62
half d; // delta
57
63
uint8_t qh[4 ]; // 5-th bit of quants
@@ -60,6 +66,7 @@ typedef struct {
60
66
static_assert (sizeof (block_q5_0) == sizeof(ggml_fp16_t ) + sizeof(uint32_t ) + QK5_0 / 2, "wrong q5_0 block size/padding");
61
67
62
68
#define QK5_1 32
69
+ #define QR5_1 2
63
70
typedef struct {
64
71
half d; // delta
65
72
half m; // min
@@ -69,6 +76,7 @@ typedef struct {
69
76
static_assert (sizeof (block_q5_1) == 2 * sizeof(ggml_fp16_t ) + sizeof(uint32_t ) + QK5_1 / 2, "wrong q5_1 block size/padding");
70
77
71
78
#define QK8_0 32
79
+ #define QR8_0 1
72
80
typedef struct {
73
81
float d; // delta
74
82
int8_t qs[QK8_0]; // quants
@@ -124,6 +132,44 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int
124
132
v1 = x1*d;
125
133
}
126
134
135
+ static __device__ void dequantize_q5_1 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
136
+ const block_q5_1 * x = (const block_q5_1 *) vx;
137
+
138
+ const float d = x[ib].d ;
139
+ const float m = x[ib].m ;
140
+
141
+ uint32_t qh;
142
+ memcpy (&qh, x[ib].qh , sizeof (qh));
143
+
144
+ const uint8_t xh_0 = ((qh >> (iqs + 0 )) << 4 ) & 0x10 ;
145
+ const uint8_t xh_1 = ((qh >> (iqs + 12 )) ) & 0x10 ;
146
+
147
+ const int32_t x0 = ((x[ib].qs [iqs] & 0xf ) | xh_0);
148
+ const int32_t x1 = ((x[ib].qs [iqs] >> 4 ) | xh_1);
149
+
150
+ v0 = x0*d + m;
151
+ v1 = x1*d + m;
152
+ }
153
+
154
+ static __device__ void dequantize_q8_0 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
155
+ const block_q8_0 * x = (const block_q8_0 *) vx;
156
+
157
+ const float d = x[ib].d ;
158
+
159
+ const int8_t vi0 = x[ib].qs [iqs + 0 ];
160
+ const int8_t vi1 = x[ib].qs [iqs + 1 ];
161
+
162
+ v0 = vi0*d;
163
+ v1 = vi1*d;
164
+ }
165
+
166
+ static __device__ void convert_f16 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
167
+ const half * x = (const half *) vx;
168
+
169
+ v0 = __half2float (x[ib + 0 ]);
170
+ v1 = __half2float (x[ib + 1 ]);
171
+ }
172
+
127
173
static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
128
174
static const int qk = QK4_0;
129
175
@@ -224,18 +270,20 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
224
270
}
225
271
}
226
272
227
- template <int block_size, int qk, dequantize_kernel_t dequantize_kernel>
273
+ template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
228
274
static __global__ void dequantize_mul_mat_vec (const void * vx, const float * y, float * dst, const int ncols) {
229
275
const int row = blockIdx .x ;
230
276
const int tid = threadIdx .x ;
231
277
278
+ const int y_offset = qr == 1 ? 1 : qk/2 ;
279
+
232
280
__shared__ float tmp[block_size]; // separate sum for each thread
233
281
tmp[tid] = 0 ;
234
282
235
283
for (int i = 0 ; i < ncols/block_size; i += 2 ) {
236
284
const int col = i*block_size + 2 *tid;
237
285
const int ib = (row*ncols + col)/qk; // block index
238
- const int iqs = (col%qk)/2 ; // quant index
286
+ const int iqs = (col%qk)/qr ; // quant index
239
287
const int iybs = col - col%qk; // y block start index
240
288
241
289
// dequantize
@@ -244,7 +292,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
244
292
245
293
// matrix multiplication
246
294
tmp[tid] += v0 * y[iybs + iqs + 0 ];
247
- tmp[tid] += v1 * y[iybs + iqs + qk/ 2 ];
295
+ tmp[tid] += v1 * y[iybs + iqs + y_offset ];
248
296
}
249
297
250
298
// sum up partial sums and write back result
@@ -287,17 +335,32 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
287
335
288
336
static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
289
337
GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
290
- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, dequantize_q4_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
338
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
339
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
291
340
}
292
341
293
342
static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
294
343
GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
295
- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, dequantize_q4_1><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
344
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
345
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
296
346
}
297
347
298
348
static void dequantize_mul_mat_vec_q5_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
299
349
GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
300
- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, dequantize_q5_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
350
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
351
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
352
+ }
353
+
354
+ static void dequantize_mul_mat_vec_q5_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
355
+ GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
356
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
357
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
358
+ }
359
+
360
+ static void dequantize_mul_mat_vec_q8_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
361
+ GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
362
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
363
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
301
364
}
302
365
303
366
// TODO: optimize
@@ -313,6 +376,12 @@ static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStre
313
376
convert_fp16_to_fp32<<<k, 1 , 0 , stream>>> (x, y);
314
377
}
315
378
379
+ static void convert_mul_mat_vec_f16_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
380
+ GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
381
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32 , 1 , convert_f16>
382
+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
383
+ }
384
+
316
385
static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
317
386
switch (type) {
318
387
case GGML_TYPE_Q4_0:
@@ -340,6 +409,12 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t
340
409
return dequantize_mul_mat_vec_q4_1_cuda;
341
410
case GGML_TYPE_Q5_0:
342
411
return dequantize_mul_mat_vec_q5_0_cuda;
412
+ case GGML_TYPE_Q5_1:
413
+ return dequantize_mul_mat_vec_q5_1_cuda;
414
+ case GGML_TYPE_Q8_0:
415
+ return dequantize_mul_mat_vec_q8_0_cuda;
416
+ case GGML_TYPE_F16:
417
+ return dequantize_mul_mat_vec_q8_0_cuda;
343
418
default :
344
419
return nullptr ;
345
420
}
0 commit comments