Skip to content

kv-cache : simplify + fix warning for recurrent models #12756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 66 additions & 11 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2474,7 +2474,12 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
}

int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
return llama_kv_cache_n_tokens(ctx->get_kv_self());
const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}

return kv->get_n_tokens();
}

// deprecated
Expand All @@ -2483,7 +2488,12 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
}

int32_t llama_kv_self_used_cells(const llama_context * ctx) {
return llama_kv_cache_used_cells(ctx->get_kv_self());
const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}

return kv->get_used_cells();
}

// deprecated
Expand All @@ -2492,7 +2502,12 @@ void llama_kv_cache_clear(llama_context * ctx) {
}

void llama_kv_self_clear(llama_context * ctx) {
llama_kv_cache_clear(ctx->get_kv_self());
auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}

kv->clear();
}

// deprecated
Expand All @@ -2509,7 +2524,12 @@ bool llama_kv_self_seq_rm(
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_rm(ctx->get_kv_self(), seq_id, p0, p1);
auto * kv = ctx->get_kv_self();
if (!kv) {
return true;
}

return kv->seq_rm(seq_id, p0, p1);
}

// deprecated
Expand All @@ -2528,7 +2548,12 @@ void llama_kv_self_seq_cp(
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
return llama_kv_cache_seq_cp(ctx->get_kv_self(), seq_id_src, seq_id_dst, p0, p1);
auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}

return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}

// deprecated
Expand All @@ -2539,7 +2564,12 @@ void llama_kv_cache_seq_keep(
}

void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_keep(ctx->get_kv_self(), seq_id);
auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}

return kv->seq_keep(seq_id);
}

// deprecated
Expand All @@ -2558,7 +2588,12 @@ void llama_kv_self_seq_add(
llama_pos p0,
llama_pos p1,
llama_pos delta) {
return llama_kv_cache_seq_add(ctx->get_kv_self(), seq_id, p0, p1, delta);
auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}

return kv->seq_add(seq_id, p0, p1, delta);
}

// deprecated
Expand All @@ -2577,7 +2612,12 @@ void llama_kv_self_seq_div(
llama_pos p0,
llama_pos p1,
int d) {
return llama_kv_cache_seq_div(ctx->get_kv_self(), seq_id, p0, p1, d);
auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}

return kv->seq_div(seq_id, p0, p1, d);
}

// deprecated
Expand All @@ -2586,7 +2626,12 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
}

llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
return llama_kv_cache_seq_pos_max(ctx->get_kv_self(), seq_id);
const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}

return kv->seq_pos_max(seq_id);
}

// deprecated
Expand All @@ -2595,7 +2640,12 @@ void llama_kv_cache_defrag(llama_context * ctx) {
}

void llama_kv_self_defrag(llama_context * ctx) {
llama_kv_cache_defrag(ctx->get_kv_self());
auto * kv = ctx->get_kv_self();
if (!kv) {
return;
}

return kv->defrag();
}

// deprecated
Expand All @@ -2604,7 +2654,12 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) {
}

bool llama_kv_self_can_shift(const llama_context * ctx) {
return llama_kv_cache_can_shift(ctx->get_kv_self());
const auto * kv = ctx->get_kv_self();
if (!kv) {
return false;
}

return kv->get_can_shift();
}

// deprecated
Expand Down
122 changes: 8 additions & 114 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ int32_t llama_kv_cache_unified::get_n_tokens() const {
return result;
}

uint32_t llama_kv_cache_unified::get_used_cells() const {
int32_t llama_kv_cache_unified::get_used_cells() const {
return used;
}

Expand Down Expand Up @@ -428,7 +428,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
}
}

llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) {
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
llama_pos result = 0;

for (uint32_t i = 0; i < size; ++i) {
Expand Down Expand Up @@ -481,6 +481,11 @@ void llama_kv_cache_unified::restore() {
}

void llama_kv_cache_unified::commit() {
// TODO: tmp - move to llama_kv_cache_recurrent
if (recurrent) {
return;
}

if (pending.ranges.empty()) {
LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n",
__func__, "https://github.com/ggml-org/llama.cpp/pull/12695");
Expand Down Expand Up @@ -1273,117 +1278,6 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
return true;
}

//
// interface implementation
//

int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv) {
if (!kv) {
return 0;
}

return kv->get_n_tokens();
}

int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv) {
if (!kv) {
return 0;
}

return kv->get_used_cells();
}

