Skip to content

Commit 7082d24

Browse files
lookup : add prompt lookup decoding example (#4484)
* initial commit, going through initializations * main loop finished, starting to debug * BUG: generates gibberish/repeating tokens after a while * kv_cache management * Added colors to distinguish drafted tokens (--color). Updated README * lookup : fix token positions in the draft batch * lookup : use n_draft from CLI params * lookup : final touches --------- Co-authored-by: Leon Ericsson <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent ba66175 commit 7082d24

File tree

7 files changed

+256
-2
lines changed

7 files changed

+256
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ models-mnt
4848
/llama-bench
4949
/llava-cli
5050
/lookahead
51+
/lookup
5152
/main
5253
/metal
5354
/perplexity

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
BUILD_TARGETS = \
33
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
44
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
5-
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead tests/test-c.o
5+
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o
66

77
# Binaries only useful for tests
88
TEST_TARGETS = \
@@ -664,6 +664,9 @@ parallel: examples/parallel/parallel.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
664664
lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
665665
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
666666

667+
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
668+
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
669+
667670
ifdef LLAMA_METAL
668671
metal: examples/metal/metal.cpp ggml.o $(OBJS)
669672
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct gpt_params {
5151
int32_t n_ctx = 512; // context size
5252
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
5353
int32_t n_keep = 0; // number of tokens to keep from initial prompt
54-
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
54+
int32_t n_draft = 8; // number of tokens to draft during speculative decoding
5555
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
5656
int32_t n_parallel = 1; // number of parallel sequences to decode
5757
int32_t n_sequences = 1; // number of sequences to decode
@@ -240,3 +240,4 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
240240

241241
// Dump the KV cache view showing individual sequences in each cell (long output).
242242
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
243+

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ else()
3333
add_subdirectory(simple)
3434
add_subdirectory(speculative)
3535
add_subdirectory(lookahead)
36+
add_subdirectory(lookup)
3637
add_subdirectory(train-text-from-scratch)
3738
if (LLAMA_METAL)
3839
add_subdirectory(metal)

examples/lookup/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET lookup)
2+
add_executable(${TARGET} lookup.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/lookup/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# llama.cpp/examples/lookup
2+
3+
Demonstration of Prompt Lookup Decoding
4+
5+
https://github.com/apoorvumang/prompt-lookup-decoding
6+
7+
The key parameters for lookup decoding are `ngram_min`, `ngram_max` and `n_draft`. The first two determine the size of the ngrams to search for in the prompt for a match. The latter specifies how many subsequent tokens to draft if a match is found.
8+
9+
More info:
10+
11+
https://github.com/ggerganov/llama.cpp/pull/4484
12+
https://github.com/ggerganov/llama.cpp/issues/4226
13+

examples/lookup/lookup.cpp

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
#include "common.h"
2+
#include "llama.h"
3+
4+
#include <cmath>
5+
#include <cstdio>
6+
#include <string>
7+
#include <vector>
8+
9+
int main(int argc, char ** argv){
10+
gpt_params params;
11+
12+
if (!gpt_params_parse(argc, argv, params)) {
13+
return 1;
14+
}
15+
16+
// max/min n-grams size to search for in prompt
17+
const int ngram_max = 4;
18+
const int ngram_min = 1;
19+
20+
// length of the candidate / draft sequence, if match is found
21+
const int n_draft = params.n_draft;
22+
23+
const bool dump_kv_cache = params.dump_kv_cache;
24+
25+
#ifndef LOG_DISABLE_LOGS
26+
log_set_target(log_filename_generator("lookup", "log"));
27+
LOG_TEE("Log start\n");
28+
log_dump_cmdline(argc, argv);
29+
#endif // LOG_DISABLE_LOGS
30+
31+
// init llama.cpp
32+
llama_backend_init(params.numa);
33+
34+
llama_model * model = NULL;
35+
llama_context * ctx = NULL;
36+
37+
// load the model
38+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
39+
40+
// tokenize the prompt
41+
const bool add_bos = llama_should_add_bos_token(model);
42+
LOG("add_bos tgt: %d\n", add_bos);
43+
44+
std::vector<llama_token> inp;
45+
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
46+
47+
const int max_context_size = llama_n_ctx(ctx);
48+
const int max_tokens_list_size = max_context_size - 4;
49+
50+
if ((int) inp.size() > max_tokens_list_size) {
51+
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
52+
return 1;
53+
}
54+
55+
fprintf(stderr, "\n\n");
56+
57+
for (auto id : inp) {
58+
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
59+
}
60+
61+
fflush(stderr);
62+
63+
const int n_input = inp.size();
64+
65+
const auto t_enc_start = ggml_time_us();
66+
67+
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
68+
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
69+
70+
const auto t_enc_end = ggml_time_us();
71+
72+
int n_predict = 0;
73+
int n_drafted = 0;
74+
int n_accept = 0;
75+
76+
int n_past = inp.size();
77+
78+
bool has_eos = false;
79+
80+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
81+
82+
std::vector<llama_token> draft;
83+
84+
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
85+
86+
// debug
87+
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
88+
89+
const auto t_dec_start = ggml_time_us();
90+
91+
while (true) {
92+
// debug
93+
if (dump_kv_cache) {
94+
llama_kv_cache_view_update(ctx, &kvc_view);
95+
dump_kv_cache_view_seqs(kvc_view, 40);
96+
}
97+
98+
// print current draft sequence
99+
LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
100+
101+
int i_dft = 0;
102+
while (true) {
103+
// sample from the target model
104+
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
105+
106+
llama_sampling_accept(ctx_sampling, ctx, id, true);
107+
108+
const std::string token_str = llama_token_to_piece(ctx, id);
109+
110+
if (!params.use_color) {
111+
printf("%s", token_str.c_str());
112+
}
113+
114+
if (id == llama_token_eos(model)) {
115+
has_eos = true;
116+
}
117+
118+
++n_predict;
119+
120+
// check if the target token matches the draft
121+
if (i_dft < (int) draft.size() && id == draft[i_dft]) {
122+
LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
123+
++n_accept;
124+
++n_past;
125+
++i_dft;
126+
inp.push_back(id);
127+
128+
if (params.use_color) {
129+
// color accepted draft token
130+
printf("\033[34m%s\033[0m", token_str.c_str());
131+
fflush(stdout);
132+
}
133+
continue;
134+
}
135+
136+
if (params.use_color) {
137+
printf("%s", token_str.c_str());
138+
}
139+
fflush(stdout);
140+
141+
142+
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
143+
144+
draft.clear();
145+
draft.push_back(id);
146+
inp.push_back(id);
147+
break;
148+
}
149+
150+
if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) {
151+
break;
152+
}
153+
154+
// KV cache management
155+
// clean the cache of draft tokens that weren't accepted
156+
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
157+
158+
llama_batch_clear(batch_tgt);
159+
llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
160+
161+
// generate n_pred tokens through prompt lookup
162+
auto prompt_lookup = [&]() -> void {
163+
int inp_size = inp.size();
164+
for (int ngram_size = ngram_max ; ngram_size > ngram_min; --ngram_size){
165+
const llama_token * ngram = &inp[inp_size - ngram_size];
166+
167+
for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) {
168+
bool match = true;
169+
for (int j = 0; j < ngram_size; ++j) {
170+
if (inp[i + j] != ngram[j]) {
171+
match = false;
172+
break;
173+
}
174+
}
175+
176+
if (match) {
177+
const int startIdx = i + ngram_size;
178+
const int endIdx = startIdx + n_draft;
179+
if (endIdx < inp_size) {
180+
for (int j = startIdx; j < endIdx; ++j) {
181+
LOG(" - draft candidate %d: %d\n", j, inp[j]);
182+
draft.push_back(inp[j]);
183+
llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true);
184+
++n_drafted;
185+
}
186+
return;
187+
}
188+
}
189+
}
190+
}
191+
return;
192+
};
193+
194+
prompt_lookup();
195+
196+
llama_decode(ctx, batch_tgt);
197+
++n_past;
198+
199+
draft.erase(draft.begin());
200+
}
201+
202+
auto t_dec_end = ggml_time_us();
203+
204+
LOG_TEE("\n\n");
205+
206+
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
207+
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
208+
209+
LOG_TEE("\n");
210+
LOG_TEE("n_draft = %d\n", n_draft);
211+
LOG_TEE("n_predict = %d\n", n_predict);
212+
LOG_TEE("n_drafted = %d\n", n_drafted);
213+
LOG_TEE("n_accept = %d\n", n_accept);
214+
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
215+
216+
LOG_TEE("\ntarget:\n");
217+
llama_print_timings(ctx);
218+
219+
llama_sampling_free(ctx_sampling);
220+
llama_batch_free(batch_tgt);
221+
222+
llama_free(ctx);
223+
llama_free_model(model);
224+
225+
llama_backend_free();
226+
227+
fprintf(stderr, "\n\n");
228+
229+
return 0;
230+
}

0 commit comments

Comments
 (0)