@@ -514,6 +514,10 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
514
514
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg (device));
515
515
}
516
516
517
+ std::mutex ggml_cuda_lock;
518
+ std::condition_variable ggml_cuda_lock_cv;
519
+ std::atomic<int > ggml_cuda_lock_counter;
520
+
517
521
// cuda buffer
518
522
519
523
struct ggml_backend_cuda_buffer_context {
@@ -2685,6 +2689,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
2685
2689
2686
2690
CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
2687
2691
graph_evaluated_or_captured = true ; // CUDA graph has been captured
2692
+
2693
+ std::lock_guard<std::mutex> lock (ggml_cuda_lock);
2694
+ if (ggml_cuda_lock_counter.fetch_sub (1 , std::memory_order_relaxed) == 1 ) {
2695
+ ggml_cuda_lock_cv.notify_all ();
2696
+ }
2688
2697
} else {
2689
2698
graph_evaluated_or_captured = true ; // ggml graph has been directly evaluated
2690
2699
}
@@ -2760,7 +2769,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
2760
2769
}
2761
2770
}
2762
2771
2763
- if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
2772
+ if (use_cuda_graph && cuda_graph_update_required) {
2773
+ // Start CUDA graph capture
2774
+ {
2775
+ std::lock_guard<std::mutex> lock (ggml_cuda_lock);
2776
+ ggml_cuda_lock_counter.fetch_add (1 , std::memory_order_relaxed);
2777
+ }
2778
+
2764
2779
CUDA_CHECK (cudaStreamBeginCapture (cuda_ctx->stream (), cudaStreamCaptureModeRelaxed));
2765
2780
}
2766
2781
0 commit comments