Skip to content

Commit 3e1d293

Browse files
authored
kv-cache : simplify + fix warning for recurrent models (#12756)
ggml-ci
1 parent 1be76e4 commit 3e1d293

File tree

4 files changed

+80
-173
lines changed

4 files changed

+80
-173
lines changed

src/llama-context.cpp

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2474,7 +2474,12 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
24742474
}
24752475

24762476
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
2477-
return llama_kv_cache_n_tokens(ctx->get_kv_self());
2477+
const auto * kv = ctx->get_kv_self();
2478+
if (!kv) {
2479+
return 0;
2480+
}
2481+
2482+
return kv->get_n_tokens();
24782483
}
24792484

24802485
// deprecated
@@ -2483,7 +2488,12 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
24832488
}
24842489

24852490
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
2486-
return llama_kv_cache_used_cells(ctx->get_kv_self());
2491+
const auto * kv = ctx->get_kv_self();
2492+
if (!kv) {
2493+
return 0;
2494+
}
2495+
2496+
return kv->get_used_cells();
24872497
}
24882498

24892499
// deprecated
@@ -2492,7 +2502,12 @@ void llama_kv_cache_clear(llama_context * ctx) {
24922502
}
24932503

24942504
void llama_kv_self_clear(llama_context * ctx) {
2495-
llama_kv_cache_clear(ctx->get_kv_self());
2505+
auto * kv = ctx->get_kv_self();
2506+
if (!kv) {
2507+
return;
2508+
}
2509+
2510+
kv->clear();
24962511
}
24972512

24982513
// deprecated
@@ -2509,7 +2524,12 @@ bool llama_kv_self_seq_rm(
25092524
llama_seq_id seq_id,
25102525
llama_pos p0,
25112526
llama_pos p1) {
2512-
return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1);
2527+
auto * kv = ctx->get_kv_self();
2528+
if (!kv) {
2529+
return true;
2530+
}
2531+
2532+
return kv->seq_rm(seq_id, p0, p1);
25132533
}
25142534

25152535
// deprecated
@@ -2528,7 +2548,12 @@ void llama_kv_self_seq_cp(
25282548
llama_seq_id seq_id_dst,
25292549
llama_pos p0,
25302550
llama_pos p1) {
2531-
return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1);
2551+
auto * kv = ctx->get_kv_self();
2552+
if (!kv) {
2553+
return;
2554+
}
2555+
2556+
return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
25322557
}
25332558

25342559
// deprecated
@@ -2539,7 +2564,12 @@ void llama_kv_cache_seq_keep(
25392564
}
25402565

25412566
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
2542-
return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id);
2567+
auto * kv = ctx->get_kv_self();
2568+
if (!kv) {
2569+
return;
2570+
}
2571+
2572+
return kv->seq_keep(seq_id);
25432573
}
25442574

25452575
// deprecated
@@ -2558,7 +2588,12 @@ void llama_kv_self_seq_add(
25582588
llama_pos p0,
25592589
llama_pos p1,
25602590
llama_pos delta) {
2561-
return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta);
2591+
auto * kv = ctx->get_kv_self();
2592+
if (!kv) {
2593+
return;
2594+
}
2595+
2596+
return kv->seq_add(seq_id, p0, p1, delta);
25622597
}
25632598

25642599
// deprecated
@@ -2577,7 +2612,12 @@ void llama_kv_self_seq_div(
25772612
llama_pos p0,
25782613
llama_pos p1,
25792614
int d) {
2580-
return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d);
2615+
auto * kv = ctx->get_kv_self();
2616+
if (!kv) {
2617+
return;
2618+
}
2619+
2620+
return kv->seq_div(seq_id, p0, p1, d);
25812621
}
25822622

25832623
// deprecated
@@ -2586,7 +2626,12 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
25862626
}
25872627

25882628
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
2589-
return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id);
2629+
const auto * kv = ctx->get_kv_self();
2630+
if (!kv) {
2631+
return 0;
2632+
}
2633+
2634+
return kv->seq_pos_max(seq_id);
25902635
}
25912636

