Skip to content

Commit 87a4f95

Browse files
committed
cuda : synchronize graph capture and cublas handle destruction
Workarounds an issue that may cause CUDA graph capture to fail when a cuBLAS handle is destroyed in a different thread ggml-ci
1 parent 381174b commit 87a4f95

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
#endif
2020
#include "ggml-common.h"
2121

22-
#include <cstdio>
2322
#include <array>
23+
#include <atomic>
2424
#include <cassert>
2525
#include <cfloat>
26+
#include <condition_variable>
27+
#include <cstdio>
28+
#include <mutex>
2629
#include <string>
2730
#include <vector>
2831

@@ -752,6 +755,12 @@ struct ggml_cuda_graph {
752755
#endif
753756
};
754757

758+
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
759+
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
760+
extern std::mutex ggml_cuda_lock;
761+
extern std::condition_variable ggml_cuda_lock_cv;
762+
extern std::atomic<int> ggml_cuda_lock_counter;
763+
755764
struct ggml_backend_cuda_context {
756765
int device;
757766
std::string name;
@@ -768,6 +777,9 @@ struct ggml_backend_cuda_context {
768777
}
769778

770779
~ggml_backend_cuda_context() {
780+
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
781+
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
782+
771783
if (copy_event != nullptr) {
772784
CUDA_CHECK(cudaEventDestroy(copy_event));
773785
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,10 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
514514
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
515515
}
516516

517+
std::mutex ggml_cuda_lock;
518+
std::condition_variable ggml_cuda_lock_cv;
519+
std::atomic<int> ggml_cuda_lock_counter;
520+
517521
// cuda buffer
518522

519523
struct ggml_backend_cuda_buffer_context {
@@ -2685,6 +2689,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
26852689

26862690
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
26872691
graph_evaluated_or_captured = true; // CUDA graph has been captured
2692+
2693+
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2694+
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
2695+
ggml_cuda_lock_cv.notify_all();
2696+
}
26882697
} else {
26892698
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
26902699
}
@@ -2760,7 +2769,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27602769
}
27612770
}
27622771

2763-
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
2772+
if (use_cuda_graph && cuda_graph_update_required) {
2773+
// Start CUDA graph capture
2774+
{
2775+
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2776+
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
2777+
}
2778+
27642779
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
27652780
}
27662781

0 commit comments

Comments
 (0)