@@ -255,13 +255,12 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
255
255
}
256
256
257
257
ggml_backend_graph_plan_t ggml_backend_graph_plan_create (
258
- ggml_backend_t backend ,
259
- const struct ggml_cgraph * cgraph ,
260
- ggml_compute_threadpool_t threadpool
258
+ ggml_backend_t backend ,
259
+ const struct ggml_cgraph * cgraph
261
260
) {
262
261
GGML_ASSERT (backend -> iface .graph_plan_create != NULL );
263
262
264
- return backend -> iface .graph_plan_create (backend , cgraph , threadpool );
263
+ return backend -> iface .graph_plan_create (backend , cgraph );
265
264
}
266
265
267
266
void ggml_backend_graph_plan_free (ggml_backend_t backend , ggml_backend_graph_plan_t plan ) {
@@ -281,20 +280,18 @@ enum ggml_status ggml_backend_graph_plan_compute(
281
280
282
281
enum ggml_status ggml_backend_graph_compute (
283
282
ggml_backend_t backend ,
284
- struct ggml_cgraph * cgraph ,
285
- ggml_compute_threadpool_t threadpool
283
+ struct ggml_cgraph * cgraph
286
284
) {
287
- enum ggml_status err = ggml_backend_graph_compute_async (backend , cgraph , threadpool );
285
+ enum ggml_status err = ggml_backend_graph_compute_async (backend , cgraph );
288
286
ggml_backend_synchronize (backend );
289
287
return err ;
290
288
}
291
289
292
290
enum ggml_status ggml_backend_graph_compute_async (
293
- ggml_backend_t backend ,
294
- struct ggml_cgraph * cgraph ,
295
- ggml_compute_threadpool_t threadpool
291
+ ggml_backend_t backend ,
292
+ struct ggml_cgraph * cgraph
296
293
) {
297
- return backend -> iface .graph_compute (backend , cgraph , threadpool );
294
+ return backend -> iface .graph_compute (backend , cgraph );
298
295
}
299
296
300
297
bool ggml_backend_supports_op (ggml_backend_t backend , const struct ggml_tensor * op ) {
@@ -741,7 +738,9 @@ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
741
738
#endif
742
739
743
740
struct ggml_backend_cpu_context {
744
- int n_threads ;
741
+ int n_threads ;
742
+ ggml_compute_threadpool_t threadpool ;
743
+
745
744
void * work_data ;
746
745
size_t work_size ;
747
746
@@ -774,15 +773,14 @@ struct ggml_backend_plan_cpu {
774
773
};
775
774
776
775
GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create (
777
- ggml_backend_t backend ,
778
- const struct ggml_cgraph * cgraph ,
779
- ggml_compute_threadpool_t threadpool
776
+ ggml_backend_t backend ,
777
+ const struct ggml_cgraph * cgraph
780
778
) {
781
779
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context * )backend -> context ;
782
780
783
781
struct ggml_backend_plan_cpu * cpu_plan = malloc (sizeof (struct ggml_backend_plan_cpu ));
784
782
785
- cpu_plan -> cplan = ggml_graph_plan (cgraph , cpu_ctx -> n_threads , threadpool );
783
+ cpu_plan -> cplan = ggml_graph_plan (cgraph , cpu_ctx -> n_threads , cpu_ctx -> threadpool );
786
784
cpu_plan -> cgraph = * cgraph ; // FIXME: deep copy
787
785
788
786
if (cpu_plan -> cplan .work_size > 0 ) {
@@ -817,13 +815,12 @@ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backe
817
815
}
818
816
819
817
GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute (
820
- ggml_backend_t backend ,
821
- struct ggml_cgraph * cgraph ,
822
- ggml_compute_threadpool_t threadpool
818
+ ggml_backend_t backend ,
819
+ struct ggml_cgraph * cgraph
823
820
) {
824
821
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context * )backend -> context ;
825
822
826
- struct ggml_cplan cplan = ggml_graph_plan (cgraph , cpu_ctx -> n_threads , threadpool );
823
+ struct ggml_cplan cplan = ggml_graph_plan (cgraph , cpu_ctx -> n_threads , cpu_ctx -> threadpool );
827
824
828
825
if (cpu_ctx -> work_size < cplan .work_size ) {
829
826
free (cpu_ctx -> work_data );
@@ -892,6 +889,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
892
889
}
893
890
894
891
ctx -> n_threads = GGML_DEFAULT_N_THREADS ;
892
+ ctx -> threadpool = NULL ;
895
893
ctx -> work_data = NULL ;
896
894
ctx -> work_size = 0 ;
897
895
ctx -> abort_callback = NULL ;
@@ -922,6 +920,13 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
922
920
ctx -> n_threads = n_threads ;
923
921
}
924
922
923
+ void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu , ggml_compute_threadpool_t threadpool ) {
924
+ GGML_ASSERT (ggml_backend_is_cpu (backend_cpu ));
925
+
926
+ struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context * )backend_cpu -> context ;
927
+ ctx -> threadpool = threadpool ;
928
+ }
929
+
925
930
void ggml_backend_cpu_set_abort_callback (ggml_backend_t backend_cpu , ggml_abort_callback abort_callback , void * abort_callback_data ) {
926
931
GGML_ASSERT (ggml_backend_is_cpu (backend_cpu ));
927
932
@@ -1653,10 +1658,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
1653
1658
return true;
1654
1659
}
1655
1660
1656
- static enum ggml_status ggml_backend_sched_compute_splits (
1657
- ggml_backend_sched_t sched ,
1658
- ggml_compute_threadpool_t threadpool
1659
- ) {
1661
+ static enum ggml_status ggml_backend_sched_compute_splits (ggml_backend_sched_t sched ) {
1660
1662
struct ggml_backend_sched_split * splits = sched -> splits ;
1661
1663
1662
1664
for (int i = 0 ; i < sched -> n_splits ; i ++ ) {
@@ -1690,7 +1692,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(
1690
1692
}
1691
1693
1692
1694
if (!sched -> callback_eval ) {
1693
- enum ggml_status ec = ggml_backend_graph_compute_async (split_backend , & split -> graph , threadpool );
1695
+ enum ggml_status ec = ggml_backend_graph_compute_async (split_backend , & split -> graph );
1694
1696
if (ec != GGML_STATUS_SUCCESS ) {
1695
1697
return ec ;
1696
1698
}
@@ -1712,7 +1714,7 @@ static enum ggml_status ggml_backend_sched_compute_splits(
1712
1714
1713
1715
struct ggml_cgraph gv = ggml_graph_view (& split -> graph , j0 , j1 + 1 );
1714
1716
1715
- enum ggml_status ec = ggml_backend_graph_compute_async (split_backend , & gv , threadpool );
1717
+ enum ggml_status ec = ggml_backend_graph_compute_async (split_backend , & gv );
1716
1718
if (ec != GGML_STATUS_SUCCESS ) {
1717
1719
return ec ;
1718
1720
}
@@ -1852,19 +1854,17 @@ bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgra
1852
1854
}
1853
1855
1854
1856
enum ggml_status ggml_backend_sched_graph_compute (
1855
- ggml_backend_sched_t sched ,
1856
- struct ggml_cgraph * graph ,
1857
- ggml_compute_threadpool_t threadpool
1857
+ ggml_backend_sched_t sched ,
1858
+ struct ggml_cgraph * graph
1858
1859
) {
1859
- enum ggml_status err = ggml_backend_sched_graph_compute_async (sched , graph , threadpool );
1860
+ enum ggml_status err = ggml_backend_sched_graph_compute_async (sched , graph );
1860
1861
ggml_backend_sched_synchronize (sched );
1861
1862
return err ;
1862
1863
}
1863
1864
1864
1865
enum ggml_status ggml_backend_sched_graph_compute_async (
1865
- ggml_backend_sched_t sched ,
1866
- struct ggml_cgraph * graph ,
1867
- ggml_compute_threadpool_t threadpool
1866
+ ggml_backend_sched_t sched ,
1867
+ struct ggml_cgraph * graph
1868
1868
) {
1869
1869
if (!sched -> is_reset && !sched -> is_alloc ) {
1870
1870
ggml_backend_sched_reset (sched );
@@ -1876,7 +1876,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(
1876
1876
}
1877
1877
}
1878
1878
1879
- return ggml_backend_sched_compute_splits (sched , threadpool );
1879
+ return ggml_backend_sched_compute_splits (sched );
1880
1880
}
1881
1881
1882
1882
void ggml_backend_sched_synchronize (ggml_backend_sched_t sched ) {
@@ -2115,8 +2115,8 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
2115
2115
struct ggml_cgraph g1v = ggml_graph_view (g1 , i , i + 1 );
2116
2116
struct ggml_cgraph g2v = ggml_graph_view (g2 , i , i + 1 );
2117
2117
2118
- ggml_backend_graph_compute (backend1 , & g1v , NULL );
2119
- ggml_backend_graph_compute (backend2 , & g2v , NULL );
2118
+ ggml_backend_graph_compute (backend1 , & g1v );
2119
+ ggml_backend_graph_compute (backend2 , & g2v );
2120
2120
2121
2121
if (ggml_is_view_op (t1 -> op )) {
2122
2122
continue ;
0 commit comments