Skip to content

Commit 1b494f8

Browse files
slarenarthw
authored andcommitted
cuda : fix defrag with quantized KV (ggml-org#9319)
1 parent 316fe3d commit 1b494f8

File tree

3 files changed

+40
-19
lines changed

3 files changed

+40
-19
lines changed

ggml/src/ggml-backend.c

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,11 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
11651165
}
11661166
}
11671167

1168+
if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) {
1169+
// since the tensor is pre-allocated, it cannot be moved to another backend
1170+
GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation");
1171+
}
1172+
11681173
// graph input
11691174
if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
11701175
cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU)
@@ -1644,7 +1649,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
16441649
sched->prev_leaf_backend_ids = tmp;
16451650
}
16461651

1647-
int graph_size = graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
1652+
int graph_size = MAX(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies;
16481653
if (sched->graph.size < graph_size) {
16491654
sched->graph.size = graph_size;
16501655
sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
@@ -1696,6 +1701,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
16961701
for (int c = 0; c < sched->n_copies; c++) {
16971702
struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
16981703
sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
1704+
assert(graph_copy->size > graph_copy->n_leafs);
16991705
graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
17001706
}
17011707
}
@@ -1709,6 +1715,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
17091715
for (int c = 0; c < sched->n_copies; c++) {
17101716
struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
17111717
sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
1718+
assert(graph_copy->size > graph_copy->n_leafs);
17121719
graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
17131720
}
17141721
}
@@ -1719,6 +1726,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
17191726
for (int i = 0; i < graph->n_leafs; i++) {
17201727
struct ggml_tensor * leaf = graph->leafs[i];
17211728
sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf);
1729+
assert(graph_copy->size > graph_copy->n_leafs);
17221730
graph_copy->leafs[graph_copy->n_leafs++] = leaf;
17231731
}
17241732
}

ggml/src/ggml-cuda.cu

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,8 +2572,15 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25722572
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
25732573
// store a pointer to each copy op CUDA kernel to identify it later
25742574
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2575-
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2576-
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2575+
if (!ptr) {
2576+
use_cuda_graph = false;
2577+
#ifndef NDEBUG
2578+
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
2579+
#endif
2580+
} else {
2581+
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2582+
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2583+
}
25772584
}
25782585
}
25792586

@@ -2842,6 +2849,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28422849
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
28432850
return true;
28442851
}
2852+
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
2853+
return true;
2854+
}
28452855
return false;
28462856
} break;
28472857
case GGML_OP_DUP:

ggml/src/ggml-cuda/cpy.cu

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
428428
char * src0_ddc = (char *) src0->data;
429429
char * src1_ddc = (char *) src1->data;
430430

431-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
431+
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
432+
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
433+
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
434+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
432435
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
433436
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
434437
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -449,9 +452,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
449452
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
450453
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
451454
} else {
452-
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
455+
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
453456
ggml_type_name(src0->type), ggml_type_name(src1->type));
454-
GGML_ABORT("fatal error");
455457
}
456458
}
457459

@@ -461,29 +463,30 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
461463
}
462464

463465
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
464-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
465-
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
466+
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
467+
return nullptr;
468+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
469+
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
466470
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
467-
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
471+
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
468472
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
469-
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
473+
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
470474
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
471-
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
475+
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
472476
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
473-
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
477+
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
474478
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
475-
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
479+
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
476480
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
477-
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
481+
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
478482
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
479-
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
483+
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
480484
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
481-
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
485+
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
482486
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
483-
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
487+
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
484488
} else {
485-
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
489+
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
486490
ggml_type_name(src0->type), ggml_type_name(src1->type));
487-
GGML_ABORT("fatal error");
488491
}
489492
}

0 commit comments

Comments
 (0)