Skip to content

cuda : synchronize graph capture and cublas handle destruction #14288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
#endif
#include "ggml-common.h"

#include <cstdio>
#include <array>
#include <cassert>
#include <cfloat>
#include <cstdio>
#include <string>
#include <vector>

Expand Down Expand Up @@ -767,21 +767,7 @@ struct ggml_backend_cuda_context {
name(GGML_CUDA_NAME + std::to_string(device)) {
}

~ggml_backend_cuda_context() {
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
~ggml_backend_cuda_context();

cudaStream_t stream(int device, int stream) {
if (streams[device][stream] == nullptr) {
Expand Down
44 changes: 41 additions & 3 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@
#include <atomic>
#include <charconv>
#include <cinttypes>
#include <condition_variable>
#include <cstddef>
#include <cstdint>
#include <float.h>
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <stdint.h>
#include <stdio.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <vector>
Expand Down Expand Up @@ -514,6 +514,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
}

// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured

static std::mutex ggml_cuda_lock;
static std::condition_variable ggml_cuda_lock_cv;
static std::atomic<int> ggml_cuda_lock_counter;

ggml_backend_cuda_context::~ggml_backend_cuda_context() {
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });

if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}


// cuda buffer

struct ggml_backend_cuda_buffer_context {
Expand Down Expand Up @@ -2685,6 +2712,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx

CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
graph_evaluated_or_captured = true; // CUDA graph has been captured

std::lock_guard<std::mutex> lock(ggml_cuda_lock);
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
ggml_cuda_lock_cv.notify_all();
}
} else {
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
}
Expand Down Expand Up @@ -2760,7 +2792,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
}
}

if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
if (use_cuda_graph && cuda_graph_update_required) {
// Start CUDA graph capture
{
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
}

CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}

Expand Down
Loading