@@ -81,6 +81,17 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
81
81
}
82
82
}
83
83
84
+ static __device__ void cpy_blck_q8_0_f32 (const char * cxi, char * cdsti) {
85
+ const block_q8_0 * xi = (const block_q8_0 *) cxi;
86
+ float * dsti = (float *) cdsti;
87
+
88
+ const float d = (float )xi->d ;
89
+
90
+ for (int j = 0 ; j < QK8_0; j++) {
91
+ dsti[j] = xi->qs [j] * d;
92
+ }
93
+ }
94
+
84
95
static __device__ void cpy_blck_f32_q4_0 (const char * cxi, char * cdsti) {
85
96
const float * xi = (const float *) cxi;
86
97
block_q4_0 * dsti = (block_q4_0 *) cdsti;
@@ -288,6 +299,32 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
288
299
cpy_blck (cx + x_offset, cdst + dst_offset);
289
300
}
290
301
302
+ template <cpy_kernel_t cpy_blck, int qk>
303
+ static __global__ void cpy_q_f32 (const char * cx, char * cdst, const int ne,
304
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
305
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
306
+ const int nb12, const int nb13) {
307
+ const int i = (blockDim .x *blockIdx .x + threadIdx .x )*qk;
308
+
309
+ if (i >= ne) {
310
+ return ;
311
+ }
312
+
313
+ const int i03 = i/(ne00 * ne01 * ne02);
314
+ const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
315
+ const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
316
+ const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
317
+ const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
318
+
319
+ const int i13 = i/(ne10 * ne11 * ne12);
320
+ const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
321
+ const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
322
+ const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
323
+ const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
324
+
325
+ cpy_blck (cx + x_offset, cdst + dst_offset);
326
+ }
327
+
291
328
static void ggml_cpy_f16_f32_cuda (
292
329
const char * cx, char * cdst, const int ne,
293
330
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -329,6 +366,16 @@ static void ggml_cpy_f32_q8_0_cuda(
329
366
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
330
367
}
331
368
369
+ static void ggml_cpy_q8_0_f32_cuda (
370
+ const char * cx, char * cdst, const int ne,
371
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
372
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
373
+
374
+ const int num_blocks = ne;
375
+ cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1 , 0 , stream>>>
376
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
377
+ }
378
+
332
379
static void ggml_cpy_f32_q4_0_cuda (
333
380
const char * cx, char * cdst, const int ne,
334
381
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -437,6 +484,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
437
484
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
438
485
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
439
486
ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
487
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
488
+ ggml_cpy_q8_0_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
440
489
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
441
490
ggml_cpy_f32_q4_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
442
491
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
@@ -471,6 +520,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
471
520
return (void *) cpy_f32_f16<cpy_1_f32_f16>;
472
521
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
473
522
return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
523
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
524
+ return (void *) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
474
525
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
475
526
return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
476
527
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
0 commit comments