Skip to content

cuda : fix defrag with quantized KV #9319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ggml/src/ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,11 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
}
}

if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) {
// since the tensor is pre-allocated, it cannot be moved to another backend
GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation");
}

// graph input
if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
cur_backend_id = sched->n_backends - 1; // last backend (assumed CPU)
Expand Down Expand Up @@ -1644,7 +1649,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
sched->prev_leaf_backend_ids = tmp;
}

int graph_size = graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2;
int graph_size = MAX(graph->n_nodes, graph->n_leafs) + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2*sched->n_copies;
if (sched->graph.size < graph_size) {
sched->graph.size = graph_size;
sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *));
Expand Down Expand Up @@ -1696,6 +1701,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
for (int c = 0; c < sched->n_copies; c++) {
struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
assert(graph_copy->size > graph_copy->n_leafs);
graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
}
}
Expand All @@ -1709,6 +1715,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
for (int c = 0; c < sched->n_copies; c++) {
struct ggml_tensor * input_cpy = tensor_id_copy(id, backend_id, c);
sched->leaf_backend_ids[graph_copy->n_leafs] = backend_id;
assert(graph_copy->size > graph_copy->n_leafs);
graph_copy->leafs[graph_copy->n_leafs++] = input_cpy;
}
}
Expand All @@ -1719,6 +1726,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
for (int i = 0; i < graph->n_leafs; i++) {
struct ggml_tensor * leaf = graph->leafs[i];
sched->leaf_backend_ids[graph_copy->n_leafs] = tensor_backend_id(leaf);
assert(graph_copy->size > graph_copy->n_leafs);
graph_copy->leafs[graph_copy->n_leafs++] = leaf;
}
}
Expand Down
14 changes: 12 additions & 2 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2572,8 +2572,15 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_CUDA_LOG_WARN("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
} else {
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
}
}
}

Expand Down Expand Up @@ -2842,6 +2849,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
return true;
}
return false;
} break;
case GGML_OP_DUP:
Expand Down
35 changes: 19 additions & 16 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data;

if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
Expand All @@ -449,9 +452,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ABORT("fatal error");
}
}

Expand All @@ -461,29 +463,30 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
}

void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
return nullptr;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ABORT("fatal error");
}
}
Loading