Skip to content

Commit cb65980

Browse files
compiladeNeo Zhang
authored and
Neo Zhang
committed
llama : refactor session file management (ggml-org#8699)
* llama : refactor session file management * llama : saving and restoring state checks for overflow The size of the buffers should now be given to the functions working with them, otherwise a truncated file could cause out of bound reads. * llama : stream from session file instead of copying into a big buffer Loading session files should no longer cause a memory usage spike. * llama : llama_state_get_size returns the actual size instead of max This is a breaking change, but makes that function *much* easier to keep up to date, and it also makes it reflect the behavior of llama_state_seq_get_size. * llama : share code between whole and seq_id-specific state saving Both session file types now use a more similar format. * llama : no longer store all hparams in session files Instead, the model arch name is stored. The layer count and the embedding dimensions of the KV cache are still verified when loading. Storing all the hparams is not necessary. * llama : fix uint64_t format type * llama : various integer type cast and format string fixes Some platforms use "%lu" and others "%llu" for uint64_t. Not sure how to handle that, so casting to size_t when displaying errors. * llama : remove _context suffix for llama_data_context * llama : fix session file loading llama_state_get_size cannot be used to get the max size anymore. * llama : more graceful error handling of invalid session files * llama : remove LLAMA_MAX_RNG_STATE It's no longer necessary to limit the size of the RNG state, because the max size of session files is not estimated anymore. * llama : cast seq_id in comparison with unsigned n_seq_max
1 parent cb87e5c commit cb65980

File tree

3 files changed

+658
-727
lines changed

3 files changed

+658
-727
lines changed

examples/save-load-state/save-load-state.cpp

+13-7
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ int main(int argc, char ** argv) {
4747
// save state (rng, logits, embedding and kv_cache) to file
4848
{
4949
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
50-
const size_t written = llama_state_get_data(ctx, state_mem.data());
50+
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
5151

5252
FILE *fp_write = fopen("dump_state.bin", "wb");
5353
fwrite(state_mem.data(), 1, written, fp_write);
@@ -99,13 +99,16 @@ int main(int argc, char ** argv) {
9999

100100
// load state (rng, logits, embedding and kv_cache) from file
101101
{
102-
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
102+
std::vector<uint8_t> state_mem;
103103

104104
FILE * fp_read = fopen("dump_state.bin", "rb");
105+
fseek(fp_read, 0, SEEK_END);
106+
state_mem.resize(ftell(fp_read));
107+
fseek(fp_read, 0, SEEK_SET);
105108
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
106109
fclose(fp_read);
107110

108-
if (read != llama_state_set_data(ctx2, state_mem.data())) {
111+
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
109112
fprintf(stderr, "\n%s : failed to read state\n", __func__);
110113
llama_free(ctx2);
111114
llama_free_model(model);
@@ -159,13 +162,16 @@ int main(int argc, char ** argv) {
159162

160163
// load state (rng, logits, embedding and kv_cache) from file
161164
{
162-
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
165+
std::vector<uint8_t> state_mem;
163166

164167
FILE * fp_read = fopen("dump_state.bin", "rb");
168+
fseek(fp_read, 0, SEEK_END);
169+
state_mem.resize(ftell(fp_read));
170+
fseek(fp_read, 0, SEEK_SET);
165171
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
166172
fclose(fp_read);
167173

168-
if (read != llama_state_set_data(ctx3, state_mem.data())) {
174+
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
169175
fprintf(stderr, "\n%s : failed to read state\n", __func__);
170176
llama_free(ctx3);
171177
llama_free_model(model);
@@ -182,7 +188,7 @@ int main(int argc, char ** argv) {
182188
{
183189
// save kv of seq 0
184190
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
185-
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
191+
const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
186192
if (ncopy != seq_store.size()) {
187193
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
188194
llama_free(ctx3);
@@ -196,7 +202,7 @@ int main(int argc, char ** argv) {
196202
fprintf(stderr, "%s : kv cache cleared\n", __func__);
197203

198204
// restore kv into seq 1
199-
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
205+
const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
200206
if (nset != seq_store.size()) {
201207
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
202208
llama_free(ctx3);

include/llama.h

+13-10
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,15 @@
3333

3434
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
3535

36-
#define LLAMA_MAX_RNG_STATE (64*1024)
37-
3836
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
3937
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
4038
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
4139

4240
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
43-
#define LLAMA_SESSION_VERSION 7
41+
#define LLAMA_SESSION_VERSION 8
4442

4543
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
46-
#define LLAMA_STATE_SEQ_VERSION 1
44+
#define LLAMA_STATE_SEQ_VERSION 2
4745

4846
#ifdef __cplusplus
4947
extern "C" {
@@ -691,18 +689,20 @@ extern "C" {
691689
// State / sessions
692690
//
693691

694-
// Returns the maximum size in bytes of the state (rng, logits, embedding
695-
// and kv_cache) - will often be smaller after compacting tokens
696-
LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
697-
LLAMA_API DEPRECATED(size_t llama_get_state_size(const struct llama_context * ctx),
692+
// Returns the *actual* size in bytes of the state
693+
// (rng, logits, embedding and kv_cache)
694+
// Only use when saving the state, not when restoring it, otherwise the size may be too small.
695+
LLAMA_API size_t llama_state_get_size(struct llama_context * ctx);
696+
LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx),
698697
"use llama_state_get_size instead");
699698

700699
// Copies the state to the specified destination address.
701700
// Destination needs to have allocated enough memory.
702701
// Returns the number of bytes copied
703702
LLAMA_API size_t llama_state_get_data(
704703
struct llama_context * ctx,
705-
uint8_t * dst);
704+
uint8_t * dst,
705+
size_t size);
706706
LLAMA_API DEPRECATED(size_t llama_copy_state_data(
707707
struct llama_context * ctx,
708708
uint8_t * dst),
@@ -712,7 +712,8 @@ extern "C" {
712712
// Returns the number of bytes read
713713
LLAMA_API size_t llama_state_set_data(
714714
struct llama_context * ctx,
715-
const uint8_t * src);
715+
const uint8_t * src,
716+
size_t size);
716717
LLAMA_API DEPRECATED(size_t llama_set_state_data(
717718
struct llama_context * ctx,
718719
const uint8_t * src),
@@ -754,6 +755,7 @@ extern "C" {
754755
LLAMA_API size_t llama_state_seq_get_data(
755756
struct llama_context * ctx,
756757
uint8_t * dst,
758+
size_t size,
757759
llama_seq_id seq_id);
758760

759761
// Copy the sequence data (originally copied with `llama_state_seq_get_data`) into the specified sequence
@@ -763,6 +765,7 @@ extern "C" {
763765
LLAMA_API size_t llama_state_seq_set_data(
764766
struct llama_context * ctx,
765767
const uint8_t * src,
768+
size_t size,
766769
llama_seq_id dest_seq_id);
767770

768771
LLAMA_API size_t llama_state_seq_save_file(

0 commit comments

Comments
 (0)