@@ -428,7 +428,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
428
428
char * src0_ddc = (char *) src0->data ;
429
429
char * src1_ddc = (char *) src1->data ;
430
430
431
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
431
+ if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
432
+ GGML_ASSERT (ggml_nbytes (src0) == ggml_nbytes (src1));
433
+ CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
434
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
432
435
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
433
436
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
434
437
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -449,9 +452,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
449
452
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
450
453
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
451
454
} else {
452
- fprintf (stderr, " %s: unsupported type combination (%s to %s)\n " , __func__,
455
+ GGML_ABORT ( " %s: unsupported type combination (%s to %s)\n " , __func__,
453
456
ggml_type_name (src0->type ), ggml_type_name (src1->type ));
454
- GGML_ABORT (" fatal error" );
455
457
}
456
458
}
457
459
@@ -461,29 +463,30 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
461
463
}
462
464
463
465
void * ggml_cuda_cpy_fn (const ggml_tensor * src0, ggml_tensor * src1) {
464
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
465
- return (void *) cpy_f32_f16<cpy_1_f32_f32>;
466
+ if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
467
+ return nullptr ;
468
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
469
+ return (void *) cpy_f32_f16<cpy_1_f32_f32>;
466
470
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
467
- return (void *) cpy_f32_f16<cpy_1_f32_f16>;
471
+ return (void *) cpy_f32_f16<cpy_1_f32_f16>;
468
472
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
469
- return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
473
+ return (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
470
474
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
471
- return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
475
+ return (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
472
476
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
473
- return (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
477
+ return (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
474
478
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
475
- return (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
479
+ return (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
476
480
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
477
- return (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
481
+ return (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
478
482
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
479
- return (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
483
+ return (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
480
484
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
481
- return (void *) cpy_f32_f16<cpy_1_f32_f16>;
485
+ return (void *) cpy_f32_f16<cpy_1_f32_f16>;
482
486
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
483
- return (void *) cpy_f32_f16<cpy_1_f16_f32>;
487
+ return (void *) cpy_f32_f16<cpy_1_f16_f32>;
484
488
} else {
485
- fprintf (stderr, " %s: unsupported type combination (%s to %s)\n " , __func__,
489
+ GGML_ABORT ( " %s: unsupported type combination (%s to %s)\n " , __func__,
486
490
ggml_type_name (src0->type ), ggml_type_name (src1->type ));
487
- GGML_ABORT (" fatal error" );
488
491
}
489
492
}
0 commit comments