Skip to content

Commit 8e39037

Browse files
committed
llama : refactor session file management
* llama : saving and restoring state checks for overflow The size of the buffers should now be given to the functions working with them, otherwise a truncated file could cause out of bound reads. * llama : stream from session file instead of copying into a big buffer Loading session files should no longer cause a memory usage spike. * llama : llama_state_get_size returns the actual size instead of max This is a breaking change, but makes that function *much* easier to keep up to date, and it also makes it reflect the behavior of llama_state_seq_get_size. * llama : share code between whole and seq_id-specific state saving Both session file types now use a more similar format. * llama : no longer store all hparams in session files Instead, the model arch name is stored. The layer count and the embedding dimensions of the KV cache are still verified when loading. Storing all the hparams is not necessary.
1 parent de28008 commit 8e39037

File tree

3 files changed

+609
-702
lines changed

3 files changed

+609
-702
lines changed

examples/save-load-state/save-load-state.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ int main(int argc, char ** argv) {
4747
// save state (rng, logits, embedding and kv_cache) to file
4848
{
4949
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
50-
const size_t written = llama_state_get_data(ctx, state_mem.data());
50+
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
5151

5252
FILE *fp_write = fopen("dump_state.bin", "wb");
5353
fwrite(state_mem.data(), 1, written, fp_write);
@@ -99,13 +99,16 @@ int main(int argc, char ** argv) {
9999

100100
// load state (rng, logits, embedding and kv_cache) from file
101101
{
102-
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
102+
std::vector<uint8_t> state_mem;
103103

104104
FILE * fp_read = fopen("dump_state.bin", "rb");
105+
fseek(fp_read, 0, SEEK_END);
106+
state_mem.resize(ftell(fp_read));
107+
fseek(fp_read, 0, SEEK_SET);
105108
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
106109
fclose(fp_read);
107110

108-
if (read != llama_state_set_data(ctx2, state_mem.data())) {
111+
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
109112
fprintf(stderr, "\n%s : failed to read state\n", __func__);
110113
llama_free(ctx2);
111114
llama_free_model(model);
@@ -159,13 +162,16 @@ int main(int argc, char ** argv) {
159162

160163
// load state (rng, logits, embedding and kv_cache) from file
161164
{
162-
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
165+
std::vector<uint8_t> state_mem;
163166

164167
FILE * fp_read = fopen("dump_state.bin", "rb");
168+
fseek(fp_read, 0, SEEK_END);
169+
state_mem.resize(ftell(fp_read));
170+
fseek(fp_read, 0, SEEK_SET);
165171
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
166172
fclose(fp_read);
167173

168-
if (read != llama_state_set_data(ctx3, state_mem.data())) {
174+
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
169175
fprintf(stderr, "\n%s : failed to read state\n", __func__);
170176
llama_free(ctx3);
171177
llama_free_model(model);
@@ -182,7 +188,7 @@ int main(int argc, char ** argv) {
182188
{
183189
// save kv of seq 0
184190
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
185-
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
191+
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
186192
if (ncopy != seq_store.size()) {
187193
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
188194
llama_free(ctx3);
@@ -196,7 +202,7 @@ int main(int argc, char ** argv) {
196202
fprintf(stderr, "%s : kv cache cleared\n", __func__);
197203

198204
// restore kv into seq 1
199-
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
205+
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
200206
if (nset != seq_store.size()) {
201207
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
202208
llama_free(ctx3);

include/llama.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@
4040
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
4141

4242
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
43-
#define LLAMA_SESSION_VERSION 7
43+
#define LLAMA_SESSION_VERSION 8
4444

4545
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
46-
#define LLAMA_STATE_SEQ_VERSION 1
46+
#define LLAMA_STATE_SEQ_VERSION 2
4747

4848
#ifdef __cplusplus
4949
extern "C" {
@@ -687,18 +687,20 @@ extern "C" {
687687
// State / sessions
688688
//
689689

690-
// Returns the maximum size in bytes of the state (rng, logits, embedding
691-
// and kv_cache) - will often be smaller after compacting tokens
692-
LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
693-
LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
690+
// Returns the *actual* size in bytes of the state
691+
// (rng, logits, embedding and kv_cache)
692+
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
693+
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
694+
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
694695
"use llama_state_get_size instead");
695696

696697
// Copies the state to the specified destination address.
697698
// Destination needs to have allocated enough memory.
698699
// Returns the number of bytes copied
699700
LLAMA_API size_t llama_state_get_data(
700701
struct llama_context * ctx,
701-
uint8_t * dst);
702+
uint8_t * dst,
703+
size_t size);
702704
LLAMA_API DEPRECATED(size_t llama_copy_state_data(
703705
struct llama_context * ctx,
704706
uint8_t * dst),
@@ -708,7 +710,8 @@ extern "C" {
708710
// Returns the number of bytes read
709711
LLAMA_API size_t llama_state_set_data(
710712
struct llama_context * ctx,
711-
const uint8_t * src);
713+
const uint8_t * src,
714+
size_t size);
712715
LLAMA_API DEPRECATED(size_t llama_set_state_data(
713716
struct llama_context * ctx,
714717
const uint8_t * src),
@@ -750,6 +753,7 @@ extern "C" {
750753
LLAMA_API size_t llama_state_seq_get_data(
751754
struct llama_context * ctx,
752755
uint8_t * dst,
756+
size_t size,
753757
llama_seq_id seq_id);
754758

755759
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
@@ -759,6 +763,7 @@ extern "C" {
759763
LLAMA_API size_t llama_state_seq_set_data(
760764
struct llama_context * ctx,
761765
const uint8_t * src,
766+
size_t size,
762767
llama_seq_id dest_seq_id);
763768

764769
LLAMA_API size_t llama_state_seq_save_file(

0 commit comments

Comments
 (0)