void llama_kv_cache_clear(llama_kv_cache * kv) {
if (!kv) {
return;
}

kv->clear();
}

bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
if (!kv) {
return true;
}

return kv->seq_rm(seq_id, p0, p1);
}

void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1) {
if (!kv) {
return;
}

kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}

void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id) {
if (!kv) {
return;
}

kv->seq_keep(seq_id);
}

void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta) {
if (!kv) {
return;
}

kv->seq_add(seq_id, p0, p1, delta);
}

void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
if (!kv) {
return;
}

kv->seq_div(seq_id, p0, p1, d);
}

llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id) {
if (!kv) {
return 0;
}

return kv->seq_pos_max(seq_id);
}

void llama_kv_cache_defrag(llama_kv_cache * kv) {
if (!kv) {
return;
}

kv->defrag();
}

bool llama_kv_cache_can_shift(const llama_kv_cache * kv) {
if (!kv) {
return false;
}

return kv->get_can_shift();
}

//
// kv cache view
//
Expand All @@ -1393,7 +1287,7 @@ llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t
/*.n_cells = */ 0,
/*.n_seq_max = */ n_seq_max,
/*.token_count = */ 0,
/*.used_cells = */ llama_kv_cache_used_cells(&kv),
/*.used_cells = */ kv.get_used_cells(),
/*.max_contiguous = */ 0,
/*.max_contiguous_idx = */ -1,
/*.cells = */ nullptr,
Expand Down
52 changes: 5 additions & 47 deletions src/llama-kv-cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ struct llama_kv_cache : public llama_memory_i {
virtual void restore() = 0; // call if batch processing fails - restores the cache state
virtual void commit() = 0; // call after successful batch processing - clears any pending state

virtual int32_t get_n_tokens() const = 0;
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
virtual int32_t get_n_tokens() const = 0;
virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache

virtual bool get_can_shift() const = 0;

Expand Down Expand Up @@ -89,8 +89,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
uint32_t kv_size,
bool offload);

int32_t get_n_tokens() const override;
uint32_t get_used_cells() const override;
int32_t get_n_tokens() const override;
int32_t get_used_cells() const override;

size_t total_size() const;

Expand All @@ -109,7 +109,7 @@ class llama_kv_cache_unified : public llama_kv_cache {
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;

llama_pos seq_pos_max(llama_seq_id seq_id) override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;

bool get_can_shift() const override;

Expand Down Expand Up @@ -204,48 +204,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
// using llama_kv_cache_unified::llama_kv_cache_unified;
//};

// TODO: maybe become part of the public llama_kv_cache in the future
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);

int32_t llama_kv_cache_used_cells(const llama_kv_cache * kv);

void llama_kv_cache_clear(llama_kv_cache * kv);

bool llama_kv_cache_seq_rm(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1);

void llama_kv_cache_seq_cp(
llama_kv_cache * kv,
llama_seq_id seq_id_src,
llama_seq_id seq_id_dst,
llama_pos p0,
llama_pos p1);

void llama_kv_cache_seq_keep(llama_kv_cache * kv, llama_seq_id seq_id);

void llama_kv_cache_seq_add(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
llama_pos delta);

void llama_kv_cache_seq_div(
llama_kv_cache * kv,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);

llama_pos llama_kv_cache_seq_pos_max(llama_kv_cache * kv, llama_seq_id seq_id);

void llama_kv_cache_defrag(llama_kv_cache * kv);

bool llama_kv_cache_can_shift(const llama_kv_cache * kv);

//
// kv cache view
//
Expand Down
2 changes: 1 addition & 1 deletion src/llama-memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class llama_memory_i {
virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0;
virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0;

virtual llama_pos seq_pos_max(llama_seq_id seq_id) = 0;
virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;

virtual bool get_can_edit() const = 0;
};
Loading