Skip to content

Commit 66874d4

Browse files
authored
Some improvements to loading the session with --prompt-cache (#1550)
Improvements to loading the session with `--prompt-cache` in the `main` example. 1. Fix an issue where the `--seed` parameter was ignored when loading a cached prompt. 2. When loading a cached prompt, you previously had to specify the saved prompt (or a prefix of it) again. This pull changes that behavior to default to the prompt that was cached if a prompt wasn't specified by the user.
1 parent 1fcdcc2 commit 66874d4

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

examples/main/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ These options help improve the performance and memory usage of the LLaMA models.
272272

273273
### Prompt Caching
274274

275-
- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs.
275+
- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs. **Note**: Restoring a cached prompt does not imply restoring the exact state of the session at the point it was saved. So even when specifying a specific seed, you are not guaranteed to get the same sequence of tokens as the original generation.
276276

277277
### Quantization
278278

examples/main/main.cpp

+14-4
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ int main(int argc, char ** argv) {
134134
return 0;
135135
}
136136

137-
// Add a space in front of the first character to match OG llama tokenizer behavior
138-
params.prompt.insert(0, 1, ' ');
139137

140138
std::string path_session = params.path_prompt_cache;
141139
std::vector<llama_token> session_tokens;
@@ -155,6 +153,7 @@ int main(int argc, char ** argv) {
155153
return 1;
156154
}
157155
session_tokens.resize(n_token_count_out);
156+
llama_set_rng_seed(ctx, params.seed);
158157

159158
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
160159
} else {
@@ -163,7 +162,16 @@ int main(int argc, char ** argv) {
163162
}
164163

165164
// tokenize the prompt
166-
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
165+
std::vector<llama_token> embd_inp;
166+
167+
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
168+
// Add a space in front of the first character to match OG llama tokenizer behavior
169+
params.prompt.insert(0, 1, ' ');
170+
171+
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
172+
} else {
173+
embd_inp = session_tokens;
174+
}
167175

168176
const int n_ctx = llama_n_ctx(ctx);
169177

@@ -181,7 +189,9 @@ int main(int argc, char ** argv) {
181189
}
182190
n_matching_session_tokens++;
183191
}
184-
if (n_matching_session_tokens >= embd_inp.size()) {
192+
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
193+
fprintf(stderr, "%s: using full prompt from session file\n", __func__);
194+
} else if (n_matching_session_tokens >= embd_inp.size()) {
185195
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
186196
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
187197
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",

0 commit comments

Comments
 (0)