Skip to content

Commit 5a5d552

Browse files
committed
Refactor interactive mode in main.cpp
1 parent 70e72fc commit 5a5d552

File tree

1 file changed

+92
-70
lines changed

1 file changed

+92
-70
lines changed

main.cpp

Lines changed: 92 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#define ANSI_COLOR_RESET "\x1b[0m"
2828
#define ANSI_BOLD "\x1b[1m"
2929

30-
static const int EOS_TOKEN_ID = 2;
3130

3231
// determine number of model parts based on the dimension
3332
static const std::map<int, int> LLAMA_N_PARTS = {
@@ -55,6 +54,8 @@ void sigint_handler(int signo) {
5554
#endif
5655

5756

57+
void process_interactive_input(llama_context& ctx, const gpt_params& params);
58+
5859
int main(int argc, char ** argv) {
5960
ggml_time_init();
6061
const int64_t t_main_start_us = ggml_time_us();
@@ -85,15 +86,18 @@ int main(int argc, char ** argv) {
8586
// params.prompt = R"(// this function checks if the number n is prime
8687
//bool is_prime(int n) {)";
8788

88-
int64_t t_load_us = 0;
89-
9089
// load the model
91-
llama_context* ctx_ptr = llama_init_from_params(params);
90+
llama_context* ctx_ptr = nullptr;
91+
{
92+
ctx_ptr = llama_init_from_params(params);
93+
if (!ctx_ptr) {
94+
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
95+
return 1;
96+
}
97+
}
98+
9299
llama_context & ctx = *ctx_ptr;
93-
gpt_vocab & vocab = llama_context_get_vocab(ctx);
94-
95-
// print system information
96-
llama_print_context_info(ctx);
100+
const gpt_vocab & vocab = llama_context_get_vocab(ctx);
97101

98102
// Add a space in front of the first character to match OG llama tokenizer behavior
99103
params.prompt.insert(0, 1, ' ');
@@ -109,8 +113,9 @@ int main(int argc, char ** argv) {
109113
}
110114

111115
// tokenize the reverse prompt
112-
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.prompt);
116+
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.antiprompt);
113117

118+
// Setup interactive mode
114119
if (params.interactive) {
115120
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
116121
struct sigaction sigint_action;
@@ -146,50 +151,56 @@ int main(int argc, char ** argv) {
146151
is_interacting = true;
147152
}
148153

149-
bool input_noecho = false;
150-
151-
int remaining_tokens = params.n_predict;
154+
// prompt user immediately after the starting prompt has been loaded
155+
if (params.interactive_start) {
156+
is_interacting = true;
157+
}
152158

153159
// set the color for the prompt which will be output initially
154160
if (params.use_color) {
155161
printf(ANSI_COLOR_YELLOW);
156162
}
157163

158-
if(!llama_ingest_input(ctx, params.prompt))
164+
// Prepare the context with input
165+
// Send "beginning of string"
166+
llama_add_bos(ctx);
167+
168+
// load the input
169+
llama_update_input(ctx, params.prompt);
170+
171+
llama_print_startup_stats(ctx);
172+
173+
if(!llama_prepare_context(ctx))
159174
{
160-
fprintf(stderr, "Failed to ingest prompt\n");
175+
fprintf(stderr, "%s: failed to prepare context\n", __func__);
161176
return 1;
162-
};
163-
164-
// display text
165-
input_noecho = false;
166-
const std::vector<gpt_vocab::id>& embd = llama_context_get_embedding(ctx);
167-
if (!input_noecho) {
168-
for (auto id : embd) {
169-
printf("%s", vocab.id_to_token[id].c_str());
170-
}
171-
fflush(stdout);
172177
}
173178

174-
if (!input_noecho && params.use_color) {
175-
printf(ANSI_COLOR_RESET);
176-
}
177-
178-
const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens(ctx);
179-
180-
while (llama_context_is_finished(ctx) != true) {
181-
gpt_vocab::id model_output = 0;
182-
bool response = llama_infer(ctx, model_output);
183-
if (response) {
184-
printf("%s", vocab.id_to_token[model_output].c_str());
185-
fflush(stdout);
179+
bool input_noecho = false;
180+
bool is_end_of_text = false;
181+
while (llama_context_is_finished(ctx) == false) {
182+
std::string model_output{};
183+
184+
if (llama_has_unconsumed_input(ctx)) {
185+
llama_ingest_all_pending_input(ctx, !input_noecho);
186+
// reset color to default if we there is no pending user input
187+
if (!input_noecho && params.use_color) {
188+
printf(ANSI_COLOR_RESET);
189+
}
190+
}else{
191+
// Run inference if we don't have any pending input
192+
llama_infer(ctx, model_output, is_end_of_text);
193+
// print the single token output
194+
printf("%s", model_output.c_str());
195+
input_noecho = false;
186196
}
187197

188198
// in interactive mode, and not currently processing queued inputs;
189199
// check if we should prompt the user for more
190-
if (params.interactive) {
200+
if (params.interactive && !llama_has_unconsumed_input(ctx)) {
201+
const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens(ctx);
191202
// check for reverse prompt
192-
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
203+
if (antiprompt_inp.size() && llama_is_anti_prompt_present(ctx, antiprompt_inp)) {
193204
// reverse prompt found
194205
is_interacting = true;
195206
}
@@ -202,38 +213,14 @@ int main(int argc, char ** argv) {
202213
}
203214

204215
// currently being interactive
205-
bool another_line = true;
206-
while (another_line) {
207-
fflush(stdout);
208-
char buf[256] = {0};
209-
int n_read;
210-
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
211-
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
212-
// presumable empty line, consume the newline
213-
std::ignore = scanf("%*c");
214-
n_read=0;
215-
}
216-
if (params.use_color) printf(ANSI_COLOR_RESET);
217-
218-
if (n_read > 0 && buf[n_read-1]=='\\') {
219-
another_line = true;
220-
buf[n_read-1] = '\n';
221-
buf[n_read] = 0;
222-
} else {
223-
another_line = false;
224-
buf[n_read] = '\n';
225-
buf[n_read+1] = 0;
226-
}
227-
// Do not clear existing context in interactive mode
228-
llama_update_context_with_prompt(ctx, buf, false);
229-
}
230-
216+
process_interactive_input(ctx, params);
217+
input_noecho = true; // do not echo this input again
231218
is_interacting = false;
232219
}
233220
}
234221

235222
// end of text token
236-
if (embd.back() == EOS_TOKEN_ID) {
223+
if (is_end_of_text) {
237224
if (params.interactive) {
238225
is_interacting = true;
239226
} else {
@@ -243,23 +230,58 @@ int main(int argc, char ** argv) {
243230
}
244231

245232
// In interactive mode, respect the maximum number of tokens and drop back to user input when reached.
246-
if (params.interactive && remaining_tokens <= 0) {
247-
remaining_tokens = params.n_predict;
233+
if (params.interactive && llama_context_is_finished(ctx)) {
234+
llama_context_reset_remaining_tokens(ctx)
248235
is_interacting = true;
249236
}
250237
}
251238

252-
// report timing from context
239+
240+
#if defined (_WIN32)
241+
signal(SIGINT, SIG_DFL);
242+
#endif
243+
244+
// report timing
253245
{
254246
const int64_t t_main_end_us = ggml_time_us();
255247
llama_print_end_stats(ctx);
256248
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
257249
}
258-
llama_free_context(ctx_ptr);
250+
251+
llama_free_context(ctx_ptr);
259252

260253
if (params.use_color) {
261254
printf(ANSI_COLOR_RESET);
262255
}
263-
264256
return 0;
265257
}
258+
259+
void process_interactive_input(llama_context& ctx, const gpt_params& params)
260+
{
261+
bool another_line = true;
262+
while (another_line) {
263+
fflush(stdout);
264+
char buf[256] = {0};
265+
int n_read;
266+
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
267+
if (scanf("%255[^\n]%n%*c", buf, &n_read) <= 0) {
268+
// presumable empty line, consume the newline
269+
std::ignore = scanf("%*c");
270+
n_read=0;
271+
}
272+
if (params.use_color) printf(ANSI_COLOR_RESET);
273+
274+
if (n_read > 0 && buf[n_read-1]=='\\') {
275+
another_line = true;
276+
buf[n_read-1] = '\n';
277+
buf[n_read] = 0;
278+
} else {
279+
another_line = false;
280+
buf[n_read] = '\n';
281+
buf[n_read+1] = 0;
282+
}
283+
284+
// Do not clear existing context in interactive mode
285+
llama_update_input(ctx, buf);
286+
}
287+
}

0 commit comments

Comments
 (0)