diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index c36c12d6579ac..1e1070bdba0d6 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -114,6 +114,8 @@ extern "C" { void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); // wait for an event on on a different stream void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + + enum ggml_status (*graph_compute_entire) (ggml_backend_t backend, struct ggml_cgraph * cgraph); }; struct ggml_backend { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 273075f4e5455..217b87225d17d 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1577,6 +1577,19 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st } enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + ggml_backend_t prefer_backend = nullptr; + for (size_t idx = 0; idx < GGML_SCHED_MAX_BACKENDS; idx++) { + prefer_backend = sched->backends[idx]; + if (nullptr != prefer_backend) { + if (ggml_backend_dev_type(prefer_backend->device) == GGML_BACKEND_DEVICE_TYPE_CPU) { + continue; + } else { + if (nullptr != prefer_backend->iface.graph_compute_entire) { + return prefer_backend->iface.graph_compute_entire(prefer_backend, graph); + } + } + } + } if (!sched->is_reset && !sched->is_alloc) { ggml_backend_sched_reset(sched); } diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index ec158dfac6e3e..5b51ea55aaa8b 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = { /* .graph_compute = */ ggml_backend_blas_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_compute_entire = */ NULL, }; static ggml_guid_t ggml_backend_blas_guid(void) { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index b8d272cda600c..d252531af67ba 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1897,6 +1897,7 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .graph_compute = */ ggml_backend_cann_graph_compute, /* .event_record = */ ggml_backend_cann_event_record, /* .event_wait = */ ggml_backend_cann_event_wait, + /* .graph_compute_entire = */ NULL, }; /** diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 09f8382b988a4..5a61260647c3e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -182,6 +182,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .graph_compute = */ ggml_backend_cpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_compute_entire = */ NULL, }; static ggml_guid_t ggml_backend_cpu_guid(void) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 497de37be8210..9fce55748e5d1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2816,6 +2816,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .graph_compute_entire = */ NULL, }; static ggml_guid_t ggml_backend_cuda_guid() { diff --git a/ggml/src/ggml-kompute/ggml-kompute.cpp b/ggml/src/ggml-kompute/ggml-kompute.cpp index 50579227183d3..fddafc7dc30cb 100644 --- a/ggml/src/ggml-kompute/ggml-kompute.cpp +++ b/ggml/src/ggml-kompute/ggml-kompute.cpp @@ -2058,6 +2058,7 @@ static struct ggml_backend_i kompute_backend_i = { /* .graph_compute = */ ggml_backend_kompute_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_compute_entire = */ NULL, }; static ggml_guid_t ggml_backend_kompute_guid() { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index e51a4169a23bf..413e5001dfa6b 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -4806,6 +4806,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { /* .graph_compute = */ ggml_backend_metal_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_compute_entire = */ NULL, }; static ggml_guid_t ggml_backend_metal_guid(void) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 14d9934fb1b73..cae5fd1110e7f 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -1167,6 +1167,7 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .graph_compute = */ ggml_backend_opencl_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_compute_entire = */ NULL, }; ggml_backend_t ggml_backend_opencl_init(void) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 6c3b80b0883c9..91eaa54813d01 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -697,6 +697,7 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_compute = */ ggml_backend_rpc_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_compute_entire = */ NULL, }; ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 6977b705e4877..bdde21ca6e914 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3749,6 +3749,7 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .graph_compute = */ ggml_backend_sycl_graph_compute, /* .event_record = */ ggml_backend_sycl_event_record, /* .event_wait = */ ggml_backend_sycl_event_wait, + /* .graph_compute_entire = */ NULL, }; static ggml_guid_t ggml_backend_sycl_guid() { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ff53bdfbe171c..e7279a519821e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8332,6 +8332,7 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .graph_compute = */ ggml_backend_vk_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .graph_compute_entire = */ NULL, }; static ggml_guid_t ggml_backend_vk_guid() {