Skip to content

Commit 792f200

Browse files
committed
examples : add passkey test
1 parent 6e08281 commit 792f200

File tree

4 files changed

+270
-0
lines changed

4 files changed

+270
-0
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ else()
3030
add_subdirectory(quantize-stats)
3131
add_subdirectory(save-load-state)
3232
add_subdirectory(simple)
33+
add_subdirectory(passkey)
3334
add_subdirectory(speculative)
3435
add_subdirectory(train-text-from-scratch)
3536
if (LLAMA_METAL)

examples/batched/batched.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ int main(int argc, char ** argv) {
6969

7070
std::vector<llama_token> tokens_list;
7171
tokens_list = ::llama_tokenize(model, params.prompt, true);
72+
7273
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
7374

7475
// initialize the context

examples/passkey/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET passkey)
2+
add_executable(${TARGET} passkey.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/passkey/passkey.cpp

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
#include "common.h"
2+
#include "llama.h"
3+
4+
#include <cmath>
5+
#include <cstdio>
6+
#include <string>
7+
#include <vector>
8+
9+
int main(int argc, char ** argv) {
10+
gpt_params params;
11+
12+
if (argc == 1 || argv[1][0] == '-') {
13+
printf("usage: %s MODEL_PATH N_JUNK SEED\n" , argv[0]);
14+
return 1 ;
15+
}
16+
17+
int seed = -1;
18+
19+
int n_junk = 250; // number of times to repeat the junk text
20+
int n_keep = 32; // number of tokens in the prompt prefix
21+
22+
if (argc >= 2) {
23+
params.model = argv[1];
24+
}
25+
26+
if (argc >= 3) {
27+
n_junk = std::stoi(argv[2]);
28+
}
29+
30+
if (argc >= 4) {
31+
seed = std::stoi(argv[3]);
32+
}
33+
34+
const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.";
35+
36+
if (seed == -1) {
37+
seed = time(NULL);
38+
}
39+
40+
srand(seed);
41+
42+
// generate junk text
43+
params.prompt = prompt_prefix;
44+
45+
const int n_insert = rand() % n_junk;
46+
const int passkey = rand() % 50000 + 1;
47+
48+
for (int i = 0; i < n_junk; i++) {
49+
if (i % n_junk == n_insert) {
50+
params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key.";
51+
}
52+
53+
params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.";
54+
}
55+
56+
params.prompt += " What is the pass key? The pass key is";
57+
58+
// init LLM
59+
60+
llama_backend_init(params.numa);
61+
62+
// initialize the model
63+
64+
llama_model_params model_params = llama_model_default_params();
65+
66+
model_params.n_gpu_layers = 99; // offload all layers to the GPU
67+
68+
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
69+
70+
if (model == NULL) {
71+
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
72+
return 1;
73+
}
74+
75+
// initialize the context
76+
77+
llama_context_params ctx_params = llama_context_default_params();
78+
79+
ctx_params.seed = seed;
80+
ctx_params.n_ctx = llama_n_ctx_train(model) + n_keep;
81+
ctx_params.n_batch = 512;
82+
ctx_params.n_threads = params.n_threads;
83+
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
84+
85+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
86+
87+
if (ctx == NULL) {
88+
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
89+
return 1;
90+
}
91+
92+
// tokenize the prefix and use it as a sink
93+
const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size();
94+
95+
// tokenize the prompt
96+
std::vector<llama_token> tokens_list;
97+
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
98+
99+
// we leave a margin of 16 tokens for the generated text - it should contain just the passkey
100+
const int n_predict = 16;
101+
102+
// total length of the sequences including the prompt
103+
const int n_len = tokens_list.size() + n_predict;
104+
105+
const int n_ctx = llama_n_ctx(ctx) - n_keep;
106+
const int n_kv_req = llama_n_ctx(ctx);
107+
const int n_batch = ctx_params.n_batch;
108+
109+
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req);
110+
111+
// print the prompt token-by-token
112+
113+
LOG_TEE("\n");
114+
LOG_TEE("prefix tokens: %d\n", n_tokens_prefix);
115+
LOG_TEE("prompt tokens: %d\n", (int) tokens_list.size());
116+
//LOG_TEE("prompt: %s\n", params.prompt.c_str());
117+
118+
llama_batch batch = llama_batch_init(512, 0, 1);
119+
120+
// fill the KV cache
121+
for (int i = 0; i < n_ctx; i += n_batch) {
122+
llama_batch_clear(batch);
123+
124+
for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) {
125+
llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false);
126+
}
127+
128+
if (i + n_batch >= (int) tokens_list.size()) {
129+
batch.logits[batch.n_tokens - 1] = true;
130+
}
131+
132+
if (llama_decode(ctx, batch) != 0) {
133+
LOG_TEE("%s: llama_decode() failed\n", __func__);
134+
return 1;
135+
}
136+
137+
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size()));
138+
139+
if (i + n_batch >= (int) tokens_list.size()) {
140+
break;
141+
}
142+
}
143+
144+
for (int i = n_ctx; i < (int) tokens_list.size(); i += n_batch) {
145+
const int n_discard = n_batch;
146+
147+
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
148+
149+
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
150+
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
151+
152+
llama_batch_clear(batch);
153+
154+
for (int j = 0; j < n_batch && i + j < (int) tokens_list.size(); j++) {
155+
llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false);
156+
}
157+
158+
if (i + n_batch >= (int) tokens_list.size()) {
159+
batch.logits[batch.n_tokens - 1] = true;
160+
}
161+
162+
if (llama_decode(ctx, batch) != 0) {
163+
LOG_TEE("%s: llama_decode() failed\n", __func__);
164+
return 1;
165+
}
166+
167+
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, (int) tokens_list.size()));
168+
}
169+
170+
int n_past = batch.pos[batch.n_tokens - 1];
171+
172+
{
173+
const int n_discard = n_past - n_ctx + n_predict;
174+
175+
if (n_discard > 0) {
176+
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
177+
178+
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
179+
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
180+
181+
n_past -= n_discard;
182+
}
183+
}
184+
185+
LOG_TEE("\n");
186+
187+
// main loop
188+
189+
int n_cur = tokens_list.size();
190+
int n_decode = 0;
191+
192+
const auto t_main_start = ggml_time_us();
193+
194+
while (n_cur <= n_len) {
195+
// sample the next token
196+
{
197+
auto n_vocab = llama_n_vocab(model);
198+
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
199+
200+
std::vector<llama_token_data> candidates;
201+
candidates.reserve(n_vocab);
202+
203+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
204+
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
205+
}
206+
207+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
208+
209+
// sample the most likely token
210+
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
211+
212+
// is it an end of stream?
213+
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
214+
LOG_TEE("\n");
215+
216+
break;
217+
}
218+
219+
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
220+
fflush(stdout);
221+
222+
n_decode += 1;
223+
n_past += 1;
224+
225+
// prepare the next batch
226+
llama_batch_clear(batch);
227+
228+
// push this new token for next evaluation
229+
llama_batch_add(batch, new_token_id, n_past, { 0 }, true);
230+
}
231+
232+
n_cur += 1;
233+
234+
// evaluate the current batch with the transformer model
235+
if (llama_decode(ctx, batch)) {
236+
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
237+
return 1;
238+
}
239+
}
240+
241+
LOG_TEE("\n");
242+
LOG_TEE("%s: passkey = %d, inserted at position %d / %d\n", __func__, passkey, n_insert, n_junk);
243+
244+
LOG_TEE("\n");
245+
246+
const auto t_main_end = ggml_time_us();
247+
248+
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
249+
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
250+
251+
llama_print_timings(ctx);
252+
253+
fprintf(stderr, "\n");
254+
255+
llama_batch_free(batch);
256+
257+
llama_free(ctx);
258+
llama_free_model(model);
259+
260+
llama_backend_free();
261+
262+
return 0;
263+
}

0 commit comments

Comments
 (0)