Skip to content

Commit 16bc66d

Browse files
authored
llama.cpp : split llama_context_params into model and context params (#3301)
* llama.cpp : split llama_context_params into model and context params ggml-ci * fix metal build * fix freq_base/scale default to model value * llama-bench : keep the same model between tests when possible * move n_threads to llama_context_params, add n_threads_batch * fix mpi build * remove kv_size(), cuda scratch fixes * remove low-vram option * add n_threads_batch to system info, refactor to get_system_info() * add documentation about --threads-batch to the READMEs * llama-bench fix * main : fix rope freq/scale warning * llama.cpp : add llama_get_model common : add llama_tokenize from model * remove duplicated ctx/model functions ggml-ci * cuda : print total VRAM used
1 parent 0512d66 commit 16bc66d

27 files changed

+713
-633
lines changed

common/common.cpp

+72-40
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
129129
if (params.n_threads <= 0) {
130130
params.n_threads = std::thread::hardware_concurrency();
131131
}
132+
} else if (arg == "-tb" || arg == "--threads-batch") {
133+
if (++i >= argc) {
134+
invalid_param = true;
135+
break;
136+
}
137+
params.n_threads_batch = std::stoi(argv[i]);
138+
if (params.n_threads_batch <= 0) {
139+
params.n_threads_batch = std::thread::hardware_concurrency();
140+
}
132141
} else if (arg == "-p" || arg == "--prompt") {
133142
if (++i >= argc) {
134143
invalid_param = true;
@@ -451,12 +460,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
451460
params.mul_mat_q = false;
452461
#else
453462
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
454-
#endif // GGML_USE_CUBLAS
455-
} else if (arg == "--low-vram" || arg == "-lv") {
456-
#ifdef GGML_USE_CUBLAS
457-
params.low_vram = true;
458-
#else
459-
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
460463
#endif // GGML_USE_CUBLAS
461464
} else if (arg == "--no-mmap") {
462465
params.use_mmap = false;
@@ -630,7 +633,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
630633
printf(" (can be specified more than once for multiple prompts).\n");
631634
printf(" --color colorise output to distinguish prompt and user input from generations\n");
632635
printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
633-
printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
636+
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
637+
printf(" -tb N, --threads-batch N\n");
638+
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
634639
printf(" -p PROMPT, --prompt PROMPT\n");
635640
printf(" prompt to start generation with (default: empty)\n");
636641
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
@@ -645,7 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
645650
printf(" -f FNAME, --file FNAME\n");
646651
printf(" prompt file to start generation.\n");
647652
printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
648-
printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
653+
printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
649654
printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
650655
printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
651656
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
@@ -705,7 +710,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
705710
printf(" -ts SPLIT --tensor-split SPLIT\n");
706711
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
707712
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
708-
printf(" -lv, --low-vram don't allocate VRAM scratch buffer\n");
709713
#ifdef GGML_USE_CUBLAS
710714
printf(" -nommq, --no-mul-mat-q\n");
711715
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
@@ -726,6 +730,18 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
726730
printf("\n");
727731
}
728732

733+
std::string get_system_info(const gpt_params & params) {
734+
std::ostringstream os;
735+
736+
os << "system_info: n_threads = " << params.n_threads;
737+
if (params.n_threads_batch != -1) {
738+
os << " (n_threads_batch = " << params.n_threads_batch << ")";
739+
}
740+
os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
741+
742+
return os.str();
743+
}
744+
729745
std::string gpt_random_prompt(std::mt19937 & rng) {
730746
const int r = rng() % 10;
731747
switch (r) {
@@ -749,40 +765,50 @@ std::string gpt_random_prompt(std::mt19937 & rng) {
749765
// Model utils
750766
//
751767

752-
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
753-
auto lparams = llama_context_default_params();
768+
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
769+
auto mparams = llama_model_default_params();
754770

755-
lparams.n_ctx = params.n_ctx;
756-
lparams.n_batch = params.n_batch;
757771
if (params.n_gpu_layers != -1) {
758-
lparams.n_gpu_layers = params.n_gpu_layers;
772+
mparams.n_gpu_layers = params.n_gpu_layers;
759773
}
760-
lparams.main_gpu = params.main_gpu;
761-
lparams.tensor_split = params.tensor_split;
762-
lparams.low_vram = params.low_vram;
763-
lparams.mul_mat_q = params.mul_mat_q;
764-
lparams.seed = params.seed;
765-
lparams.f16_kv = params.memory_f16;
766-
lparams.use_mmap = params.use_mmap;
767-
lparams.use_mlock = params.use_mlock;
768-
lparams.logits_all = params.logits_all;
769-
lparams.embedding = params.embedding;
770-
lparams.rope_freq_base = params.rope_freq_base;
771-
lparams.rope_freq_scale = params.rope_freq_scale;
772-
773-
return lparams;
774+
mparams.main_gpu = params.main_gpu;
775+
mparams.tensor_split = params.tensor_split;
776+
mparams.use_mmap = params.use_mmap;
777+
mparams.use_mlock = params.use_mlock;
778+
779+
return mparams;
780+
}
781+
782+
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
783+
auto cparams = llama_context_default_params();
784+
785+
cparams.n_ctx = params.n_ctx;
786+
cparams.n_batch = params.n_batch;
787+
cparams.n_threads = params.n_threads;
788+
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
789+
cparams.mul_mat_q = params.mul_mat_q;
790+
cparams.seed = params.seed;
791+
cparams.f16_kv = params.memory_f16;
792+
cparams.logits_all = params.logits_all;
793+
cparams.embedding = params.embedding;
794+
cparams.rope_freq_base = params.rope_freq_base;
795+
cparams.rope_freq_scale = params.rope_freq_scale;
796+
797+
return cparams;
774798
}
775799

776800
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
777-
auto lparams = llama_context_params_from_gpt_params(params);
801+
auto mparams = llama_model_params_from_gpt_params(params);
778802

779-
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
803+
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
780804
if (model == NULL) {
781805
fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
782806
return std::make_tuple(nullptr, nullptr);
783807
}
784808

785-
llama_context * lctx = llama_new_context_with_model(model, lparams);
809+
auto cparams = llama_context_params_from_gpt_params(params);
810+
811+
llama_context * lctx = llama_new_context_with_model(model, cparams);
786812
if (lctx == NULL) {
787813
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
788814
llama_free_model(model);
@@ -815,7 +841,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
815841
LOG("warming up the model with an empty run\n");
816842

817843
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
818-
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
844+
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
819845
llama_kv_cache_tokens_rm(lctx, -1, -1);
820846
llama_reset_timings(lctx);
821847
}
@@ -828,16 +854,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
828854
//
829855

830856
std::vector<llama_token> llama_tokenize(
831-
struct llama_context * ctx,
857+
const struct llama_context * ctx,
858+
const std::string & text,
859+
bool add_bos) {
860+
return llama_tokenize(llama_get_model(ctx), text, add_bos);
861+
}
862+
863+
std::vector<llama_token> llama_tokenize(
864+
const struct llama_model * model,
832865
const std::string & text,
833866
bool add_bos) {
834867
// upper limit for the number of tokens
835868
int n_tokens = text.length() + add_bos;
836869
std::vector<llama_token> result(n_tokens);
837-
n_tokens = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
870+
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
838871
if (n_tokens < 0) {
839872
result.resize(-n_tokens);
840-
int check = llama_tokenize(ctx, text.data(), text.length(), result.data(), result.size(), add_bos);
873+
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
841874
GGML_ASSERT(check == -n_tokens);
842875
} else {
843876
result.resize(n_tokens);
@@ -847,10 +880,10 @@ std::vector<llama_token> llama_tokenize(
847880

848881
std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
849882
std::vector<char> result(8, 0);
850-
const int n_tokens = llama_token_to_piece(ctx, token, result.data(), result.size());
883+
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
851884
if (n_tokens < 0) {
852885
result.resize(-n_tokens);
853-
int check = llama_token_to_piece(ctx, token, result.data(), result.size());
886+
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
854887
GGML_ASSERT(check == -n_tokens);
855888
} else {
856889
result.resize(n_tokens);
@@ -905,7 +938,7 @@ llama_token llama_sample_token(
905938
std::vector<llama_token_data> & candidates,
906939
int idx) {
907940
const int n_ctx = llama_n_ctx(ctx);
908-
const int n_vocab = llama_n_vocab(ctx);
941+
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
909942

910943
const float temp = params.temp;
911944
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
@@ -1191,7 +1224,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
11911224
#endif // NDEBUG
11921225

11931226
fprintf(stream, "model_desc: %s\n", model_desc);
1194-
fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(lctx));
1227+
fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));
11951228

11961229
#ifdef __OPTIMIZE__
11971230
fprintf(stream, "optimize: true\n");
@@ -1258,7 +1291,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
12581291
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
12591292
}
12601293
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
1261-
fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false");
12621294
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
12631295
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
12641296
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat);

common/common.h

+10-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ int32_t get_num_physical_cores();
3636
struct gpt_params {
3737
uint32_t seed = -1; // RNG seed
3838
int32_t n_threads = get_num_physical_cores();
39+
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
3940
int32_t n_predict = -1; // new tokens to predict
4041
int32_t n_ctx = 512; // context size
4142
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
@@ -95,7 +96,6 @@ struct gpt_params {
9596
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
9697
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
9798

98-
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
9999
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
100100
bool memory_f16 = true; // use f16 instead of f32 for memory kv
101101
bool random_prompt = false; // do not randomize prompt if none provided
@@ -126,6 +126,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
126126

127127
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
128128

129+
std::string get_system_info(const gpt_params & params);
130+
129131
std::string gpt_random_prompt(std::mt19937 & rng);
130132

131133
void process_escapes(std::string& input);
@@ -135,6 +137,7 @@ void process_escapes(std::string& input);
135137
//
136138

137139
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params);
140+
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params);
138141
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
139142

140143
//
@@ -144,7 +147,12 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
144147
// tokenizes a string into a vector of tokens
145148
// should work similar to Python's `tokenizer.encode`
146149
std::vector<llama_token> llama_tokenize(
147-
struct llama_context * ctx,
150+
const struct llama_context * ctx,
151+
const std::string & text,
152+
bool add_bos);
153+
154+
std::vector<llama_token> llama_tokenize(
155+
const struct llama_model * model,
148156
const std::string & text,
149157
bool add_bos);
150158

common/train.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ size_t tokenize_file(
858858
out_tokens.resize(buf.size() + n_max_tokens_overhead);
859859

860860
int n_tokens = llama_tokenize(
861-
lctx,
861+
llama_get_model(lctx),
862862
buf.data(),
863863
(int) buf.size(),
864864
out_tokens.data(),
@@ -867,7 +867,7 @@ size_t tokenize_file(
867867
if (n_tokens < 0) {
868868
out_tokens.resize(-n_tokens);
869869
n_tokens = llama_tokenize(
870-
lctx,
870+
llama_get_model(lctx),
871871
buf.data(),
872872
(int) buf.size(),
873873
out_tokens.data(),
@@ -920,7 +920,7 @@ size_t tokenize_file(
920920
size_t found_max_sample_size = 0;
921921

922922
size_t max_token_text_size = 0;
923-
int n_vocab = llama_n_vocab(lctx);
923+
int n_vocab = llama_n_vocab(llama_get_model(lctx));
924924
for (llama_token token=0; token < n_vocab; ++token) {
925925
max_token_text_size = std::max(
926926
max_token_text_size,
@@ -961,15 +961,15 @@ size_t tokenize_file(
961961

962962
// tokenize the sample
963963
tok_sample.resize(buf_sample.size() + n_max_tokens_overhead);
964-
int n_tokens = llama_tokenize(lctx,
964+
int n_tokens = llama_tokenize(llama_get_model(lctx),
965965
buf_sample.data(),
966966
(int) buf_sample.size(),
967967
tok_sample.data(),
968968
(int) tok_sample.size(),
969969
false);
970970
if (n_tokens < 0) {
971971
tok_sample.resize(-n_tokens);
972-
n_tokens = llama_tokenize(lctx,
972+
n_tokens = llama_tokenize(llama_get_model(lctx),
973973
buf_sample.data(),
974974
(int) buf_sample.size(),
975975
tok_sample.data(),

examples/batched/batched.cpp

+24-15
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,43 @@ int main(int argc, char ** argv) {
4040

4141
llama_backend_init(params.numa);
4242

43-
llama_context_params ctx_params = llama_context_default_params();
43+
// initialize the model
4444

45-
ctx_params.seed = 1234;
46-
ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
47-
ctx_params.n_batch = std::max(n_len, n_parallel);
48-
// ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
45+
llama_model_params model_params = llama_model_default_params();
4946

50-
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
47+
// model_params.n_gpu_layers = 99; // offload all layers to the GPU
48+
49+
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
5150

5251
if (model == NULL) {
5352
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
5453
return 1;
5554
}
5655

56+
// tokenize the prompt
57+
58+
std::vector<llama_token> tokens_list;
59+
tokens_list = ::llama_tokenize(model, params.prompt, true);
60+
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
61+
62+
// initialize the context
63+
64+
llama_context_params ctx_params = llama_context_default_params();
65+
66+
ctx_params.seed = 1234;
67+
ctx_params.n_ctx = n_kv_req;
68+
ctx_params.n_batch = std::max(n_len, n_parallel);
69+
ctx_params.n_threads = params.n_threads;
70+
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
71+
5772
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
5873

5974
if (ctx == NULL) {
6075
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
6176
return 1;
6277
}
6378

64-
// tokenize the prompt
65-
66-
std::vector<llama_token> tokens_list;
67-
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
68-
6979
const int n_ctx = llama_n_ctx(ctx);
70-
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
7180

7281
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
7382

@@ -106,7 +115,7 @@ int main(int argc, char ** argv) {
106115
// llama_decode will output logits only for the last token of the prompt
107116
batch.logits[batch.n_tokens - 1] = true;
108117

109-
if (llama_decode(ctx, batch, params.n_threads) != 0) {
118+
if (llama_decode(ctx, batch) != 0) {
110119
LOG_TEE("%s: llama_decode() failed\n", __func__);
111120
return 1;
112121
}
@@ -146,7 +155,7 @@ int main(int argc, char ** argv) {
146155
continue;
147156
}
148157

149-
auto n_vocab = llama_n_vocab(ctx);
158+
auto n_vocab = llama_n_vocab(model);
150159
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
151160

152161
std::vector<llama_token_data> candidates;
@@ -210,7 +219,7 @@ int main(int argc, char ** argv) {
210219
n_cur += 1;
211220

212221
// evaluate the current batch with the transformer model
213-
if (llama_decode(ctx, batch, params.n_threads)) {
222+
if (llama_decode(ctx, batch)) {
214223
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
215224
return 1;
216225
}

0 commit comments

Comments
 (0)