Skip to content

Commit eb9f15f

Browse files
committed
With mechanism to fall back if graph capture fails
1 parent d44e0fb commit eb9f15f

File tree

2 files changed

+30
-6
lines changed

2 files changed

+30
-6
lines changed

ggml-cuda.cu

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,20 @@
4848

4949
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
5050

51-
[[noreturn]]
51+
static bool disable_cuda_graphs_due_to_failed_capture = false;
52+
5253
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
5354
int id = -1; // in case cudaGetDevice fails
5455
cudaGetDevice(&id);
5556

57+
if(strcmp(msg,"operation not permitted when stream is capturing")==0 ||
58+
strcmp(msg,"operation failed due to a previous error during capture")==0) {
59+
// CUDA graph capture has failed, but we can fall back to regular stream-based CUDA
60+
// so mark as failed, clear the error and return.
61+
disable_cuda_graphs_due_to_failed_capture = true;
62+
cudaGetLastError();
63+
return;
64+
}
5665
fprintf(stderr, "CUDA error: %s\n", msg);
5766
fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line);
5867
fprintf(stderr, " %s\n", stmt);
@@ -2428,6 +2437,7 @@ struct ggml_cuda_graph {
24282437
cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
24292438
bool disable_due_to_gpu_arch = false;
24302439
bool disable_due_to_too_many_updates = false;
2440+
bool disable_due_to_failed_graph_capture = false;
24312441
int number_consecutive_updates = 0;
24322442
ggml_graph_node_properties ggml_graph_properties[MAX_NODES_IN_CUDA_GRAPH];
24332443
};
@@ -2481,9 +2491,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24812491
}
24822492
}
24832493

2484-
// Disable CUDA graphs in presence of env var, old GPU or use-case which is changing too rapidly.
2494+
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
2495+
// or previous graph capture failure.
24852496
// Also disable for multi-gpu for now. TO DO investigate
2486-
if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || cuda_graph.disable_due_to_too_many_updates ||
2497+
if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch ||
2498+
cuda_graph.disable_due_to_too_many_updates || cuda_graph.disable_due_to_failed_graph_capture ||
24872499
ggml_backend_cuda_get_device_count() > 1){
24882500
use_cuda_graph = false;
24892501
}
@@ -2540,11 +2552,16 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25402552
bool use_cuda_graph = false;
25412553
bool cuda_graph_update_required = false;
25422554
#endif // USE_CUDA_GRAPH
2543-
2555+
2556+
bool graph_evaluated_or_captured = false;
2557+
2558+
while(!graph_evaluated_or_captured) {
2559+
// Temporarily avoid indenting here (and below the following if) to make code review easier
2560+
25442561
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
25452562
// With the use of CUDA graphs, the execution will be performed by the graph launch.
25462563
if(!use_cuda_graph || cuda_graph_update_required) {
2547-
//temporarily avoid indenting here to make code review easier
2564+
25482565
for (int i = 0; i < cgraph->n_nodes; i++) {
25492566
ggml_tensor * node = cgraph->nodes[i];
25502567

@@ -2572,6 +2589,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25722589
#ifdef USE_CUDA_GRAPH
25732590
if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture
25742591
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph));
2592+
if(disable_cuda_graphs_due_to_failed_capture) {
2593+
use_cuda_graph = false;
2594+
cuda_graph.disable_due_to_failed_graph_capture = true;
2595+
}
2596+
}
2597+
else {
2598+
graph_evaluated_or_captured = true;
2599+
}
25752600
}
25762601
if(use_cuda_graph){
25772602

ggml-cuda/common.cuh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@
172172

173173
#define GGML_CUDA_MAX_STREAMS 8
174174

175-
[[noreturn]]
176175
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
177176

178177
#define CUDA_CHECK_GEN(err, success, error_fn) \

0 commit comments

Comments
 (0)