@@ -10,14 +10,15 @@ int main(int argc, char ** argv) {
1010 gpt_params params;
1111
1212 if (argc == 1 || argv[1 ][0 ] == ' -' ) {
13- printf (" usage: %s MODEL_PATH N_JUNK I_POS SEED\n " , argv[0 ]);
13+ printf (" usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n " , argv[0 ]);
1414 return 1 ;
1515 }
1616
1717 int seed = -1 ;
1818
1919 int n_junk = 250 ; // number of times to repeat the junk text
2020 int n_keep = 32 ; // number of tokens in the prompt prefix
21+ int n_grp = 1 ; // if more than 1 - perform LongLM SelfExtend
2122 int i_pos = -1 ; // position of the passkey in the junk text
2223
2324 if (argc >= 2 ) {
@@ -29,11 +30,15 @@ int main(int argc, char ** argv) {
2930 }
3031
3132 if (argc >= 4 ) {
32- i_pos = std::stoi (argv[3 ]);
33+ n_grp = std::stoi (argv[3 ]);
3334 }
3435
3536 if (argc >= 5 ) {
36- seed = std::stoi (argv[4 ]);
37+ i_pos = std::stoi (argv[4 ]);
38+ }
39+
40+ if (argc >= 6 ) {
41+ seed = std::stoi (argv[5 ]);
3742 }
3843
3944 if (seed == -1 ) {
@@ -86,11 +91,13 @@ int main(int argc, char ** argv) {
8691 llama_context_params ctx_params = llama_context_default_params ();
8792
8893 ctx_params.seed = seed;
89- ctx_params.n_ctx = llama_n_ctx_train (model) + n_keep;
94+ ctx_params.n_ctx = llama_n_ctx_train (model)*n_grp + n_keep;
9095 ctx_params.n_batch = 512 ;
9196 ctx_params.n_threads = params.n_threads ;
9297 ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch ;
9398
99+ GGML_ASSERT (ctx_params.n_batch % n_grp == 0 && " n_batch must be divisible by n_grp" );
100+
94101 llama_context * ctx = llama_new_context_with_model (model, ctx_params);
95102
96103 if (ctx == NULL ) {
@@ -113,9 +120,10 @@ int main(int argc, char ** argv) {
113120 // total length of the sequences including the prompt
114121 const int n_len = n_tokens_all + n_predict;
115122
116- const int n_ctx = llama_n_ctx (ctx) - n_keep;
117- const int n_kv_req = llama_n_ctx (ctx);
118- const int n_batch = ctx_params.n_batch ;
123+ const int n_ctx = llama_n_ctx (ctx) - n_keep;
124+ const int n_kv_req = llama_n_ctx (ctx);
125+ const int n_batch = ctx_params.n_batch ;
126+ const int n_batch_grp = ctx_params.n_batch /n_grp;
119127
120128 LOG_TEE (" \n %s: n_len = %d, n_ctx = %d, n_kv_req = %d\n " , __func__, n_len, n_ctx, n_kv_req);
121129
@@ -132,6 +140,16 @@ int main(int argc, char ** argv) {
132140
133141 // fill the KV cache
134142 for (int i = 0 ; i < n_ctx; i += n_batch) {
143+ if (i > 0 && n_grp > 1 ) {
144+ const int ib = i/n_batch - 1 ;
145+ const int bd = n_batch_grp*(n_grp - 1 );
146+
147+ llama_kv_cache_seq_shift (ctx, 0 , n_past - n_batch, n_past, ib*bd);
148+ llama_kv_cache_seq_div (ctx, 0 , n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
149+
150+ n_past -= bd;
151+ }
152+
135153 llama_batch_clear (batch);
136154
137155 for (int j = 0 ; j < n_batch && i + j < n_tokens_all; j++) {
0 commit comments