25922637
// deprecated
@@ -2595,7 +2640,12 @@ void llama_kv_cache_defrag(llama_context * ctx) {
25952640
}
25962641

25972642
void llama_kv_self_defrag(llama_context * ctx) {
2598-
llama_kv_cache_defrag(ctx->get_kv_self());
2643+
auto * kv = ctx->get_kv_self();
2644+
if (!kv) {
2645+
return;
2646+
}
2647+
2648+
return kv->defrag();
25992649
}
26002650

26012651
// deprecated
@@ -2604,7 +2654,12 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) {
26042654
}
26052655

26062656
bool llama_kv_self_can_shift(const llama_context * ctx) {
2607-
return llama_kv_cache_can_shift(ctx->get_kv_self());
2657+
const auto * kv = ctx->get_kv_self();
2658+
if (!kv) {
2659+
return false;
2660+
}
2661+
2662+
return kv->get_can_shift();
26082663
}
26092664

26102665
// deprecated

src/llama-kv-cache.cpp

Lines changed: 8 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const {
131131
return result;
132132
}
133133

134-
uint32_t llama_kv_cache_unified::get_used_cells() const {
134+
int32_t llama_kv_cache_unified::get_used_cells() const {
135135
return used;
136136
}
137137

@@ -428,7 +428,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
428428
}
429429
}
430430

431-
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
431+
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
432432
llama_pos result = 0;
433433

434434
for (uint32_t i = 0; i < size; ++i) {
@@ -481,6 +481,11 @@ void llama_kv_cache_unified::restore() {
481481
}
482482

483483
void llama_kv_cache_unified::commit() {
484+
// TODO: tmp - move to llama_kv_cache_recurrent
485+
if (recurrent) {
486+
return;
487+
}
488+
484489
if (pending.ranges.empty()) {
485490
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
486491
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
@@ -1273,117 +1278,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
12731278
return true;
12741279
}
12751280

1276-
//
1277-
// interface implementation
1278-
//
1279-
1280-
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
1281-
if (!kv) {
1282-
return 0;
1283-
}
1284-
1285-
return kv->get_n_tokens();
1286-
}
1287-
1288-
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
1289-
if (!kv) {
1290-
return 0;
1291-
}
1292-
1293-
return kv->get_used_cells();
1294-
}
1295-
1296-
void llama_kv_cache_clear(llama_kv_cache * kv) {
1297-
if (!kv) {
1298-
return;
1299-
}
1300-
1301-
kv->clear();
1302-
}
1303-
1304-
bool llama_kv_cache_seq_rm(
1305-
llama_kv_cache * kv,
1306-
llama_seq_id seq_id,
1307-
llama_pos p0,
1308-
llama_pos p1) {
1309-
if (!kv) {
1310-
return true;
1311-
}
1312-
1313-
return kv->seq_rm(seq_id, p0, p1);
1314-
}
1315-
1316-
void llama_kv_cache_seq_cp(
1317-
llama_kv_cache * kv,
1318-
llama_seq_id seq_id_src,
1319-
llama_seq_id seq_id_dst,
1320-
llama_pos p0,
1321-
llama_pos p1) {
1322-
if (!kv) {
1323-
return;
1324-
}
1325-
1326-
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
1327-
}
1328-
1329-
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
1330-
if (!kv) {
1331-
return;
1332-
}
1333-
1334-
kv->seq_keep(seq_id);
1335-
}
1336-
1337-
void llama_kv_cache_seq_add(
1338-
llama_kv_cache * kv,
1339-
llama_seq_id seq_id,
1340-
llama_pos p0,
1341-
llama_pos p1,
1342-
llama_pos delta) {
1343-
if (!kv) {
1344-
return;
1345-
}
1346-
1347-
kv->seq_add(seq_id, p0, p1, delta);
1348-
}
1349-
1350-
void llama_kv_cache_seq_div(
1351-
llama_kv_cache * kv,
1352-
llama_seq_id seq_id,
1353-
llama_pos p0,
1354-
llama_pos p1,
1355-
int d) {
1356-
if (!kv) {
1357-
return;
1358-
}
1359-
1360-
kv->seq_div(seq_id, p0, p1, d);
1361-
}
1362-
1363-
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
1364-
if (!kv) {
1365-
return 0;
1366-
}
1367-
1368-
return kv->seq_pos_max(seq_id);
1369-
}
1370-
1371-
void llama_kv_cache_defrag(llama_kv_cache * kv) {
1372-
if (!kv) {
1373-
return;
1374-
}
1375-
1376-
kv->defrag();
1377-
}
1378-
1379-
bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
1380-
if (!kv) {
1381-
return false;
1382-
}
1383-
1384-
return kv->get_can_shift();
1385-
}
1386-
13871281
//
13881282
// kv cache view
13891283
//
@@ -1393,7 +1287,7 @@ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t
13931287
/*.n_cells = */ 0,
13941288
/*.n_seq_max = */ n_seq_max,
13951289
/*.token_count = */ 0,
1396-
/*.used_cells = */ llama_kv_cache_used_cells(&kv),
1290+
/*.used_cells = */ kv.get_used_cells(),
13971291
/*.max_contiguous = */ 0,
13981292
/*.max_contiguous_idx = */ -1,
13991293
/*.cells = */ nullptr,

