Skip to content

Commit 0c56923

Browse files
authored
examples : add save_load_state example (#1150)
* add save_load_state example * use <cstdio> instead of <iostream> and fprintf / printf instead of cout * renamed save-load-state example files replacing underscores by dashes
1 parent 957c8ae commit 0c56923

File tree

3 files changed

+133
-0
lines changed

3 files changed

+133
-0
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ else()
3434
add_subdirectory(quantize-stats)
3535
add_subdirectory(perplexity)
3636
add_subdirectory(embedding)
37+
add_subdirectory(save-load-state)
3738
endif()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(TARGET save-load-state)
2+
add_executable(${TARGET} save-load-state.cpp)
3+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
4+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#include <vector>
2+
#include <cstdio>
3+
#include <chrono>
4+
5+
#include "common.h"
6+
#include "llama.h"
7+
#include "llama.cpp"
8+
9+
using namespace std;
10+
11+
int main(int argc, char ** argv) {
12+
gpt_params params;
13+
params.model = "models/llama-7B/ggml-model.bin";
14+
params.seed = 42;
15+
params.n_threads = 4;
16+
params.repeat_last_n = 64;
17+
params.prompt = "The quick brown fox";
18+
19+
if (gpt_params_parse(argc, argv, params) == false) {
20+
return 1;
21+
}
22+
23+
auto lparams = llama_context_default_params();
24+
25+
lparams.n_ctx = params.n_ctx;
26+
lparams.n_parts = params.n_parts;
27+
lparams.seed = params.seed;
28+
lparams.f16_kv = params.memory_f16;
29+
lparams.use_mmap = params.use_mmap;
30+
lparams.use_mlock = params.use_mlock;
31+
32+
auto n_past = 0;
33+
auto last_n_tokens_data = vector<llama_token>(params.repeat_last_n, 0);
34+
35+
// init
36+
auto ctx = llama_init_from_file(params.model.c_str(), lparams);
37+
auto tokens = vector<llama_token>(params.n_ctx);
38+
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), tokens.size(), true);
39+
40+
if (n_prompt_tokens < 1) {
41+
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
42+
return 1;
43+
}
44+
45+
// evaluate prompt
46+
47+
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
48+
49+
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
50+
n_past += n_prompt_tokens;
51+
52+
// Save state (rng, logits, embedding and kv_cache) to file
53+
FILE *fp_write = fopen("dump_state.bin", "wb");
54+
auto state_size = llama_get_state_size(ctx);
55+
auto state_mem = new uint8_t[state_size];
56+
llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file
57+
fwrite(state_mem, 1, state_size, fp_write);
58+
fclose(fp_write);
59+
60+
// save state (last tokens)
61+
auto last_n_tokens_data_saved = vector<llama_token>(last_n_tokens_data);
62+
auto n_past_saved = n_past;
63+
64+
// first run
65+
printf("\n%s", params.prompt.c_str());
66+
for (auto i = 0; i < params.n_predict; i++) {
67+
auto next_token = llama_sample_top_p_top_k(
68+
ctx,
69+
&last_n_tokens_data.back() - params.repeat_last_n,
70+
params.repeat_last_n,
71+
40,
72+
1.0,
73+
1.0,
74+
1.1);
75+
auto next_token_str = llama_token_to_str(ctx, next_token);
76+
last_n_tokens_data.push_back(next_token);
77+
printf("%s", next_token_str);
78+
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
79+
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
80+
return 1;
81+
}
82+
n_past += 1;
83+
}
84+
printf("\n\n");
85+
86+
// free old model
87+
llama_free(ctx);
88+
89+
// load new model
90+
91+
auto ctx2 = llama_init_from_file(params.model.c_str(), lparams);
92+
93+
// Load state (rng, logits, embedding and kv_cache) from file
94+
FILE *fp_read = fopen("dump_state.bin", "rb");
95+
auto state_size2 = llama_get_state_size(ctx2);
96+
if (state_size != state_size2) {
97+
fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
98+
}
99+
fread(state_mem, 1, state_size, fp_read);
100+
llama_set_state_data(ctx2, state_mem); // could also read directly from memory mapped file
101+
fclose(fp_read);
102+
103+
// restore state (last tokens)
104+
last_n_tokens_data = last_n_tokens_data_saved;
105+
n_past = n_past_saved;
106+
107+
// second run
108+
for (auto i = 0; i < params.n_predict; i++) {
109+
auto next_token = llama_sample_top_p_top_k(
110+
ctx2,
111+
&last_n_tokens_data.back() - params.repeat_last_n,
112+
params.repeat_last_n,
113+
40,
114+
1.0,
115+
1.0,
116+
1.1);
117+
auto next_token_str = llama_token_to_str(ctx2, next_token);
118+
last_n_tokens_data.push_back(next_token);
119+
printf("%s", next_token_str);
120+
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
121+
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
122+
return 1;
123+
}
124+
n_past += 1;
125+
}
126+
printf("\n\n");
127+
return 0;
128+
}

0 commit comments

Comments
 (0)