Skip to content

Commit b0034d9

Browse files
authored
examples : add passkey test (#3856)
* examples : add passkey test * passkey : better prints * passkey : select pass key pos from CLI * passkey : simplify n_past logic * make : add passkey target * passkey : add "self-extend"-like context extension (#4810) * llama : "self-extend"-like context extension * passkey : add comment * passkey : add readme
1 parent b7e7982 commit b0034d9

File tree

9 files changed

+361
-1
lines changed

9 files changed

+361
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ models-mnt
5151
/lookup
5252
/main
5353
/metal
54+
/passkey
5455
/perplexity
5556
/q8dot
5657
/quantize

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
BUILD_TARGETS = \
33
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
44
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
5-
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o
5+
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
66

77
# Binaries only useful for tests
88
TEST_TARGETS = \
@@ -665,6 +665,9 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS
665665
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
666666
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
667667

668+
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
669+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
670+
668671
ifdef LLAMA_METAL
669672
metal: examples/metal/metal.cpp ggml.o $(OBJS)
670673
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ else()
3131
add_subdirectory(quantize-stats)
3232
add_subdirectory(save-load-state)
3333
add_subdirectory(simple)
34+
add_subdirectory(passkey)
3435
add_subdirectory(speculative)
3536
add_subdirectory(lookahead)
3637
add_subdirectory(lookup)

examples/batched/batched.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ int main(int argc, char ** argv) {
6969

7070
std::vector<llama_token> tokens_list;
7171
tokens_list = ::llama_tokenize(model, params.prompt, true);
72+
7273
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
7374

7475
// initialize the context

examples/passkey/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET passkey)
2+
add_executable(${TARGET} passkey.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/passkey/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# llama.cpp/example/passkey
2+
3+
See the following PRs for more info:
4+
5+
- https://github.com/ggerganov/llama.cpp/pull/3856
6+
- https://github.com/ggerganov/llama.cpp/pull/4810
7+
8+
### Usage
9+
10+
```bash
11+
make -j && ./passkey ./models/llama-7b-v2/ggml-model-f16.gguf 250
12+
```

examples/passkey/passkey.cpp

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
#include "common.h"
2+
#include "llama.h"
3+
4+
#include <cmath>
5+
#include <cstdio>
6+
#include <string>
7+
#include <vector>
8+
9+
int main(int argc, char ** argv) {
10+
gpt_params params;
11+
12+
if (argc == 1 || argv[1][0] == '-') {
13+
printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]);
14+
return 1 ;
15+
}
16+
17+
int seed = -1;
18+
19+
int n_junk = 250; // number of times to repeat the junk text
20+
int n_keep = 32; // number of tokens in the prompt prefix
21+
int n_grp = 1; // if more than 1 - perform LongLM SelfExtend
22+
int i_pos = -1; // position of the passkey in the junk text
23+
24+
if (argc >= 2) {
25+
params.model = argv[1];
26+
}
27+
28+
if (argc >= 3) {
29+
n_junk = std::stoi(argv[2]);
30+
}
31+
32+
if (argc >= 4) {
33+
n_grp = std::stoi(argv[3]);
34+
}
35+
36+
if (argc >= 5) {
37+
i_pos = std::stoi(argv[4]);
38+
}
39+
40+
if (argc >= 6) {
41+
seed = std::stoi(argv[5]);
42+
}
43+
44+
if (seed == -1) {
45+
seed = time(NULL);
46+
}
47+
48+
srand(seed);
49+
50+
if (i_pos == -1) {
51+
i_pos = rand() % n_junk;
52+
}
53+
54+
const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.";
55+
const std::string prompt_suffix = " What is the pass key? The pass key is";
56+
57+
// generate junk text
58+
params.prompt = prompt_prefix;
59+
60+
const int passkey = rand() % 50000 + 1;
61+
62+
for (int i = 0; i < n_junk; i++) {
63+
if (i % n_junk == i_pos) {
64+
params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key.";
65+
}
66+
67+
params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.";
68+
}
69+
70+
params.prompt += prompt_suffix;
71+
72+
// init LLM
73+
74+
llama_backend_init(params.numa);
75+
76+
// initialize the model
77+
78+
llama_model_params model_params = llama_model_default_params();
79+
80+
model_params.n_gpu_layers = 99; // offload all layers to the GPU
81+
82+
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
83+
84+
if (model == NULL) {
85+
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
86+
return 1;
87+
}
88+
89+
// initialize the context
90+
91+
llama_context_params ctx_params = llama_context_default_params();
92+
93+
ctx_params.seed = seed;
94+
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
95+
ctx_params.n_batch = 512;
96+
ctx_params.n_threads = params.n_threads;
97+
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
98+
99+
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
100+
101+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
102+
103+
if (ctx == NULL) {
104+
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
105+
return 1;
106+
}
107+
108+
// tokenize the prompt
109+
std::vector<llama_token> tokens_list;
110+
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
111+
112+
// tokenize the prefix and use it as a sink
113+
const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size();
114+
115+
const int n_tokens_all = tokens_list.size();
116+
117+
// we leave a margin of 16 tokens for the generated text - it should contain just the passkey
118+
const int n_predict = 16;
119+
120+
// total length of the sequences including the prompt
121+
const int n_len = n_tokens_all + n_predict;
122+
123+
const int n_ctx = llama_n_ctx(ctx) - n_keep;
124+
const int n_kv_req = llama_n_ctx(ctx);
125+
const int n_batch = ctx_params.n_batch;
126+
const int n_batch_grp = ctx_params.n_batch/n_grp;
127+
128+
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);
129+
130+
// print the prompt token-by-token
131+
132+
LOG_TEE("\n");
133+
LOG_TEE("prefix tokens: %d\n", n_tokens_prefix);
134+
LOG_TEE("prompt tokens: %d\n", n_tokens_all);
135+
//LOG_TEE("prompt: %s\n", params.prompt.c_str());
136+
137+
llama_batch batch = llama_batch_init(512, 0, 1);
138+
139+
int n_past = 0;
140+
141+
// fill the KV cache
142+
for (int i = 0; i < n_ctx; i += n_batch) {
143+
if (i > 0 && n_grp > 1) {
144+
// if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp
145+
const int ib = i/n_batch - 1;
146+
const int bd = n_batch_grp*(n_grp - 1);
147+
148+
llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd);
149+
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
150+
151+
n_past -= bd;
152+
}
153+
154+
llama_batch_clear(batch);
155+
156+
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
157+
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
158+
}
159+
160+
if (i + n_batch >= n_tokens_all) {
161+
batch.logits[batch.n_tokens - 1] = true;
162+
}
163+
164+
if (llama_decode(ctx, batch) != 0) {
165+
LOG_TEE("%s: llama_decode() failed\n", __func__);
166+
return 1;
167+
}
168+
169+
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
170+
171+
if (i + n_batch >= n_tokens_all) {
172+
break;
173+
}
174+
}
175+
176+
for (int i = n_ctx; i < n_tokens_all; i += n_batch) {
177+
const int n_discard = n_batch;
178+
179+
LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
180+
181+
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
182+
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
183+
184+
n_past -= n_discard;
185+
186+
llama_batch_clear(batch);
187+
188+
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
189+
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
190+
}
191+
192+
if (i + n_batch >= n_tokens_all) {
193+
batch.logits[batch.n_tokens - 1] = true;
194+
}
195+
196+
if (llama_decode(ctx, batch) != 0) {
197+
LOG_TEE("%s: llama_decode() failed\n", __func__);
198+
return 1;
199+
}
200+
201+
LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
202+
}
203+
204+
{
205+
const int n_discard = n_past - n_ctx + n_predict;
206+
207+
if (n_discard > 0) {
208+
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
209+
210+
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
211+
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
212+
213+
n_past -= n_discard;
214+
}
215+
}
216+
217+
LOG_TEE("\n");
218+
LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk);
219+
LOG_TEE("\n");
220+
221+
// main loop
222+
223+
int n_cur = n_tokens_all;
224+
int n_decode = 0;
225+
226+
LOG_TEE("%s", prompt_suffix.c_str());
227+
fflush(stdout);
228+
229+
const auto t_main_start = ggml_time_us();
230+
231+
while (n_cur <= n_len) {
232+
// sample the next token
233+
{
234+
auto n_vocab = llama_n_vocab(model);
235+
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
236+
237+
std::vector<llama_token_data> candidates;
238+
candidates.reserve(n_vocab);
239+
240+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
241+
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
242+
}
243+
244+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
245+
246+
// sample the most likely token
247+
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
248+
249+
// is it an end of stream?
250+
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
251+
LOG_TEE("\n");
252+
253+
break;
254+
}
255+
256+
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
257+
fflush(stdout);
258+
259+
n_decode += 1;
260+
261+
// prepare the next batch
262+
llama_batch_clear(batch);
263+
264+
// push this new token for next evaluation
265+
llama_batch_add(batch, new_token_id, n_past++, { 0 }, true);
266+
}
267+
268+
n_cur += 1;
269+
270+
// evaluate the current batch with the transformer model
271+
if (llama_decode(ctx, batch)) {
272+
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
273+
return 1;
274+
}
275+
}
276+
277+
LOG_TEE("\n");
278+
279+
const auto t_main_end = ggml_time_us();
280+
281+
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
282+
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
283+
284+
llama_print_timings(ctx);
285+
286+
fprintf(stderr, "\n");
287+
288+
llama_batch_free(batch);
289+
290+
llama_free(ctx);
291+
llama_free_model(model);
292+
293+
llama_backend_free();
294+
295+
return 0;
296+
}

0 commit comments

Comments
 (0)