Skip to content

Commit 8fed512

Browse files
committed
split behavior into --session and --prompt-cache
1 parent 8947b1b commit 8fed512

File tree

4 files changed

+25
-14
lines changed

4 files changed

+25
-14
lines changed

examples/common.cpp

+13-4
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
119119
params.prompt = argv[i];
120120
} else if (arg == "-e") {
121121
escape_prompt = true;
122+
} else if (arg == "--prompt-cache") {
123+
if (++i >= argc) {
124+
invalid_param = true;
125+
break;
126+
}
127+
params.path_prompt_cache = argv[i];
122128
} else if (arg == "--session") {
123129
if (++i >= argc) {
124130
invalid_param = true;
125131
break;
126132
}
127133
params.path_session = argv[i];
128-
} else if (arg == "--session-full") {
129-
params.session_full = true;
130134
} else if (arg == "-f" || arg == "--file") {
131135
if (++i >= argc) {
132136
invalid_param = true;
@@ -343,6 +347,11 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
343347
gpt_print_usage(argc, argv, default_params);
344348
exit(1);
345349
}
350+
if (!params.path_session.empty() && !params.path_prompt_cache.empty()) {
351+
fprintf(stderr, "error: only one of --prompt-cache or --session may be specified\n");
352+
gpt_print_usage(argc, argv, default_params);
353+
exit(1);
354+
}
346355
if (escape_prompt) {
347356
process_escapes(params.prompt);
348357
}
@@ -367,8 +376,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
367376
fprintf(stderr, " -p PROMPT, --prompt PROMPT\n");
368377
fprintf(stderr, " prompt to start generation with (default: empty)\n");
369378
fprintf(stderr, " -e process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
370-
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
371-
fprintf(stderr, " --session-full if specified, saves output to the session file in addition to prompt\n");
379+
fprintf(stderr, " --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n");
380+
fprintf(stderr, " --session FNAME file to store prompt and generations, allowing continuation (default: none)\n");
372381
fprintf(stderr, " --random-prompt start with a randomized prompt.\n");
373382
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
374383
fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");

examples/common.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ struct gpt_params {
4141

4242
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
4343
std::string prompt = "";
44-
std::string path_session = ""; // path to file for saving/loading model eval state
45-
std::string input_prefix = ""; // string to prefix user inputs with
46-
std::string input_suffix = ""; // string to suffix user inputs with
44+
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
45+
std::string path_session = ""; // file for saving/loading prompt and generations
46+
std::string input_prefix = ""; // string to prefix user inputs with
47+
std::string input_suffix = ""; // string to suffix user inputs with
4748
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
4849

4950
std::string lora_adapter = ""; // lora adapter path
@@ -53,7 +54,6 @@ struct gpt_params {
5354
bool random_prompt = false; // do not randomize prompt if none provided
5455
bool use_color = false; // use color to distinguish generations and inputs
5556
bool interactive = false; // interactive mode
56-
bool session_full = false; // save the output to the session file in addition to prompt
5757

5858
bool embedding = false; // get only sentence embedding
5959
bool interactive_first = false; // wait for user input immediately

examples/main/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ These options help improve the performance and memory usage of the LLaMA models.
270270

271271
- `-b N, --batch_size N`: Set the batch size for prompt processing (default: 512). This large batch size benefits users who have BLAS installed and enabled it during the build. If you don't have BLAS enabled ("BLAS=0"), you can use a smaller number, such as 8, to see the prompt progress as it's evaluated in some situations.
272272

273-
### Session Caching
273+
### Prompt Caching
274274

275-
- `--session FNAME`: Specify a file to load/save the session, which caches the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The session file is created during the first run and is reused in subsequent runs. If you change your prompt such that 75% or less of the session is reusable, the existing session file will be overwritten with a new, updated version to maintain optimal performance.
275+
- `--prompt-cache FNAME`: Specify a file to cache the model state after the initial prompt. This can significantly speed up the startup time when you're using longer prompts. The file is created during the first run and is reused and updated in subsequent runs.
276276

277277
### Quantization
278278

examples/main/main.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,10 @@ int main(int argc, char ** argv) {
140140
// Add a space in front of the first character to match OG llama tokenizer behavior
141141
params.prompt.insert(0, 1, ' ');
142142

143-
std::string path_session = params.path_session;
143+
std::string path_session =
144+
!params.path_session.empty() ? params.path_session : params.path_prompt_cache;
144145
std::vector<llama_token> session_tokens;
146+
bool resume_session = !params.path_session.empty();
145147

146148
if (!path_session.empty()) {
147149
fprintf(stderr, "%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str());
@@ -314,8 +316,8 @@ int main(int argc, char ** argv) {
314316
// insert n_left/2 tokens at the start of embd from last_n_tokens
315317
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
316318

317-
// stop saving session if we run out of context
318-
if (!path_session.empty() && params.session_full) {
319+
// stop saving session if we run out of context, saving whatever was evaled
320+
if (!path_session.empty() && resume_session) {
319321
llama_save_session_file(ctx, path_session.c_str(),
320322
session_tokens.data(), session_tokens.size());
321323
}
@@ -619,7 +621,7 @@ int main(int argc, char ** argv) {
619621
}
620622
}
621623

622-
if (!path_session.empty() && params.session_full) {
624+
if (!path_session.empty() && resume_session) {
623625
fprintf(stderr, "\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
624626
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
625627
}

0 commit comments

Comments
 (0)