@@ -2480,6 +2480,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
2480
2480
for (int i = 0 ; i < GGML_MAX_SRC; i++) {
2481
2481
graph_node_properties->src_address [i] = node->src [i] ? node->src [i]->data : nullptr ;
2482
2482
}
2483
+ memcpy (graph_node_properties->op_params , node->op_params , GGML_MAX_OP_PARAMS);
2483
2484
}
2484
2485
2485
2486
static bool ggml_graph_node_has_matching_properties (ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -2511,6 +2512,12 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
2511
2512
return false ;
2512
2513
}
2513
2514
}
2515
+
2516
+ if (node->op == GGML_OP_SCALE &&
2517
+ memcmp (graph_node_properties->op_params , node->op_params , GGML_MAX_OP_PARAMS) != 0 ) {
2518
+ return false ;
2519
+ }
2520
+
2514
2521
return true ;
2515
2522
}
2516
2523
@@ -2721,7 +2728,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2721
2728
// First call with null argument gets number of nodes in graph
2722
2729
CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , nullptr , &cuda_ctx->cuda_graph ->num_nodes ));
2723
2730
// Subsequent call with non-null argument gets nodes
2731
+ cuda_ctx->cuda_graph ->nodes .clear ();
2724
2732
cuda_ctx->cuda_graph ->nodes .resize (cuda_ctx->cuda_graph ->num_nodes );
2733
+ cuda_ctx->cuda_graph ->params .clear ();
2725
2734
cuda_ctx->cuda_graph ->params .resize (cuda_ctx->cuda_graph ->num_nodes );
2726
2735
if (cuda_ctx->cuda_graph ->num_nodes > 0 ) {
2727
2736
CUDA_CHECK (cudaGraphGetNodes (cuda_ctx->cuda_graph ->graph , cuda_ctx->cuda_graph ->nodes .data (), &cuda_ctx->cuda_graph ->num_nodes ));
0 commit comments