|
48 | 48 |
|
49 | 49 | static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
50 | 50 |
|
51 |
| -[[noreturn]] |
| 51 | +static bool disable_cuda_graphs_due_to_failed_capture = false; |
| 52 | + |
52 | 53 | void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
|
53 | 54 | int id = -1; // in case cudaGetDevice fails
|
54 | 55 | cudaGetDevice(&id);
|
55 | 56 |
|
| 57 | + if(strcmp(msg,"operation not permitted when stream is capturing")==0 || |
| 58 | + strcmp(msg,"operation failed due to a previous error during capture")==0) { |
| 59 | + // CUDA graph capture has failed, but we can fall back to regular stream-based CUDA |
| 60 | + // so mark as failed, clear the error and return. |
| 61 | + disable_cuda_graphs_due_to_failed_capture = true; |
| 62 | + cudaGetLastError(); |
| 63 | + return; |
| 64 | + } |
56 | 65 | fprintf(stderr, "CUDA error: %s\n", msg);
|
57 | 66 | fprintf(stderr, " current device: %d, in function %s at %s:%d\n", id, func, file, line);
|
58 | 67 | fprintf(stderr, " %s\n", stmt);
|
@@ -2428,6 +2437,7 @@ struct ggml_cuda_graph {
|
2428 | 2437 | cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
|
2429 | 2438 | bool disable_due_to_gpu_arch = false;
|
2430 | 2439 | bool disable_due_to_too_many_updates = false;
|
| 2440 | + bool disable_due_to_failed_graph_capture = false; |
2431 | 2441 | int number_consecutive_updates = 0;
|
2432 | 2442 | ggml_graph_node_properties ggml_graph_properties[MAX_NODES_IN_CUDA_GRAPH];
|
2433 | 2443 | };
|
@@ -2481,9 +2491,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
2481 | 2491 | }
|
2482 | 2492 | }
|
2483 | 2493 |
|
2484 |
| - // Disable CUDA graphs in presence of env var, old GPU or use-case which is changing too rapidly. |
| 2494 | + // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, |
| 2495 | + // or previous graph capture failure. |
2485 | 2496 | // Also disable for multi-gpu for now. TO DO investigate
|
2486 |
| - if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || cuda_graph.disable_due_to_too_many_updates || |
| 2497 | + if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || |
| 2498 | + cuda_graph.disable_due_to_too_many_updates || cuda_graph.disable_due_to_failed_graph_capture || |
2487 | 2499 | ggml_backend_cuda_get_device_count() > 1){
|
2488 | 2500 | use_cuda_graph = false;
|
2489 | 2501 | }
|
@@ -2540,11 +2552,16 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
2540 | 2552 | bool use_cuda_graph = false;
|
2541 | 2553 | bool cuda_graph_update_required = false;
|
2542 | 2554 | #endif // USE_CUDA_GRAPH
|
2543 |
| - |
| 2555 | + |
| 2556 | + bool graph_evaluated_or_captured = false; |
| 2557 | + |
| 2558 | + while(!graph_evaluated_or_captured) { |
| 2559 | + // Temporarily avoid indenting here (and below the following if) to make code review easier |
| 2560 | + |
2544 | 2561 | // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
|
2545 | 2562 | // With the use of CUDA graphs, the execution will be performed by the graph launch.
|
2546 | 2563 | if(!use_cuda_graph || cuda_graph_update_required) {
|
2547 |
| - //temporarily avoid indenting here to make code review easier |
| 2564 | + |
2548 | 2565 | for (int i = 0; i < cgraph->n_nodes; i++) {
|
2549 | 2566 | ggml_tensor * node = cgraph->nodes[i];
|
2550 | 2567 |
|
@@ -2572,6 +2589,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
2572 | 2589 | #ifdef USE_CUDA_GRAPH
|
2573 | 2590 | if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture
|
2574 | 2591 | CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph));
|
| 2592 | + if(disable_cuda_graphs_due_to_failed_capture) { |
| 2593 | + use_cuda_graph = false; |
| 2594 | + cuda_graph.disable_due_to_failed_graph_capture = true; |
| 2595 | + } |
| 2596 | + } |
| 2597 | + else { |
| 2598 | + graph_evaluated_or_captured = true; |
| 2599 | + } |
2575 | 2600 | }
|
2576 | 2601 | if(use_cuda_graph){
|
2577 | 2602 |
|
|
0 commit comments