@@ -58,6 +58,7 @@ ggml_backend_buffer_t ggml_backend_buffer_init(
58
58
/* .buft = */ buft ,
59
59
/* .context = */ context ,
60
60
/* .size = */ size ,
61
+ /* .usage = */ GGML_BACKEND_BUFFER_USAGE_ANY
61
62
};
62
63
63
64
return buffer ;
@@ -109,6 +110,10 @@ bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
109
110
return ggml_backend_buft_is_host (ggml_backend_buffer_type (buffer ));
110
111
}
111
112
113
+ void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer , enum ggml_backend_buffer_usage usage ) {
114
+ buffer -> usage = usage ;
115
+ }
116
+
112
117
ggml_backend_buffer_type_t ggml_backend_buffer_type (ggml_backend_buffer_t buffer ) {
113
118
return buffer -> buft ;
114
119
}
@@ -773,7 +778,7 @@ static ggml_backend_t get_allocr_backend(ggml_backend_sched_t sched, ggml_talloc
773
778
}
774
779
775
780
#if 0
776
- static char causes [GGML_DEFAULT_GRAPH_SIZE * 8 + GGML_MAX_SPLITS * GGML_MAX_SPLIT_INPUTS ][128 ]; // debug, remove
781
+ static char causes [GGML_DEFAULT_GRAPH_SIZE * 16 + GGML_MAX_SPLITS * GGML_MAX_SPLIT_INPUTS ][128 ]; // debug, remove
777
782
#define SET_CAUSE (node , ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
778
783
#define GET_CAUSE (node ) causes[hash_id(node)]
779
784
#else
@@ -808,17 +813,25 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
808
813
if (src == NULL ) {
809
814
break ;
810
815
}
816
+
811
817
ggml_backend_t src_backend = get_buffer_backend (sched , src -> buffer );
812
- if (src_backend != NULL ) {
818
+ if (src -> buffer != NULL && src -> buffer -> usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS ) {
819
+ // operations with weights are always on the same backend as the weights
820
+ cur_backend = src_backend ;
821
+ SET_CAUSE (node , "1.wgt%d" , i );
822
+ break ;
823
+ }
824
+
825
+ //if (src_backend != NULL) {
813
826
int src_prio = sched_backend_prio (sched , src_backend );
814
827
size_t src_size = ggml_nbytes (src );
815
- if (src_prio < cur_prio && src_size >= cur_size ) {
828
+ if (/* src_prio < cur_prio &&*/ src_size >= cur_size ) {
816
829
cur_prio = src_prio ;
817
830
cur_size = src_size ;
818
831
cur_backend = src_backend ;
819
832
SET_CAUSE (node , "1.src%d" , i );
820
833
}
821
- }
834
+ // }
822
835
}
823
836
return cur_backend ;
824
837
}
@@ -929,6 +942,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
929
942
}
930
943
//printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
931
944
945
+ #if 0
932
946
// pass 2: assign backends to ops from current assignments
933
947
// TODO:
934
948
// - reuse sched_backend_from_cur
@@ -960,6 +974,23 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
960
974
}
961
975
}
962
976
}
977
+ #else
978
+ // pass 2: assign backends to ops from current assignments
979
+ // start from the end and assign the same backend to previous ops
980
+ {
981
+ ggml_tallocr_t cur_allocr = NULL ;
982
+ for (int i = graph -> n_nodes - 1 ; i >= 0 ; i -- ) {
983
+ struct ggml_tensor * node = graph -> nodes [i ];
984
+ ggml_tallocr_t node_allocr = node_allocr (node );
985
+ if (node_allocr != NULL ) {
986
+ cur_allocr = node_allocr ;
987
+ } else {
988
+ node_allocr (node ) = cur_allocr ;
989
+ SET_CAUSE (node , "2.cur" );
990
+ }
991
+ }
992
+ }
993
+ #endif
963
994
//printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
964
995
965
996
// pass 3: assign backends to remaining src from dst (should only be leafs)
@@ -1025,9 +1056,21 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
1025
1056
}
1026
1057
ggml_tallocr_t src_allocr = node_allocr (src );
1027
1058
if (src_allocr != node_allocr ) {
1028
- int n_inputs = sched -> splits [cur_split ].n_inputs ++ ;
1029
- GGML_ASSERT (n_inputs < GGML_MAX_SPLIT_INPUTS );
1030
- sched -> splits [cur_split ].inputs [n_inputs ] = (struct ggml_tensor * )src ;
1059
+ // check if the input is already in the split
1060
+ bool found = false;
1061
+ for (int k = 0 ; k < sched -> splits [cur_split ].n_inputs ; k ++ ) {
1062
+ if (sched -> splits [cur_split ].inputs [k ] == src ) {
1063
+ found = true;
1064
+ break ;
1065
+ }
1066
+ }
1067
+
1068
+ if (!found ) {
1069
+ int n_inputs = sched -> splits [cur_split ].n_inputs ++ ;
1070
+ //printf("split %d input %d: %s (%s)\n", cur_split, n_inputs, src->name, ggml_backend_name(get_allocr_backend(sched, src_allocr)));
1071
+ GGML_ASSERT (n_inputs < GGML_MAX_SPLIT_INPUTS );
1072
+ sched -> splits [cur_split ].inputs [n_inputs ] = (struct ggml_tensor * )src ;
1073
+ }
1031
1074
1032
1075
// create copies
1033
1076
size_t id = hash_id (src );
@@ -1231,6 +1274,10 @@ void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cg
1231
1274
sched_reset (sched );
1232
1275
}
1233
1276
1277
+ int ggml_backend_sched_get_n_splits (ggml_backend_sched_t sched ) {
1278
+ return sched -> n_splits ;
1279
+ }
1280
+
1234
1281
ggml_tallocr_t ggml_backend_sched_get_tallocr (ggml_backend_sched_t sched , ggml_backend_t backend ) {
1235
1282
int backend_index = sched_backend_prio (sched , backend );
1236
1283
return sched -> tallocs [backend_index ];
@@ -1316,6 +1363,7 @@ static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor
1316
1363
1317
1364
struct ggml_tensor * dst = node_copies [id ];
1318
1365
if (dst -> view_src != NULL ) {
1366
+ graph_init_tensor (hash_set , node_copies , node_init , src -> view_src );
1319
1367
ggml_backend_view_init (dst -> view_src -> buffer , dst );
1320
1368
}
1321
1369
else {
0 commit comments