@@ -83,10 +83,16 @@ typedef struct {
83
83
} block_q8_0;
84
84
static_assert (sizeof (block_q8_0) == sizeof(ggml_fp16_t ) + QK8_0, "wrong q8_0 block size/padding");
85
85
86
+ #define WARP_SIZE 32
87
+
86
88
#define CUDA_MUL_BLOCK_SIZE 256
89
+
87
90
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
91
+
88
92
// dmmv = dequantize_mul_mat_vec
89
- #define GGML_CUDA_DMMV_BLOCK_X 32
93
+ #ifndef GGML_CUDA_DMMV_BLOCK_X
94
+ #define GGML_CUDA_DMMV_BLOCK_X 32 // can by set by compiler option LLAMA_CUDA_BY
95
+ #endif
90
96
#ifndef GGML_CUDA_DMMV_BLOCK_Y
91
97
#define GGML_CUDA_DMMV_BLOCK_Y 1 // can by set by compiler option LLAMA_CUDA_BY
92
98
#endif
@@ -204,32 +210,40 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
204
210
dequantize_kernel (vx, ib, iqs, v0, v1);
205
211
}
206
212
207
- template <int ncols, int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
208
- static __global__ void dequantize_mul_mat_vec (const void * vx, const float * y, float * dst) {
213
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
214
+ static __global__ void dequantize_mul_mat_vec (const void * vx, const float * y, float * dst, const int ncols) {
215
+ // qk = quantized weights per x block
216
+ // qr = number of quantized weights per data value in x block
209
217
const int row = blockIdx .x *blockDim .y + threadIdx .y ;
210
218
const int tid = threadIdx .x ;
211
219
220
+ const int iter_stride = 2 *GGML_CUDA_DMMV_BLOCK_X;
221
+ const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
212
222
const int y_offset = qr == 1 ? 1 : qk/2 ;
213
223
214
-
215
224
float tmp = 0 ; // partial sum for thread in warp
216
225
217
- #ifdef GGML_CUDA_UNROLL
218
- #pragma unroll
219
- #endif
220
- for (int i = 0 ; i < ncols/block_size; i += 2 ) {
221
- const int col = i*block_size + 2 *tid;
222
- const int ib = (row*ncols + col)/qk; // block index
223
- const int iqs = (col%qk)/qr; // quant index
226
+ for (int i = 0 ; i < ncols; i += iter_stride) {
227
+ const int col = i + vals_per_iter*tid;
228
+ const int ib = (row*ncols + col)/qk; // x block index
229
+ const int iqs = (col%qk)/qr; // x quant index
224
230
const int iybs = col - col%qk; // y block start index
225
231
226
- // dequantize
227
- float v0, v1;
228
- dequantize_kernel (vx, ib, iqs, v0, v1);
229
-
230
- // matrix multiplication
231
- tmp += v0 * y[iybs + iqs + 0 ];
232
- tmp += v1 * y[iybs + iqs + y_offset];
232
+ // processing >2 values per i iter is faster for fast GPUs
233
+ #pragma unroll
234
+ for (int j = 0 ; j < vals_per_iter; j += 2 ) {
235
+ // process 2 vals per j iter
236
+
237
+ // dequantize
238
+ float v0, v1;
239
+ dequantize_kernel (vx, ib, iqs + j/qr, v0, v1);
240
+ // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
241
+
242
+ // matrix multiplication
243
+ tmp += v0 * y[iybs + iqs + j/qr + 0 ];
244
+ tmp += v1 * y[iybs + iqs + j/qr + y_offset];
245
+ // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
246
+ }
233
247
}
234
248
235
249
// sum up partial sums and write back result
@@ -274,72 +288,44 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
274
288
dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
275
289
}
276
290
277
- template <dequantize_kernel_t dequantize_kernel, int qk, int qr>
278
- static void dequantize_mul_mat_vec_cuda (const void * vx, const float * y, float * dst,
279
- const int ncols, const int nrows, cudaStream_t stream) {
291
+ static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
280
292
GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
281
293
GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
282
- const dim3 block_dims (GGML_CUDA_DMMV_BLOCK_X, GGML_CUDA_DMMV_BLOCK_Y, 1 );
283
-
284
- // Use a switch statement for ncols so the compiler can unroll all loops:
285
- switch (ncols) {
286
- case 4096 :
287
- dequantize_mul_mat_vec<4096 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
288
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
289
- break ;
290
- case 5120 :
291
- dequantize_mul_mat_vec<5120 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
292
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
293
- break ;
294
- case 6656 :
295
- dequantize_mul_mat_vec<6656 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
296
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
297
- break ;
298
- case 8192 :
299
- dequantize_mul_mat_vec<8192 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
300
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
301
- break ;
302
- case 11008 :
303
- dequantize_mul_mat_vec<11008 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
304
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
305
- break ;
306
- case 13824 :
307
- dequantize_mul_mat_vec<13824 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
308
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
309
- break ;
310
- case 17920 :
311
- dequantize_mul_mat_vec<17920 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
312
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
313
- break ;
314
- case 22016 :
315
- dequantize_mul_mat_vec<22016 , GGML_CUDA_DMMV_BLOCK_X, qk, qr, dequantize_kernel>
316
- <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst);
317
- break ;
318
- default :
319
- fprintf (stderr, " Tell the devs to add a switch case for this: ncols=%d\n " , ncols);
320
- GGML_ASSERT (false );
321
- break ;
322
- }
323
- }
324
-
325
- static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
326
- dequantize_mul_mat_vec_cuda<dequantize_q4_0, QK4_0, QR4_0>(vx, y, dst, ncols, nrows, stream);
294
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
295
+ dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
296
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
327
297
}
328
298
329
299
static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
330
- dequantize_mul_mat_vec_cuda<dequantize_q4_1, QK4_1, QR4_1>(vx, y, dst, ncols, nrows, stream);
300
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
301
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
302
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
303
+ dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
304
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
331
305
}
332
306
333
307
static void dequantize_mul_mat_vec_q5_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
334
- dequantize_mul_mat_vec_cuda<dequantize_q5_0, QK5_0, QR5_0>(vx, y, dst, ncols, nrows, stream);
308
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
309
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
310
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
311
+ dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
312
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
335
313
}
336
314
337
315
static void dequantize_mul_mat_vec_q5_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
338
- dequantize_mul_mat_vec_cuda<dequantize_q5_1, QK5_1, QR5_1>(vx, y, dst, ncols, nrows, stream);
316
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
317
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
318
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
319
+ dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
320
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
339
321
}
340
322
341
323
static void dequantize_mul_mat_vec_q8_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
342
- dequantize_mul_mat_vec_cuda<dequantize_q8_0, QK8_0, QR8_0>(vx, y, dst, ncols, nrows, stream);
324
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
325
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
326
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
327
+ dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
328
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
343
329
}
344
330
345
331
static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -348,7 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
348
334
}
349
335
350
336
static void convert_mul_mat_vec_f16_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
351
- dequantize_mul_mat_vec_cuda<convert_f16, 1 , 1 >(vx, y, dst, ncols, nrows, stream);
337
+ GGML_ASSERT (ncols % GGML_CUDA_DMMV_BLOCK_X == 0 );
338
+ GGML_ASSERT (nrows % GGML_CUDA_DMMV_BLOCK_Y == 0 );
339
+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_DMMV_BLOCK_Y, 1 );
340
+ dequantize_mul_mat_vec<1 , 1 , convert_f16>
341
+ <<<nrows/GGML_CUDA_DMMV_BLOCK_Y, block_dims, 0 , stream>>> (vx, y, dst, ncols);
352
342
}
353
343
354
344
static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
0 commit comments