Skip to content

examples : add passkey test #3856

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ models-mnt
/lookup
/main
/metal
/passkey
/perplexity
/q8dot
/quantize
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
BUILD_TARGETS = \
main quantize quantize-stats perplexity embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup tests/test-c.o
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o

# Binaries only useful for tests
TEST_TARGETS = \
Expand Down Expand Up @@ -665,6 +665,9 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

ifdef LLAMA_METAL
metal: examples/metal/metal.cpp ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS)
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ else()
add_subdirectory(quantize-stats)
add_subdirectory(save-load-state)
add_subdirectory(simple)
add_subdirectory(passkey)
add_subdirectory(speculative)
add_subdirectory(lookahead)
add_subdirectory(lookup)
Expand Down
1 change: 1 addition & 0 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ int main(int argc, char ** argv) {

std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(model, params.prompt, true);

const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;

// initialize the context
Expand Down
5 changes: 5 additions & 0 deletions examples/passkey/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET passkey)
add_executable(${TARGET} passkey.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
12 changes: 12 additions & 0 deletions examples/passkey/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# llama.cpp/example/passkey

See the following PRs for more info:

- https://github.com/ggerganov/llama.cpp/pull/3856
- https://github.com/ggerganov/llama.cpp/pull/4810

### Usage

```bash
make -j && ./passkey ./models/llama-7b-v2/ggml-model-f16.gguf 250
```
296 changes: 296 additions & 0 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
#include "common.h"
#include "llama.h"

#include <cmath>
#include <cstdio>
#include <string>
#include <vector>

int main(int argc, char ** argv) {
gpt_params params;

if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]);
return 1 ;
}

int seed = -1;

int n_junk = 250; // number of times to repeat the junk text
int n_keep = 32; // number of tokens in the prompt prefix
int n_grp = 1; // if more than 1 - perform LongLM SelfExtend
int i_pos = -1; // position of the passkey in the junk text

if (argc >= 2) {
params.model = argv[1];
}

if (argc >= 3) {
n_junk = std::stoi(argv[2]);
}

if (argc >= 4) {
n_grp = std::stoi(argv[3]);
}

if (argc >= 5) {
i_pos = std::stoi(argv[4]);
}

if (argc >= 6) {
seed = std::stoi(argv[5]);
}

if (seed == -1) {
seed = time(NULL);
}

srand(seed);

if (i_pos == -1) {
i_pos = rand() % n_junk;
}

const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.";
const std::string prompt_suffix = " What is the pass key? The pass key is";

// generate junk text
params.prompt = prompt_prefix;

const int passkey = rand() % 50000 + 1;

for (int i = 0; i < n_junk; i++) {
if (i % n_junk == i_pos) {
params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key.";
}

params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.";
}

params.prompt += prompt_suffix;

// init LLM

llama_backend_init(params.numa);

// initialize the model

llama_model_params model_params = llama_model_default_params();

model_params.n_gpu_layers = 99; // offload all layers to the GPU

llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}

// initialize the context

llama_context_params ctx_params = llama_context_default_params();

ctx_params.seed = seed;
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
ctx_params.n_batch = 512;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");

llama_context * ctx = llama_new_context_with_model(model, ctx_params);

if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}

// tokenize the prompt
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize(ctx, params.prompt, true);

// tokenize the prefix and use it as a sink
const int n_tokens_prefix = ::llama_tokenize(ctx, prompt_prefix, true).size();

const int n_tokens_all = tokens_list.size();

// we leave a margin of 16 tokens for the generated text - it should contain just the passkey
const int n_predict = 16;

// total length of the sequences including the prompt
const int n_len = n_tokens_all + n_predict;

const int n_ctx = llama_n_ctx(ctx) - n_keep;
const int n_kv_req = llama_n_ctx(ctx);
const int n_batch = ctx_params.n_batch;
const int n_batch_grp = ctx_params.n_batch/n_grp;

LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);

// print the prompt token-by-token

LOG_TEE("\n");
LOG_TEE("prefix tokens: %d\n", n_tokens_prefix);
LOG_TEE("prompt tokens: %d\n", n_tokens_all);
//LOG_TEE("prompt: %s\n", params.prompt.c_str());

llama_batch batch = llama_batch_init(512, 0, 1);

int n_past = 0;

// fill the KV cache
for (int i = 0; i < n_ctx; i += n_batch) {
if (i > 0 && n_grp > 1) {
// if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp
const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1);

llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);

n_past -= bd;
}

llama_batch_clear(batch);

for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
}

if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}

LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));

if (i + n_batch >= n_tokens_all) {
break;
}
}

for (int i = n_ctx; i < n_tokens_all; i += n_batch) {
const int n_discard = n_batch;

LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);

llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);

n_past -= n_discard;

llama_batch_clear(batch);

for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false);
}

if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}

LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
}

{
const int n_discard = n_past - n_ctx + n_predict;

if (n_discard > 0) {
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);

llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);

n_past -= n_discard;
}
}

LOG_TEE("\n");
LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk);
LOG_TEE("\n");

// main loop

int n_cur = n_tokens_all;
int n_decode = 0;

LOG_TEE("%s", prompt_suffix.c_str());
fflush(stdout);

const auto t_main_start = ggml_time_us();

while (n_cur <= n_len) {
// sample the next token
{
auto n_vocab = llama_n_vocab(model);
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
}

llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

// sample the most likely token
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);

// is it an end of stream?
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
LOG_TEE("\n");

break;
}

LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
fflush(stdout);

n_decode += 1;

// prepare the next batch
llama_batch_clear(batch);

// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_past++, { 0 }, true);
}

n_cur += 1;

// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
}

LOG_TEE("\n");

const auto t_main_end = ggml_time_us();

LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));

llama_print_timings(ctx);

fprintf(stderr, "\n");

llama_batch_free(batch);

llama_free(ctx);
llama_free_model(model);

llama_backend_free();

return 0;
}
Loading