@@ -267,14 +267,16 @@ static void kv_cache_free(struct llama_kv_cache & cache) {
267
267
268
268
struct llama_context_params llama_context_default_params () {
269
269
struct llama_context_params result = {
270
- /* .n_ctx =*/ 512 ,
271
- /* .n_parts =*/ -1 ,
272
- /* .seed =*/ 0 ,
273
- /* .f16_kv =*/ false ,
274
- /* .logits_all =*/ false ,
275
- /* .vocab_only =*/ false ,
276
- /* .use_mlock =*/ false ,
277
- /* .embedding =*/ false ,
270
+ /* .n_ctx =*/ 512 ,
271
+ /* .n_parts =*/ -1 ,
272
+ /* .seed =*/ 0 ,
273
+ /* .f16_kv =*/ false ,
274
+ /* .logits_all =*/ false ,
275
+ /* .vocab_only =*/ false ,
276
+ /* .use_mlock =*/ false ,
277
+ /* .embedding =*/ false ,
278
+ /* .progress_callback =*/ nullptr ,
279
+ /* .progress_callback_user_data =*/ nullptr ,
278
280
};
279
281
280
282
return result;
@@ -290,7 +292,9 @@ static bool llama_model_load(
290
292
int n_ctx,
291
293
int n_parts,
292
294
ggml_type memory_type,
293
- bool vocab_only) {
295
+ bool vocab_only,
296
+ llama_progress_callback progress_callback,
297
+ void *progress_callback_user_data) {
294
298
fprintf (stderr, " %s: loading model from '%s' - please wait ...\n " , __func__, fname.c_str ());
295
299
296
300
const int64_t t_start_us = ggml_time_us ();
@@ -576,6 +580,10 @@ static bool llama_model_load(
576
580
577
581
std::vector<uint8_t > tmp;
578
582
583
+ if (progress_callback) {
584
+ progress_callback (0.0 , progress_callback_user_data);
585
+ }
586
+
579
587
for (int i = 0 ; i < n_parts; ++i) {
580
588
const int part_id = i;
581
589
// const int part_id = n_parts - i - 1;
@@ -589,6 +597,10 @@ static bool llama_model_load(
589
597
590
598
fin = std::ifstream (fname_part, std::ios::binary);
591
599
fin.rdbuf ()->pubsetbuf (f_buf.data (), f_buf.size ());
600
+
601
+ fin.seekg (0 , fin.end );
602
+ const size_t file_size = fin.tellg ();
603
+
592
604
fin.seekg (file_offset);
593
605
594
606
// load weights
@@ -764,6 +776,11 @@ static bool llama_model_load(
764
776
model.n_loaded ++;
765
777
766
778
// progress
779
+ if (progress_callback) {
780
+ double current_file_progress = double (size_t (fin.tellg ()) - file_offset) / double (file_size - file_offset);
781
+ double current_progress = (double (i) + current_file_progress) / double (n_parts);
782
+ progress_callback (current_progress, progress_callback_user_data);
783
+ }
767
784
if (model.n_loaded % 8 == 0 ) {
768
785
fprintf (stderr, " ." );
769
786
fflush (stderr);
@@ -786,6 +803,10 @@ static bool llama_model_load(
786
803
787
804
lctx.t_load_us = ggml_time_us () - t_start_us;
788
805
806
+ if (progress_callback) {
807
+ progress_callback (1.0 , progress_callback_user_data);
808
+ }
809
+
789
810
return true ;
790
811
}
791
812
@@ -1617,7 +1638,8 @@ struct llama_context * llama_init_from_file(
1617
1638
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
1618
1639
1619
1640
if (!llama_model_load (path_model, *ctx, params.n_ctx , params.n_parts , memory_type,
1620
- params.vocab_only )) {
1641
+ params.vocab_only , params.progress_callback ,
1642
+ params.progress_callback_user_data )) {
1621
1643
fprintf (stderr, " %s: failed to load model\n " , __func__);
1622
1644
llama_free (ctx);
1623
1645
return nullptr ;
0 commit comments