Skip to content

llama : refactor get / set state + remove redundant kv cache API #1143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
319 changes: 179 additions & 140 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2072,35 +2072,191 @@ int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lor
}
}

// Returns the KV cache that will contain the context for the
// ongoing prediction with the model.
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
return ctx->model.kv_self.buf.addr;
int llama_get_kv_cache_token_count(struct llama_context * ctx) {
return ctx->model.kv_self.n;
}

// Returns the size of the KV cache
size_t llama_get_kv_cache_size(struct llama_context * ctx) {
return ctx->model.kv_self.buf.size;
#define LLAMA_MAX_RNG_STATE 64*1024

// Returns the size of the state
size_t llama_get_state_size(struct llama_context * ctx) {
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = LLAMA_MAX_RNG_STATE;
const size_t s_logits_capacity = sizeof(size_t);
const size_t s_logits_size = sizeof(size_t);
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
const size_t s_embedding_size = sizeof(size_t);
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
const size_t s_kv_size = sizeof(size_t);
const size_t s_kv_ntok = sizeof(int);
const size_t s_kv = ctx->model.kv_self.buf.size;

const size_t s_total = (
+ s_rng_size
+ s_rng
+ s_logits_capacity
+ s_logits_size
+ s_logits
+ s_embedding_size
+ s_embedding
+ s_kv_size
+ s_kv_ntok
+ s_kv
);

return s_total;
}

int llama_get_kv_cache_token_count(struct llama_context * ctx) {
return ctx->model.kv_self.n;
// Copies the state to the specified destination address
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
uint8_t * out = dest;

// copy rng
{
std::stringstream rng_ss;
rng_ss << ctx->rng;

const size_t rng_size = rng_ss.str().size();
char rng_buf[LLAMA_MAX_RNG_STATE];

memset(&rng_buf[0], 0, LLAMA_MAX_RNG_STATE);
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());

memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size);
memcpy(out, &rng_buf[0], LLAMA_MAX_RNG_STATE); out += LLAMA_MAX_RNG_STATE;
}

// copy logits
{
const size_t logits_cap = ctx->logits.capacity();
const size_t logits_size = ctx->logits.size();

memcpy(out, &logits_cap, sizeof(logits_cap)); out += sizeof(logits_cap);
memcpy(out, &logits_size, sizeof(logits_size)); out += sizeof(logits_size);

if (logits_size) {
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
}

out += logits_cap * sizeof(float);
}

// copy embeddings
{
const size_t embedding_size = ctx->embedding.size();

memcpy(out, &embedding_size, sizeof(embedding_size)); out += sizeof(embedding_size);

if (embedding_size) {
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float));
out += embedding_size * sizeof(float);
}
}

// copy kv cache
{
const size_t kv_size = ctx->model.kv_self.buf.size;
const int kv_ntok = llama_get_kv_cache_token_count(ctx);

memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);

if (kv_size) {
memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
}
}

const size_t written = out - dest;
const size_t expected = llama_get_state_size(ctx);

LLAMA_ASSERT(written == expected);

return written;
}

// Sets the KV cache containing the current context for the model
void llama_set_kv_cache(
struct llama_context * ctx,
const uint8_t * kv_cache,
size_t n_size,
int n_token_count) {
// Make sure we have the same kv cache setup
LLAMA_ASSERT(ctx->model.kv_self.buf.size == n_size);
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
memcpy(ctx->model.kv_self.buf.addr, kv_cache, n_size);
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
ctx->model.kv_self.v->data = v_data;
ctx->model.kv_self.n = n_token_count;
// Sets the state reading from the specified source address
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
const uint8_t * in = src;

// set rng
{
size_t rng_size;
char rng_buf[LLAMA_MAX_RNG_STATE];

memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;

std::stringstream rng_ss;
rng_ss.str(std::string(&rng_buf[0], rng_size));
rng_ss >> ctx->rng;

LLAMA_ASSERT(rng_ss.fail() == false);
}

// set logits
{
size_t logits_cap;
size_t logits_size;

memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap);
memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);

LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);

if (logits_size) {
ctx->logits.resize(logits_size);
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
}

in += logits_cap * sizeof(float);
}

// set embeddings
{
size_t embedding_size;

memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);

LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);

if (embedding_size) {
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
in += embedding_size * sizeof(float);
}
}

// set kv cache
{
size_t kv_size;
int kv_ntok;

memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);

