Skip to content

Commit 971920e

Browse files
ggml_cuda_compute_forward
1 parent 071dcd3 commit 971920e

File tree

4 files changed

+65
-52
lines changed

4 files changed

+65
-52
lines changed

ggml-cuda.cu

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
862862
}
863863
}
864864

865+
bool ggml_cuda_can_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
866+
return src1->backend == GGML_BACKEND_CUDA;
867+
}
868+
865869
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
866870
GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
867871
ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul>(src0, src1, dst);
@@ -968,3 +972,34 @@ void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const
968972
free(buf_host);
969973
fclose(fp);
970974
}
975+
976+
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
977+
switch (tensor->op) {
978+
case GGML_OP_MUL:
979+
if (!ggml_cuda_can_mul(tensor->src0, tensor->src1, tensor)) {
980+
return false;
981+
}
982+
if (params->ith != 0) {
983+
return true;
984+
}
985+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
986+
return true;
987+
}
988+
ggml_cuda_mul(tensor->src0, tensor->src1, tensor);
989+
return true;
990+
case GGML_OP_MUL_MAT:
991+
if (!ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
992+
return false;
993+
}
994+
if (params->ith != 0) {
995+
return true;
996+
}
997+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
998+
return true;
999+
}
1000+
ggml_cuda_mul_mat(tensor->src0, tensor->src1, tensor, params->wdata, params->wsize);
1001+
return true;
1002+
default:
1003+
return false;
1004+
}
1005+
}

ggml-cuda.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ void * ggml_cuda_host_malloc(size_t size);
1616
void ggml_cuda_host_free(void * ptr);
1717

1818
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
19+
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
1920

2021
#ifdef __cplusplus
2122
}

ggml.c

Lines changed: 11 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3647,26 +3647,6 @@ struct ggml_context_container {
36473647
struct ggml_context context;
36483648
};
36493649

3650-
//
3651-
// compute types
3652-
//
3653-
3654-
enum ggml_task_type {
3655-
GGML_TASK_INIT = 0,
3656-
GGML_TASK_COMPUTE,
3657-
GGML_TASK_FINALIZE,
3658-
};
3659-
3660-
struct ggml_compute_params {
3661-
enum ggml_task_type type;
3662-
3663-
int ith, nth;
3664-
3665-
// work buffer for all threads
3666-
size_t wsize;
3667-
void * wdata;
3668-
};
3669-
36703650
//
36713651
// ggml state
36723652
//
@@ -8166,14 +8146,7 @@ static void ggml_compute_forward_mul_f32(
81668146
const int ith = params->ith;
81678147
const int nth = params->nth;
81688148

8169-
#ifdef GGML_USE_CUBLAS
8170-
if (src1->backend == GGML_BACKEND_CUDA) {
8171-
if (ith == 0) {
8172-
ggml_cuda_mul(src0, src1, dst);
8173-
}
8174-
return;
8175-
}
8176-
#elif defined(GGML_USE_CLBLAST)
8149+
#ifdef GGML_USE_CLBLAST
81778150
if (src1->backend == GGML_BACKEND_CL) {
81788151
if (ith == 0) {
81798152
ggml_cl_mul(src0, src1, dst);
@@ -9614,14 +9587,7 @@ static void ggml_compute_forward_mul_mat_f32(
96149587
// nb01 >= nb00 - src0 is not transposed
96159588
// compute by src0 rows
96169589

9617-
#if defined(GGML_USE_CUBLAS)
9618-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
9619-
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9620-
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
9621-
}
9622-
return;
9623-
}
9624-
#elif defined(GGML_USE_CLBLAST)
9590+
#if defined(GGML_USE_CLBLAST)
96259591
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
96269592
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
96279593
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -9786,14 +9752,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
97869752
// nb01 >= nb00 - src0 is not transposed
97879753
// compute by src0 rows
97889754

9789-
#if defined(GGML_USE_CUBLAS)
9790-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
9791-
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
9792-
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
9793-
}
9794-
return;
9795-
}
9796-
#elif defined(GGML_USE_CLBLAST)
9755+
#if defined(GGML_USE_CLBLAST)
97979756
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
97989757
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
97999758
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -9998,14 +9957,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
99989957
// nb01 >= nb00 - src0 is not transposed
99999958
// compute by src0 rows
100009959

10001-
#if defined(GGML_USE_CUBLAS)
10002-
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
10003-
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
10004-
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
10005-
}
10006-
return;
10007-
}
10008-
#elif defined(GGML_USE_CLBLAST)
9960+
#if defined(GGML_USE_CLBLAST)
100099961
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
100109962
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
100119963
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -12931,6 +12883,13 @@ static void ggml_compute_forward_map_binary(
1293112883
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
1293212884
GGML_ASSERT(params);
1293312885

12886+
#ifdef GGML_USE_CUBLAS
12887+
bool used_cuda = ggml_cuda_compute_forward(params, tensor);
12888+
if (used_cuda) {
12889+
return;
12890+
}
12891+
#endif // GGML_USE_CUBLAS
12892+
1293412893
switch (tensor->op) {
1293512894
case GGML_OP_DUP:
1293612895
{

ggml.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,24 @@ extern "C" {
413413
bool no_alloc; // don't allocate memory for the tensor data
414414
};
415415

416+
417+
// compute types
418+
enum ggml_task_type {
419+
GGML_TASK_INIT = 0,
420+
GGML_TASK_COMPUTE,
421+
GGML_TASK_FINALIZE,
422+
};
423+
424+
struct ggml_compute_params {
425+
enum ggml_task_type type;
426+
427+
int ith, nth;
428+
429+
// work buffer for all threads
430+
size_t wsize;
431+
void * wdata;
432+
};
433+
416434
// misc
417435

418436
GGML_API void ggml_time_init(void); // call this once at the beginning of the program

0 commit comments

Comments
 (0)