@@ -83,7 +83,8 @@ typedef struct {
83
83
} block_q8_0;
84
84
static_assert (sizeof (block_q8_0) == sizeof(float ) + QK8_0, "wrong q8_0 block size/padding");
85
85
86
- #define CUDA_DMMV_BLOCK_SIZE 32
86
+ #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
87
+ #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
87
88
88
89
static __device__ void dequantize_q4_0 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
89
90
const block_q4_0 * x = (const block_q4_0 *) vx;
@@ -170,104 +171,23 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
170
171
v1 = __half2float (x[ib + 1 ]);
171
172
}
172
173
173
- static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
174
- static const int qk = QK4_0;
174
+ template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
175
+ static __global__ void dequantize_block (const void * vx, float * y, const int k) {
176
+ const int i = blockDim .x *blockIdx .x + 2 *threadIdx .x ;
175
177
176
- const block_q4_0 * x = (const block_q4_0 *) vx;
177
-
178
- const int i = blockIdx .x ;
179
-
180
- const float d = x[i].d ;
181
-
182
- for (int j = 0 ; j < qk/2 ; ++j) {
183
- const int x0 = (x[i].qs [j] & 0xf ) - 8 ;
184
- const int x1 = (x[i].qs [j] >> 4 ) - 8 ;
185
-
186
- y[i*qk + j + 0 ] = x0*d;
187
- y[i*qk + j + qk/2 ] = x1*d;
188
- }
189
- }
190
-
191
- static __global__ void dequantize_block_q4_1 (const void * vx, float * y) {
192
- static const int qk = QK4_1;
193
-
194
- const block_q4_1 * x = (const block_q4_1 *) vx;
195
-
196
- const int i = blockIdx .x ;
197
-
198
- const float d = x[i].d ;
199
- const float m = x[i].m ;
200
-
201
- for (int j = 0 ; j < qk/2 ; ++j) {
202
- const int x0 = (x[i].qs [j] & 0xf );
203
- const int x1 = (x[i].qs [j] >> 4 );
204
-
205
- y[i*qk + j + 0 ] = x0*d + m;
206
- y[i*qk + j + qk/2 ] = x1*d + m;
207
- }
208
- }
209
-
210
- static __global__ void dequantize_block_q5_0 (const void * vx, float * y) {
211
- static const int qk = QK5_0;
212
-
213
- const block_q5_0 * x = (const block_q5_0 *) vx;
214
-
215
- const int i = blockIdx .x ;
216
-
217
- const float d = x[i].d ;
218
-
219
- uint32_t qh;
220
- memcpy (&qh, x[i].qh , sizeof (qh));
221
-
222
- for (int j = 0 ; j < qk/2 ; ++j) {
223
- const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
224
- const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
225
-
226
- const int32_t x0 = ((x[i].qs [j] & 0xf ) | xh_0) - 16 ;
227
- const int32_t x1 = ((x[i].qs [j] >> 4 ) | xh_1) - 16 ;
228
-
229
- y[i*qk + j + 0 ] = x0*d;
230
- y[i*qk + j + qk/2 ] = x1*d;
178
+ if (i >= k) {
179
+ return ;
231
180
}
232
- }
233
-
234
- static __global__ void dequantize_block_q5_1 (const void * vx, float * y) {
235
- static const int qk = QK5_1;
236
-
237
- const block_q5_1 * x = (const block_q5_1 *) vx;
238
-
239
- const int i = blockIdx .x ;
240
-
241
- const float d = x[i].d ;
242
- const float m = x[i].m ;
243
181
244
- uint32_t qh;
245
- memcpy (&qh, x[i].qh , sizeof (qh));
246
-
247
- for (int j = 0 ; j < qk/2 ; ++j) {
248
- const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
249
- const uint8_t xh_1 = ((qh >> (j + 12 )) ) & 0x10 ;
250
-
251
- const int x0 = (x[i].qs [j] & 0xf ) | xh_0;
252
- const int x1 = (x[i].qs [j] >> 4 ) | xh_1;
253
-
254
- y[i*qk + j + 0 ] = x0*d + m;
255
- y[i*qk + j + qk/2 ] = x1*d + m;
256
- }
257
- }
258
-
259
- static __global__ void dequantize_block_q8_0 (const void * vx, float * y) {
260
- static const int qk = QK8_0;
261
-
262
- const block_q8_0 * x = (const block_q8_0 *) vx;
263
-
264
- const int i = blockIdx .x ;
265
-
266
- const float d = x[i].d ;
182
+ const int ib = i/qk; // block index
183
+ const int iqs = (i%qk)/qr; // quant index
184
+ const int iybs = i - i%qk; // y block start index
185
+ const int y_offset = qr == 1 ? 1 : qk/2 ;
267
186
268
- for (int j = 0 ; j < qk; ++j) {
269
- y[i*qk + j] = x[i].qs [j]*d;
270
- }
187
+ // dequantize
188
+ float & v0 = y[iybs + iqs + 0 ];
189
+ float & v1 = y[iybs + iqs + y_offset];
190
+ dequantize_kernel (vx, ib, iqs, v0, v1);
271
191
}
272
192
273
193
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -308,29 +228,29 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
308
228
}
309
229
}
310
230
311
- static void dequantize_row_q4_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
312
- const int nb = k / QK4_0 ;
313
- dequantize_block_q4_0 <<<nb, 1 , 0 , stream>>> (vx, y);
231
+ static void dequantize_row_q4_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
232
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
233
+ dequantize_block<QK4_0, QR4_0, dequantize_q4_0> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
314
234
}
315
235
316
- static void dequantize_row_q4_1_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
317
- const int nb = k / QK4_1 ;
318
- dequantize_block_q4_1 <<<nb, 1 , 0 , stream>>> (vx, y);
236
+ static void dequantize_row_q4_1_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
237
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
238
+ dequantize_block<QK4_1, QR4_1, dequantize_q4_1> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
319
239
}
320
240
321
- static void dequantize_row_q5_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
322
- const int nb = k / QK5_0 ;
323
- dequantize_block_q5_0 <<<nb, 1 , 0 , stream>>> (vx, y);
241
+ static void dequantize_row_q5_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
242
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
243
+ dequantize_block<QK5_0, QR5_0, dequantize_q5_0> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
324
244
}
325
245
326
- static void dequantize_row_q5_1_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
327
- const int nb = k / QK5_1 ;
328
- dequantize_block_q5_1 <<<nb, 1 , 0 , stream>>> (vx, y);
246
+ static void dequantize_row_q5_1_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
247
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
248
+ dequantize_block<QK5_1, QR5_1, dequantize_q5_1> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
329
249
}
330
250
331
- static void dequantize_row_q8_0_cuda (const void * vx, float * y, int k, cudaStream_t stream) {
332
- const int nb = k / QK8_0 ;
333
- dequantize_block_q8_0 <<<nb, 1 , 0 , stream>>> (vx, y);
251
+ static void dequantize_row_q8_0_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
252
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE ;
253
+ dequantize_block<QK8_0, QR8_0, dequantize_q8_0> <<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE , 0 , stream>>> (vx, y, k );
334
254
}
335
255
336
256
static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -363,17 +283,9 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
363
283
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
364
284
}
365
285
366
- // TODO: optimize
367
- static __global__ void convert_fp16_to_fp32 (const void * vx, float * y) {
368
- const half * x = (const half *) vx;
369
-
370
- const int i = blockIdx .x ;
371
-
372
- y[i] = __half2float (x[i]);
373
- }
374
-
375
- static void convert_fp16_to_fp32_cuda (const void * x, float * y, int k, cudaStream_t stream) {
376
- convert_fp16_to_fp32<<<k, 1 , 0 , stream>>> (x, y);
286
+ static void convert_fp16_to_fp32_cuda (const void * vx, float * y, const int k, cudaStream_t stream) {
287
+ const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1 ) / CUDA_DEQUANTIZE_BLOCK_SIZE;
288
+ dequantize_block<32 , 1 , convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0 , stream>>> (vx, y, k);
377
289
}
378
290
379
291
static void convert_mul_mat_vec_f16_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
0 commit comments