if (kv_size) {
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);

void * k_data = ctx->model.kv_self.k->data; // remember data pointers
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy

memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;

ctx->model.kv_self.k->data = k_data; // restore correct data pointers
ctx->model.kv_self.v->data = v_data;

}

ctx->model.kv_self.n = kv_ntok;
}

const size_t nread = in - src;
const size_t expected = llama_get_state_size(ctx);

LLAMA_ASSERT(nread == expected);

return nread;
}

int llama_eval(
Expand Down Expand Up @@ -2256,120 +2412,3 @@ std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_te
return ctx->model.tensors_by_name;
}

// Returns the size of the state
size_t llama_get_state_size(struct llama_context * ctx) {
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = 64*1024;
const size_t s_logits_capacity = sizeof(size_t);
const size_t s_logits_size = sizeof(size_t);
const size_t s_logits = ctx->logits.capacity() * sizeof(float);
const size_t s_embedding_size = sizeof(size_t);
const size_t s_embedding = ctx->embedding.size() * sizeof(float);
const size_t s_kv_size = sizeof(size_t);
const size_t s_kv_ntok = sizeof(int);
const size_t s_kv = llama_get_kv_cache_size(ctx);
const size_t s_total = (
+ s_rng_size
+ s_rng
+ s_logits_capacity
+ s_logits_size
+ s_logits
+ s_embedding_size
+ s_embedding
+ s_kv_size
+ s_kv_ntok
+ s_kv
);
return s_total;
}

// Copies the state to the specified destination address
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
std::stringstream rng_ss;
rng_ss << ctx->rng;
const size_t rng_size = rng_ss.str().size();
char rng_buf[64*1024];
memset(&rng_buf[0], 0, 64*1024);
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
const size_t logits_capacity = ctx->logits.capacity();
const size_t logits_size = ctx->logits.size();
const size_t embedding_size = ctx->embedding.size();
const size_t kv_size = llama_get_kv_cache_size(ctx);
const int kv_ntok = llama_get_kv_cache_token_count(ctx);

uint8_t * out = dest;
memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t);
memcpy(out, &rng_buf[0], 64*1024); out += 64*1024;
memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t);
memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t);
if (logits_size) {
memcpy(out, ctx->logits.data(), logits_size * sizeof(float));
}
out += logits_capacity * sizeof(float);
memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t);
if (embedding_size) {
memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float);
}
memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t);
memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int);
if (kv_size) {
memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size;
}
const size_t written = out - dest;
const size_t expected = llama_get_state_size(ctx);
LLAMA_ASSERT(written == expected);
return written;
}

// Sets the state reading from the specified source address
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
size_t rng_size;
char rng_buf[64*1024];
std::stringstream rng_ss;

const uint8_t * in = src;
memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t);
memcpy(&rng_buf[0], in, 64*1024); in += 64*1024;
rng_ss.str(std::string(&rng_buf[0], rng_size));
rng_ss >> ctx->rng;
LLAMA_ASSERT(rng_ss.fail() == false);

size_t logits_capacity;
size_t logits_size;
size_t embedding_size;
size_t kv_size;
int kv_ntok;

memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t);
memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t);
LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity);
if (logits_size) {
ctx->logits.resize(logits_size);
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
}
in += logits_capacity * sizeof(float);
memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t);
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
if (embedding_size) {
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
in += embedding_size * sizeof(float);
}
memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t);
memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int);
if (kv_size) {
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
memcpy(ctx->model.kv_self.buf.addr, in, kv_size);
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
ctx->model.kv_self.v->data = v_data;
in += kv_size;
}
ctx->model.kv_self.n = kv_ntok;
const size_t nread = in - src;
const size_t expected = llama_get_state_size(ctx);
LLAMA_ASSERT(nread == expected);
return nread;
}
14 changes: 0 additions & 14 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,9 @@ extern "C" {
const char * path_base_model,
int n_threads);

// Returns the KV cache that will contain the context for the
// ongoing prediction with the model.
LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);

// Returns the size of the KV cache
LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);

// Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);

// Sets the KV cache containing the current context for the model
LLAMA_API void llama_set_kv_cache(
struct llama_context * ctx,
const uint8_t * kv_cache,
size_t n_size,
int n_token_count);

// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
LLAMA_API size_t llama_get_state_size(struct llama_context * ctx);

Expand Down