@@ -36,6 +36,46 @@ static uint64_t get_time_ns() {
36
36
return std::chrono::nanoseconds (clock ::now ().time_since_epoch ()).count ();
37
37
}
38
38
39
+ static bool tensor_buft_override_equal (const llama_model_tensor_buft_override& a, const llama_model_tensor_buft_override& b) {
40
+ if (a.pattern != b.pattern ) {
41
+ // cString comparison that may be null
42
+ if (a.pattern == nullptr || b.pattern == nullptr ) {
43
+ return false ;
44
+ }
45
+ if (strcmp (a.pattern , b.pattern ) != 0 ) {
46
+ return false ;
47
+ }
48
+ }
49
+ if (a.buft != b.buft ) {
50
+ return false ;
51
+ }
52
+ return true ;
53
+ }
54
+
55
+ static bool vec_tensor_buft_override_equal (const std::vector<llama_model_tensor_buft_override>& a, const std::vector<llama_model_tensor_buft_override>& b) {
56
+ if (a.size () != b.size ()) {
57
+ return false ;
58
+ }
59
+ for (size_t i = 0 ; i < a.size (); i++) {
60
+ if (!tensor_buft_override_equal (a[i], b[i])) {
61
+ return false ;
62
+ }
63
+ }
64
+ return true ;
65
+ }
66
+
67
+ static bool vec_vec_tensor_buft_override_equal (const std::vector<std::vector<llama_model_tensor_buft_override>>& a, const std::vector<std::vector<llama_model_tensor_buft_override>>& b) {
68
+ if (a.size () != b.size ()) {
69
+ return false ;
70
+ }
71
+ for (size_t i = 0 ; i < a.size (); i++) {
72
+ if (!vec_tensor_buft_override_equal (a[i], b[i])) {
73
+ return false ;
74
+ }
75
+ }
76
+ return true ;
77
+ }
78
+
39
79
template <class T > static std::string join (const std::vector<T> & values, const std::string & delim) {
40
80
std::ostringstream str;
41
81
for (size_t i = 0 ; i < values.size (); i++) {
@@ -175,6 +215,7 @@ struct cmd_params {
175
215
std::vector<bool > no_kv_offload;
176
216
std::vector<bool > flash_attn;
177
217
std::vector<std::vector<float >> tensor_split;
218
+ std::vector<std::vector<llama_model_tensor_buft_override>> tensor_buft_overrides;
178
219
std::vector<bool > use_mmap;
179
220
std::vector<bool > embeddings;
180
221
ggml_numa_strategy numa;
@@ -207,6 +248,7 @@ static const cmd_params cmd_params_defaults = {
207
248
/* no_kv_offload */ { false },
208
249
/* flash_attn */ { false },
209
250
/* tensor_split */ { std::vector<float >(llama_max_devices (), 0 .0f ) },
251
+ /* tensor_buft_overrides*/ { std::vector<llama_model_tensor_buft_override>{{nullptr ,nullptr }} },
210
252
/* use_mmap */ { true },
211
253
/* embeddings */ { false },
212
254
/* numa */ GGML_NUMA_STRATEGY_DISABLED,
@@ -265,6 +307,7 @@ static void print_usage(int /* argc */, char ** argv) {
265
307
printf (" -embd, --embeddings <0|1> (default: %s)\n " ,
266
308
join (cmd_params_defaults.embeddings , " ," ).c_str ());
267
309
printf (" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n " );
310
+ printf (" -ot --override-tensors <tensor name pattern>=<buffer type>;... (default: disabled)\n " );
268
311
printf (" -r, --repetitions <n> (default: %d)\n " , cmd_params_defaults.reps );
269
312
printf (" --prio <0|1|2|3> (default: %d)\n " , cmd_params_defaults.prio );
270
313
printf (" --delay <0...N> (seconds) (default: %d)\n " , cmd_params_defaults.delay );
@@ -557,6 +600,87 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
557
600
}
558
601
params.tensor_split .push_back (tensor_split);
559
602
}
603
+ } else if (arg == " -ot" || arg == " --override-tensor" ) {
604
+ if (++i >= argc) {
605
+ invalid_param = true ;
606
+ break ;
607
+ }
608
+ auto value = argv[i];
609
+ /* static */ std::map<std::string, ggml_backend_buffer_type_t > buft_list;
610
+ if (buft_list.empty ()) {
611
+ // enumerate all the devices and add their buffer types to the list
612
+ for (size_t i = 0 ; i < ggml_backend_dev_count (); ++i) {
613
+ auto * dev = ggml_backend_dev_get (i);
614
+ auto * buft = ggml_backend_dev_buffer_type (dev);
615
+ if (buft) {
616
+ buft_list[ggml_backend_buft_name (buft)] = buft;
617
+ }
618
+ }
619
+ }
620
+ auto override_group_span_len = std::strcspn (value, " ," );
621
+ bool last_group = false ;
622
+ do {
623
+ if (override_group_span_len == 0 ) {
624
+ // Adds an empty override-tensors for an empty span
625
+ params.tensor_buft_overrides .push_back ({{}});
626
+ if (value[override_group_span_len] == ' \0 ' ) {
627
+ value = &value[override_group_span_len];
628
+ last_group = true ;
629
+ } else {
630
+ value = &value[override_group_span_len + 1 ];
631
+ override_group_span_len = std::strcspn (value, " ," );
632
+ }
633
+ continue ;
634
+ }
635
+ // Stamps null terminators into the argv
636
+ // value for this option to avoid the
637
+ // memory leak present in the implementation
638
+ // over in arg.cpp. Acceptable because we
639
+ // only parse these args once in this program.
640
+ auto override_group = value;
641
+ if (value[override_group_span_len] == ' \0 ' ) {
642
+ value = &value[override_group_span_len];
643
+ last_group = true ;
644
+ } else {
645
+ value[override_group_span_len] = ' \0 ' ;
646
+ value = &value[override_group_span_len + 1 ];
647
+ }
648
+ std::vector<llama_model_tensor_buft_override> group_tensor_buft_overrides{};
649
+ auto override_span_len = std::strcspn (override_group, " ;" );
650
+ while (override_span_len > 0 ) {
651
+ auto override = override_group;
652
+ if (override_group[override_span_len] != ' \0 ' ) {
653
+ override_group[override_span_len] = ' \0 ' ;
654
+ override_group = &override_group[override_span_len + 1 ];
655
+ } else {
656
+ override_group = &override_group[override_span_len];
657
+ }
658
+ auto tensor_name_span_len = std::strcspn (override , " =" );
659
+ if (tensor_name_span_len >= override_span_len) {
660
+ invalid_param = true ;
661
+ break ;
662
+ }
663
+ override [tensor_name_span_len] = ' \0 ' ;
664
+ auto tensor_name = override ;
665
+ auto buffer_type = &override [tensor_name_span_len + 1 ];
666
+ if (buft_list.find (buffer_type) == buft_list.end ()) {
667
+ printf (" Available buffer types:\n " );
668
+ for (const auto & it : buft_list) {
669
+ printf (" %s\n " , ggml_backend_buft_name (it.second ));
670
+ }
671
+ invalid_param = true ;
672
+ break ;
673
+ }
674
+ group_tensor_buft_overrides.push_back ({tensor_name, buft_list.at (buffer_type)});
675
+ override_span_len = std::strcspn (override_group, " ;" );
676
+ }
677
+ if (invalid_param) {
678
+ break ;
679
+ }
680
+ group_tensor_buft_overrides.push_back ({nullptr ,nullptr });
681
+ params.tensor_buft_overrides .push_back (group_tensor_buft_overrides);
682
+ override_group_span_len = std::strcspn (value, " ," );
683
+ } while (!last_group);
560
684
} else if (arg == " -r" || arg == " --repetitions" ) {
561
685
if (++i >= argc) {
562
686
invalid_param = true ;
@@ -648,6 +772,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
648
772
if (params.tensor_split .empty ()) {
649
773
params.tensor_split = cmd_params_defaults.tensor_split ;
650
774
}
775
+ if (params.tensor_buft_overrides .empty ()) {
776
+ params.tensor_buft_overrides = cmd_params_defaults.tensor_buft_overrides ;
777
+ }
651
778
if (params.use_mmap .empty ()) {
652
779
params.use_mmap = cmd_params_defaults.use_mmap ;
653
780
}
@@ -689,6 +816,7 @@ struct cmd_params_instance {
689
816
bool no_kv_offload;
690
817
bool flash_attn;
691
818
std::vector<float > tensor_split;
819
+ std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
692
820
bool use_mmap;
693
821
bool embeddings;
694
822
@@ -733,13 +861,20 @@ struct cmd_params_instance {
733
861
mparams.tensor_split = tensor_split.data ();
734
862
mparams.use_mmap = use_mmap;
735
863
864
+ if (tensor_buft_overrides.empty ()) {
865
+ mparams.tensor_buft_overrides = nullptr ;
866
+ } else {
867
+ GGML_ASSERT (tensor_buft_overrides.back ().pattern == nullptr && " Tensor buffer overrides not terminated with empty pattern" );
868
+ mparams.tensor_buft_overrides = tensor_buft_overrides.data ();
869
+ }
870
+
736
871
return mparams;
737
872
}
738
873
739
874
bool equal_mparams (const cmd_params_instance & other) const {
740
875
return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
741
876
split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
742
- tensor_split == other.tensor_split ;
877
+ tensor_split == other.tensor_split && vec_tensor_buft_override_equal (tensor_buft_overrides, other. tensor_buft_overrides ) ;
743
878
}
744
879
745
880
llama_context_params to_llama_cparams () const {
@@ -769,6 +904,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
769
904
for (const auto & sm : params.split_mode )
770
905
for (const auto & mg : params.main_gpu )
771
906
for (const auto & ts : params.tensor_split )
907
+ for (const auto & ot : params.tensor_buft_overrides )
772
908
for (const auto & mmp : params.use_mmap )
773
909
for (const auto & embd : params.embeddings )
774
910
for (const auto & nb : params.n_batch )
@@ -804,6 +940,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
804
940
/* .no_kv_offload= */ nkvo,
805
941
/* .flash_attn = */ fa,
806
942
/* .tensor_split = */ ts,
943
+ /* .tensor_buft_overrides = */ ot,
807
944
/* .use_mmap = */ mmp,
808
945
/* .embeddings = */ embd,
809
946
};
@@ -833,6 +970,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
833
970
/* .no_kv_offload= */ nkvo,
834
971
/* .flash_attn = */ fa,
835
972
/* .tensor_split = */ ts,
973
+ /* .tensor_buft_overrides = */ ot,
836
974
/* .use_mmap = */ mmp,
837
975
/* .embeddings = */ embd,
838
976
};
@@ -862,6 +1000,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
862
1000
/* .no_kv_offload= */ nkvo,
863
1001
/* .flash_attn = */ fa,
864
1002
/* .tensor_split = */ ts,
1003
+ /* .tensor_buft_overrides = */ ot,
865
1004
/* .use_mmap = */ mmp,
866
1005
/* .embeddings = */ embd,
867
1006
};
@@ -896,6 +1035,7 @@ struct test {
896
1035
bool no_kv_offload;
897
1036
bool flash_attn;
898
1037
std::vector<float > tensor_split;
1038
+ std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
899
1039
bool use_mmap;
900
1040
bool embeddings;
901
1041
int n_prompt;
@@ -927,6 +1067,7 @@ struct test {
927
1067
no_kv_offload = inst.no_kv_offload ;
928
1068
flash_attn = inst.flash_attn ;
929
1069
tensor_split = inst.tensor_split ;
1070
+ tensor_buft_overrides = inst.tensor_buft_overrides ;
930
1071
use_mmap = inst.use_mmap ;
931
1072
embeddings = inst.embeddings ;
932
1073
n_prompt = inst.n_prompt ;
@@ -972,9 +1113,9 @@ struct test {
972
1113
" build_commit" , " build_number" , " cpu_info" , " gpu_info" , " backends" , " model_filename" ,
973
1114
" model_type" , " model_size" , " model_n_params" , " n_batch" , " n_ubatch" , " n_threads" ,
974
1115
" cpu_mask" , " cpu_strict" , " poll" , " type_k" , " type_v" , " n_gpu_layers" ,
975
- " split_mode" , " main_gpu" , " no_kv_offload" , " flash_attn" , " tensor_split" , " use_mmap " ,
976
- " embeddings " , " n_prompt " , " n_gen " , " test_time " , " avg_ns " , " stddev_ns " ,
977
- " avg_ts" , " stddev_ts" ,
1116
+ " split_mode" , " main_gpu" , " no_kv_offload" , " flash_attn" , " tensor_split" , " tensor_buft_overrides " ,
1117
+ " use_mmap " , " embeddings " , " n_prompt " , " n_gen " , " test_time " , " avg_ns " ,
1118
+ " stddev_ns " , " avg_ts" , " stddev_ts" ,
978
1119
};
979
1120
return fields;
980
1121
}
@@ -1000,6 +1141,7 @@ struct test {
1000
1141
1001
1142
std::vector<std::string> get_values () const {
1002
1143
std::string tensor_split_str;
1144
+ std::string tensor_buft_overrides_str;
1003
1145
int max_nonzero = 0 ;
1004
1146
for (size_t i = 0 ; i < llama_max_devices (); i++) {
1005
1147
if (tensor_split[i] > 0 ) {
@@ -1014,6 +1156,26 @@ struct test {
1014
1156
tensor_split_str += " /" ;
1015
1157
}
1016
1158
}
1159
+ if (tensor_buft_overrides.size () == 1 ) {
1160
+ // Last element of tensor_buft_overrides is always a null pattern
1161
+ // so if it is only one element long, it must be a null pattern.
1162
+ GGML_ASSERT (tensor_buft_overrides[0 ].pattern == nullptr );
1163
+ tensor_buft_overrides_str += " none" ;
1164
+ } else {
1165
+ for (size_t i = 0 ; i < tensor_buft_overrides.size ()-1 ; i++) {
1166
+ // Last element of tensor_buft_overrides is always a null pattern
1167
+ if (tensor_buft_overrides[i].pattern == nullptr ) {
1168
+ tensor_buft_overrides_str += " none" ;
1169
+ } else {
1170
+ tensor_buft_overrides_str += tensor_buft_overrides[i].pattern ;
1171
+ tensor_buft_overrides_str += " =" ;
1172
+ tensor_buft_overrides_str += ggml_backend_buft_name (tensor_buft_overrides[i].buft );
1173
+ }
1174
+ if (i + 2 < tensor_buft_overrides.size ()) {
1175
+ tensor_buft_overrides_str += " ;" ;
1176
+ }
1177
+ }
1178
+ }
1017
1179
std::vector<std::string> values = { build_commit,
1018
1180
std::to_string (build_number),
1019
1181
cpu_info,
@@ -1037,6 +1199,7 @@ struct test {
1037
1199
std::to_string (no_kv_offload),
1038
1200
std::to_string (flash_attn),
1039
1201
tensor_split_str,
1202
+ tensor_buft_overrides_str,
1040
1203
std::to_string (use_mmap),
1041
1204
std::to_string (embeddings),
1042
1205
std::to_string (n_prompt),
@@ -1254,6 +1417,9 @@ struct markdown_printer : public printer {
1254
1417
if (field == " tensor_split" ) {
1255
1418
return " ts" ;
1256
1419
}
1420
+ if (field == " tensor_buft_overrides" ) {
1421
+ return " ot" ;
1422
+ }
1257
1423
return field;
1258
1424
}
1259
1425
@@ -1307,6 +1473,9 @@ struct markdown_printer : public printer {
1307
1473
if (params.tensor_split .size () > 1 || params.tensor_split != cmd_params_defaults.tensor_split ) {
1308
1474
fields.emplace_back (" tensor_split" );
1309
1475
}
1476
+ if (params.tensor_buft_overrides .size () > 1 || !vec_vec_tensor_buft_override_equal (params.tensor_buft_overrides , cmd_params_defaults.tensor_buft_overrides )) {
1477
+ fields.emplace_back (" tensor_buft_overrides" );
1478
+ }
1310
1479
if (params.use_mmap .size () > 1 || params.use_mmap != cmd_params_defaults.use_mmap ) {
1311
1480
fields.emplace_back (" use_mmap" );
1312
1481
}
0 commit comments