@@ -10,14 +10,15 @@ int main(int argc, char ** argv) {
10
10
gpt_params params;
11
11
12
12
if (argc == 1 || argv[1 ][0 ] == ' -' ) {
13
- printf (" usage: %s MODEL_PATH N_JUNK I_POS SEED\n " , argv[0 ]);
13
+ printf (" usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n " , argv[0 ]);
14
14
return 1 ;
15
15
}
16
16
17
17
int seed = -1 ;
18
18
19
19
int n_junk = 250 ; // number of times to repeat the junk text
20
20
int n_keep = 32 ; // number of tokens in the prompt prefix
21
+ int n_grp = 1 ; // if more than 1 - perform LongLM SelfExtend
21
22
int i_pos = -1 ; // position of the passkey in the junk text
22
23
23
24
if (argc >= 2 ) {
@@ -29,11 +30,15 @@ int main(int argc, char ** argv) {
29
30
}
30
31
31
32
if (argc >= 4 ) {
32
- i_pos = std::stoi (argv[3 ]);
33
+ n_grp = std::stoi (argv[3 ]);
33
34
}
34
35
35
36
if (argc >= 5 ) {
36
- seed = std::stoi (argv[4 ]);
37
+ i_pos = std::stoi (argv[4 ]);
38
+ }
39
+
40
+ if (argc >= 6 ) {
41
+ seed = std::stoi (argv[5 ]);
37
42
}
38
43
39
44
if (seed == -1 ) {
@@ -86,11 +91,13 @@ int main(int argc, char ** argv) {
86
91
llama_context_params ctx_params = llama_context_default_params ();
87
92
88
93
ctx_params.seed = seed;
89
- ctx_params.n_ctx = llama_n_ctx_train (model) + n_keep;
94
+ ctx_params.n_ctx = llama_n_ctx_train (model)*n_grp + n_keep;
90
95
ctx_params.n_batch = 512 ;
91
96
ctx_params.n_threads = params.n_threads ;
92
97
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch ;
93
98
99
+ GGML_ASSERT (ctx_params.n_batch % n_grp == 0 && " n_batch must be divisible by n_grp" );
100
+
94
101
llama_context * ctx = llama_new_context_with_model (model, ctx_params);
95
102
96
103
if (ctx == NULL ) {
@@ -113,9 +120,10 @@ int main(int argc, char ** argv) {
113
120
// total length of the sequences including the prompt
114
121
const int n_len = n_tokens_all + n_predict;
115
122
116
- const int n_ctx = llama_n_ctx (ctx) - n_keep;
117
- const int n_kv_req = llama_n_ctx (ctx);
118
- const int n_batch = ctx_params.n_batch ;
123
+ const int n_ctx = llama_n_ctx (ctx) - n_keep;
124
+ const int n_kv_req = llama_n_ctx (ctx);
125
+ const int n_batch = ctx_params.n_batch ;
126
+ const int n_batch_grp = ctx_params.n_batch /n_grp;
119
127
120
128
LOG_TEE (" \n %s: n_len = %d, n_ctx = %d, n_kv_req = %d\n " , __func__, n_len, n_ctx, n_kv_req);
121
129
@@ -132,6 +140,16 @@ int main(int argc, char ** argv) {
132
140
133
141
// fill the KV cache
134
142
for (int i = 0 ; i < n_ctx; i += n_batch) {
143
+ if (i > 0 && n_grp > 1 ) {
144
+ const int ib = i/n_batch - 1 ;
145
+ const int bd = n_batch_grp*(n_grp - 1 );
146
+
147
+ llama_kv_cache_seq_shift (ctx, 0 , n_past - n_batch, n_past, ib*bd);
148
+ llama_kv_cache_seq_div (ctx, 0 , n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
149
+
150
+ n_past -= bd;
151
+ }
152
+
135
153
llama_batch_clear (batch);
136
154
137
155
for (int j = 0 ; j < n_batch && i + j < n_tokens_all; j++) {
0 commit comments