src/llama-kv-cache.h

Lines changed: 5 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ struct llama_kv_cache : public llama_memory_i {
2020
virtual void restore() = 0; // call if batch processing fails - restores the cache state
2121
virtual void commit() = 0; // call after successful batch processing - clears any pending state
2222

23-
virtual int32_t get_n_tokens() const = 0;
24-
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
23+
virtual int32_t get_n_tokens() const = 0;
24+
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
2525

2626
virtual bool get_can_shift() const = 0;
2727

@@ -89,8 +89,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
8989
uint32_t kv_size,
9090
bool offload);
9191

92-
int32_t get_n_tokens() const override;
93-
uint32_t get_used_cells() const override;
92+
int32_t get_n_tokens() const override;
93+
int32_t get_used_cells() const override;
9494

9595
size_t total_size() const;
9696

@@ -109,7 +109,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
109109
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
110110
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
111111

112-
llama_pos seq_pos_max(llama_seq_id seq_id) override;
112+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
113113

114114
bool get_can_shift() const override;
115115

@@ -204,48 +204,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
204204
// using llama_kv_cache_unified::llama_kv_cache_unified;
205205
//};
206206

207-
// TODO: maybe become part of the public llama_kv_cache in the future
208-
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
209-
210-
int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);
211-
212-
void llama_kv_cache_clear(llama_kv_cache * kv);
213-
214-
bool llama_kv_cache_seq_rm(
215-
llama_kv_cache * kv,
216-
llama_seq_id seq_id,
217-
llama_pos p0,
218-
llama_pos p1);
219-
220-
void llama_kv_cache_seq_cp(
221-
llama_kv_cache * kv,
222-
llama_seq_id seq_id_src,
223-
llama_seq_id seq_id_dst,
224-
llama_pos p0,
225-
llama_pos p1);
226-
227-
void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);
228-
229-
void llama_kv_cache_seq_add(
230-
llama_kv_cache * kv,
231-
llama_seq_id seq_id,
232-
llama_pos p0,
233-
llama_pos p1,
234-
llama_pos delta);
235-
236-
void llama_kv_cache_seq_div(
237-
llama_kv_cache * kv,
238-
llama_seq_id seq_id,
239-
llama_pos p0,
240-
llama_pos p1,
241-
int d);
242-
243-
llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);
244-
245-
void llama_kv_cache_defrag(llama_kv_cache * kv);
246-
247-
bool llama_kv_cache_can_shift(const llama_kv_cache * kv);
248-
249207
//
250208
// kv cache view
251209
//

src/llama-memory.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class llama_memory_i {
1515
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
1616
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;
1717

18-
virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
18+
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
1919

2020
virtual bool get_can_edit() const = 0;
2121
};

0 commit comments

Comments
 (0)