Skip to content

Commit 58e6c9f

Browse files
authored
Add support for file load progress reporting callbacks (abetlen#434)
* File load progress reporting * Move llama_progress_handler into llama_context_params * Renames * Use seekg to find file size instead * More correct load progress * Call progress callback more frequently * Fix typo
1 parent 36d0753 commit 58e6c9f

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

llama.cpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,16 @@ static void kv_cache_free(struct llama_kv_cache & cache) {
267267

268268
struct llama_context_params llama_context_default_params() {
269269
struct llama_context_params result = {
270-
/*.n_ctx =*/ 512,
271-
/*.n_parts =*/ -1,
272-
/*.seed =*/ 0,
273-
/*.f16_kv =*/ false,
274-
/*.logits_all =*/ false,
275-
/*.vocab_only =*/ false,
276-
/*.use_mlock =*/ false,
277-
/*.embedding =*/ false,
270+
/*.n_ctx =*/ 512,
271+
/*.n_parts =*/ -1,
272+
/*.seed =*/ 0,
273+
/*.f16_kv =*/ false,
274+
/*.logits_all =*/ false,
275+
/*.vocab_only =*/ false,
276+
/*.use_mlock =*/ false,
277+
/*.embedding =*/ false,
278+
/*.progress_callback =*/ nullptr,
279+
/*.progress_callback_user_data =*/ nullptr,
278280
};
279281

280282
return result;
@@ -290,7 +292,9 @@ static bool llama_model_load(
290292
int n_ctx,
291293
int n_parts,
292294
ggml_type memory_type,
293-
bool vocab_only) {
295+
bool vocab_only,
296+
llama_progress_callback progress_callback,
297+
void *progress_callback_user_data) {
294298
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
295299

296300
const int64_t t_start_us = ggml_time_us();
@@ -576,6 +580,10 @@ static bool llama_model_load(
576580

577581
std::vector<uint8_t> tmp;
578582

583+
if (progress_callback) {
584+
progress_callback(0.0, progress_callback_user_data);
585+
}
586+
579587
for (int i = 0; i < n_parts; ++i) {
580588
const int part_id = i;
581589
//const int part_id = n_parts - i - 1;
@@ -589,6 +597,10 @@ static bool llama_model_load(
589597

590598
fin = std::ifstream(fname_part, std::ios::binary);
591599
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
600+
601+
fin.seekg(0, fin.end);
602+
const size_t file_size = fin.tellg();
603+
592604
fin.seekg(file_offset);
593605

594606
// load weights
@@ -764,6 +776,11 @@ static bool llama_model_load(
764776
model.n_loaded++;
765777

766778
// progress
779+
if (progress_callback) {
780+
double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
781+
double current_progress = (double(i) + current_file_progress) / double(n_parts);
782+
progress_callback(current_progress, progress_callback_user_data);
783+
}
767784
if (model.n_loaded % 8 == 0) {
768785
fprintf(stderr, ".");
769786
fflush(stderr);
@@ -786,6 +803,10 @@ static bool llama_model_load(
786803

787804
lctx.t_load_us = ggml_time_us() - t_start_us;
788805

806+
if (progress_callback) {
807+
progress_callback(1.0, progress_callback_user_data);
808+
}
809+
789810
return true;
790811
}
791812

@@ -1617,7 +1638,8 @@ struct llama_context * llama_init_from_file(
16171638
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
16181639

16191640
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, memory_type,
1620-
params.vocab_only)) {
1641+
params.vocab_only, params.progress_callback,
1642+
params.progress_callback_user_data)) {
16211643
fprintf(stderr, "%s: failed to load model\n", __func__);
16221644
llama_free(ctx);
16231645
return nullptr;

llama.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ extern "C" {
4545

4646
} llama_token_data;
4747

48+
typedef void (*llama_progress_callback)(double progress, void *ctx);
49+
4850
struct llama_context_params {
4951
int n_ctx; // text context
5052
int n_parts; // -1 for default
@@ -55,6 +57,11 @@ extern "C" {
5557
bool vocab_only; // only load the vocabulary, no weights
5658
bool use_mlock; // force system to keep model in RAM
5759
bool embedding; // embedding mode only
60+
61+
// called with a progress value between 0 and 1, pass NULL to disable
62+
llama_progress_callback progress_callback;
63+
// context pointer passed to the progress callback
64+
void * progress_callback_user_data;
5865
};
5966

6067
LLAMA_API struct llama_context_params llama_context_default_params();

0 commit comments

Comments
 (0)