Skip to content

Commit 2f40c9f

Browse files
committed
llama : "self-extend"-like context extension
1 parent f2c9800 commit 2f40c9f

File tree

3 files changed

+66
-7
lines changed

3 files changed

+66
-7
lines changed

examples/passkey/passkey.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ int main(int argc, char ** argv) {
1010
gpt_params params;
1111

1212
if (argc == 1 || argv[1][0] == '-') {
13-
printf("usage: %s MODEL_PATH N_JUNK I_POS SEED\n" , argv[0]);
13+
printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]);
1414
return 1 ;
1515
}
1616

1717
int seed = -1;
1818

1919
int n_junk = 250; // number of times to repeat the junk text
2020
int n_keep = 32; // number of tokens in the prompt prefix
21+
int n_grp = 1; // if more than 1 - perform LongLM SelfExtend
2122
int i_pos = -1; // position of the passkey in the junk text
2223

2324
if (argc >= 2) {
@@ -29,11 +30,15 @@ int main(int argc, char ** argv) {
2930
}
3031

3132
if (argc >= 4) {
32-
i_pos = std::stoi(argv[3]);
33+
n_grp = std::stoi(argv[3]);
3334
}
3435

3536
if (argc >= 5) {
36-
seed = std::stoi(argv[4]);
37+
i_pos = std::stoi(argv[4]);
38+
}
39+
40+
if (argc >= 6) {
41+
seed = std::stoi(argv[5]);
3742
}
3843

3944
if (seed == -1) {
@@ -86,11 +91,13 @@ int main(int argc, char ** argv) {
8691
llama_context_params ctx_params = llama_context_default_params();
8792

8893
ctx_params.seed = seed;
89-
ctx_params.n_ctx = llama_n_ctx_train(model) + n_keep;
94+
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
9095
ctx_params.n_batch = 512;
9196
ctx_params.n_threads = params.n_threads;
9297
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
9398

99+
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
100+
94101
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
95102

96103
if (ctx == NULL) {
@@ -113,9 +120,10 @@ int main(int argc, char ** argv) {
113120
// total length of the sequences including the prompt
114121
const int n_len = n_tokens_all + n_predict;
115122

116-
const int n_ctx = llama_n_ctx(ctx) - n_keep;
117-
const int n_kv_req = llama_n_ctx(ctx);
118-
const int n_batch = ctx_params.n_batch;
123+
const int n_ctx = llama_n_ctx(ctx) - n_keep;
124+
const int n_kv_req = llama_n_ctx(ctx);
125+
const int n_batch = ctx_params.n_batch;
126+
const int n_batch_grp = ctx_params.n_batch/n_grp;
119127

120128
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req);
121129

@@ -132,6 +140,16 @@ int main(int argc, char ** argv) {
132140

133141
// fill the KV cache
134142
for (int i = 0; i < n_ctx; i += n_batch) {
143+
if (i > 0 && n_grp > 1) {
144+
const int ib = i/n_batch - 1;
145+
const int bd = n_batch_grp*(n_grp - 1);
146+
147+
llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd);
148+
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
149+
150+
n_past -= bd;
151+
}
152+
135153
llama_batch_clear(batch);
136154

137155
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {

llama.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,6 +1903,28 @@ static void llama_kv_cache_seq_shift(
19031903
cache.head = new_head != cache.size ? new_head : 0;
19041904
}
19051905

1906+
static void llama_kv_cache_seq_div(
1907+
struct llama_kv_cache & cache,
1908+
llama_seq_id seq_id,
1909+
llama_pos p0,
1910+
llama_pos p1,
1911+
int d) {
1912+
if (p0 < 0) p0 = 0;
1913+
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
1914+
1915+
for (uint32_t i = 0; i < cache.size; ++i) {
1916+
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
1917+
cache.has_shift = true;
1918+
1919+
{
1920+
llama_pos p_old = cache.cells[i].pos;
1921+
cache.cells[i].pos /= d;
1922+
cache.cells[i].delta += cache.cells[i].pos - p_old;
1923+
}
1924+
}
1925+
}
1926+
}
1927+
19061928
//
19071929
// model loading and saving
19081930
//
@@ -10140,9 +10162,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
1014010162
}
1014110163

1014210164
void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
10165+
if (delta == 0) {
10166+
return;
10167+
}
10168+
1014310169
llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
1014410170
}
1014510171

10172+
void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
10173+
if (d == 1) {
10174+
return;
10175+
}
10176+
10177+
llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
10178+
}
10179+
1014610180
// Returns the *maximum* size of the state
1014710181
size_t llama_get_state_size(const struct llama_context * ctx) {
1014810182
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.

llama.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,13 @@ extern "C" {
484484
llama_pos p1,
485485
llama_pos delta);
486486

487+
LLAMA_API void llama_kv_cache_seq_div(
488+
struct llama_context * ctx,
489+
llama_seq_id seq_id,
490+
llama_pos p0,
491+
llama_pos p1,
492+
int d);
493+
487494
//
488495
// State / sessions
489496
//

0 commit comments

Comments
 (0)