Skip to content

Commit 72ef3a7

Browse files
agray3Nexesenex
authored andcommitted
Update CUDA graph on scale change plus clear nodes/params (ggml-org#9550)
* Avoid using saved CUDA graph if scale changes and reset nodes/params on update Fixes ggml-org#9451 * clear before resize
1 parent 63c87b0 commit 72ef3a7

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2480,6 +2480,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
24802480
for (int i = 0; i < GGML_MAX_SRC; i++) {
24812481
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
24822482
}
2483+
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
24832484
}
24842485

24852486
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2511,6 +2512,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
25112512
return false;
25122513
}
25132514
}
2515+
2516+
if (node->op == GGML_OP_SCALE &&
2517+
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2518+
return false;
2519+
}
2520+
25142521
return true;
25152522
}
25162523

@@ -2721,7 +2728,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
27212728
// First call with null argument gets number of nodes in graph
27222729
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
27232730
// Subsequent call with non-null argument gets nodes
2731+
cuda_ctx->cuda_graph->nodes.clear();
27242732
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
2733+
cuda_ctx->cuda_graph->params.clear();
27252734
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
27262735
if (cuda_ctx->cuda_graph->num_nodes > 0) {
27272736
CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));

ggml/src/ggml-cuda/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,7 @@ struct ggml_graph_node_properties {
666666
int64_t ne[GGML_MAX_DIMS];
667667
size_t nb[GGML_MAX_DIMS];
668668
void * src_address[GGML_MAX_SRC];
669+
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
669670
};
670671

671672
struct ggml_cuda_graph {

0 commit comments

Comments
 (0)