@@ -32,7 +32,9 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
32
32
} \
33
33
} while (0 )
34
34
35
+ typedef void (*dequantize_kernel_t )(const void * vx, const int ib, const int iqs, float & v0, float & v1);
35
36
typedef void (*to_fp32_cuda_t )(const void * x, float * y, int k, cudaStream_t stream);
37
+ typedef void (*dequantize_mul_mat_vec_cuda_t )(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
36
38
37
39
#define QK4_0 32
38
40
typedef struct {
@@ -73,6 +75,37 @@ typedef struct {
73
75
} block_q8_0;
74
76
static_assert (sizeof (block_q8_0) == sizeof(float ) + QK8_0, "wrong q8_0 block size/padding");
75
77
78
+ #define CUDA_DMMV_BLOCK_SIZE 32
79
+
80
+ static __device__ void dequantize_q4_0 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
81
+ const block_q4_0 * x = (const block_q4_0 *) vx;
82
+
83
+ const float d = x[ib].d ;
84
+
85
+ const uint8_t vui = x[ib].qs [iqs];
86
+
87
+ const int8_t vi0 = vui & 0xF ;
88
+ const int8_t vi1 = vui >> 4 ;
89
+
90
+ v0 = (vi0 - 8 )*d;
91
+ v1 = (vi1 - 8 )*d;
92
+ }
93
+
94
+ static __device__ void dequantize_q4_1 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
95
+ const block_q4_1 * x = (const block_q4_1 *) vx;
96
+
97
+ const float d = x[ib].d ;
98
+ const float m = x[ib].m ;
99
+
100
+ const uint8_t vui = x[ib].qs [iqs];
101
+
102
+ const int8_t vi0 = vui & 0xF ;
103
+ const int8_t vi1 = vui >> 4 ;
104
+
105
+ v0 = vi0*d + m;
106
+ v1 = vi1*d + m;
107
+ }
108
+
76
109
static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
77
110
static const int qk = QK4_0;
78
111
@@ -173,10 +206,7 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
173
206
}
174
207
}
175
208
176
- template <int block_size> static __global__ void dequantize_mul_mat_q4_0 (const void * vx, const float * y, float * dst, const int ncols) {
177
- const block_q4_0 * x = (const block_q4_0 *) vx;
178
- const int qk = QK4_0;
179
-
209
+ template <int block_size, int qk, dequantize_kernel_t dequantize_kernel> static __global__ void dequantize_mul_mat_vec (const void * vx, const float * y, float * dst, const int ncols) {
180
210
const int row = blockIdx .x ;
181
211
const int tid = threadIdx .x ;
182
212
@@ -190,17 +220,8 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
190
220
const int iybs = col - col%qk; // y block start index
191
221
192
222
// dequantize
193
- const float d = x[ib].d ;
194
-
195
- const uint8_t * pp = x[ib].qs ;
196
-
197
- const uint8_t vui = pp[iqs];
198
-
199
- const int8_t vi0 = vui & 0xF ;
200
- const int8_t vi1 = vui >> 4 ;
201
-
202
- const float v0 = (vi0 - 8 )*d;
203
- const float v1 = (vi1 - 8 )*d;
223
+ float v0, v1;
224
+ dequantize_kernel (vx, ib, iqs, v0, v1);
204
225
205
226
// matrix multiplication
206
227
tmp[tid] += v0 * y[iybs + iqs + 0 ];
@@ -244,21 +265,14 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
244
265
dequantize_block_q8_0<<<nb, 1 , 0 , stream>>> (vx, y);
245
266
}
246
267
247
- static void dequantize_mul_mat_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
248
- // static int block_size = -1;
249
- // if (block_size == -1) {
250
- // int min_grid_size, max_block_size = 1;
251
- // CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0));
252
- // max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE);
253
- // block_size = 1;
254
- // while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) {
255
- // block_size *= 2;
256
- // }
257
- // }
258
- // dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
259
- const int block_size = 32 ;
260
- GGML_ASSERT (ncols % block_size == 0 );
261
- dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0 , stream>>> (vx, y, dst, ncols);
268
+ static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
269
+ GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
270
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, dequantize_q4_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
271
+ }
272
+
273
+ static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
274
+ GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
275
+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, dequantize_q4_1><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
262
276
}
263
277
264
278
// TODO: optimize
@@ -293,6 +307,17 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
293
307
}
294
308
}
295
309
310
+ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda (ggml_type type) {
311
+ switch (type) {
312
+ case GGML_TYPE_Q4_0:
313
+ return dequantize_mul_mat_vec_q4_0_cuda;
314
+ case GGML_TYPE_Q4_1:
315
+ return dequantize_mul_mat_vec_q4_1_cuda;
316
+ default :
317
+ return nullptr ;
318
+ }
319
+ }
320
+
296
321
// buffer pool for cuda
297
322
#define MAX_CUDA_BUFFERS 256
298
323
@@ -610,6 +635,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
610
635
char * d_Q = (char *) ggml_cuda_pool_malloc (n_mm * q_sz, &q_size);
611
636
612
637
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda (type);
638
+ dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda (type);
613
639
GGML_ASSERT (to_fp32_cuda != nullptr );
614
640
615
641
for (int64_t i03 = 0 ; i03 < ne03; i03++) {
@@ -641,7 +667,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
641
667
CUDA_CHECK (cudaStreamWaitEvent (cudaStream, cudaEvent, 0 ));
642
668
643
669
// compute
644
- dequantize_mul_mat_q4_0_cuda (c_Q, c_Y, c_D, ne00, ne01, cudaStream);
670
+ dmmv (c_Q, c_Y, c_D, ne00, ne01, cudaStream);
645
671
CUDA_CHECK (cudaGetLastError ());
646
672
647
673
} else {
0 commit comments