@@ -97,7 +97,9 @@ struct llama_context {
97
97
llama_model model;
98
98
llama_vocab vocab;
99
99
100
- size_t mem_per_token = 0 ;
100
+ // used to estimate memory requirements experimentally
101
+ size_t mem_at_token0 = 0 ; // first time
102
+ size_t mem_at_token1 = 0 ; // second time
101
103
102
104
// decode output (2-dimensional array: [n_tokens][n_vocab])
103
105
std::vector<float > logits;
@@ -626,14 +628,24 @@ static bool llama_eval_internal(
626
628
const int n_vocab = hparams.n_vocab ;
627
629
const int n_rot = hparams.n_embd /hparams.n_head ;
628
630
629
- auto & mem_per_token = lctx.mem_per_token ;
631
+ auto & mem_at_token0 = lctx.mem_at_token0 ;
632
+ auto & mem_at_token1 = lctx.mem_at_token1 ;
630
633
631
634
// TODO: fix this hardcoded size
632
- static size_t buf_size = 512u *1024 *1024 ;
635
+ static size_t buf_size = size_t (n_ctx) *1024 *1024 ;
633
636
static void * buf = malloc (buf_size);
634
637
635
- if (mem_per_token > 0 && mem_per_token*N > buf_size) {
636
- const size_t buf_size_new = 1.3 *(mem_per_token*N); // add 30% to account for ggml object overhead
638
+ const size_t C0 = mem_at_token0; // ~base
639
+ const int64_t C1 = mem_at_token1 - mem_at_token0; // delta 0,1
640
+
641
+ // TODO(Green-Sky): determine relation to N (batch size)
642
+ // const size_t size_estimate = C0 + size_t(C1 * (n_past + N));
643
+ const size_t size_estimate = C0 + C1 * n_past;
644
+
645
+ // fprintf(stderr, "\n%s: size_estimate %zu bytes (%zu | %zu)\n", __func__, size_estimate, mem_per_token0, mem_per_token1);
646
+
647
+ if (mem_at_token0 > 0 && mem_at_token1 > 0 && size_estimate > buf_size) {
648
+ const size_t buf_size_new = 1.1 *size_estimate; // just grow by 10%
637
649
// fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
638
650
639
651
// reallocate
@@ -830,10 +842,13 @@ static bool llama_eval_internal(
830
842
memcpy (logits_out.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
831
843
}
832
844
833
- if (mem_per_token == 0 ) {
834
- mem_per_token = ggml_used_mem (ctx0)/N;
845
+ if (mem_at_token0 == 0 ) {
846
+ mem_at_token0 = ggml_used_mem (ctx0);
847
+ } else if (mem_at_token1 == 0 ) {
848
+ mem_at_token1 = ggml_used_mem (ctx0);
835
849
}
836
850
// fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
851
+ // fprintf(stderr, "estimate/used_mem = %f\n", double(size_estimate) / ggml_used_mem(ctx0));
837
852
838
853
ggml_free (ctx0);
839
854
0 commit comments