Skip to content

Commit 6ea19fe

Browse files
committed
Refactor main.cpp
1 parent cfc5502 commit 6ea19fe

File tree

1 file changed

+54
-40
lines changed

1 file changed

+54
-40
lines changed

main.cpp

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,25 @@ int main(int argc, char ** argv) {
7676
// params.prompt = R"(// this function checks if the number n is prime
7777
//bool is_prime(int n) {)";
7878

79-
int64_t t_load_us = 0;
80-
8179
// load the model
82-
llama_context* ctx_ptr = llama_init_from_params(params);
80+
llama_context* ctx_ptr = nullptr;
81+
{
82+
ctx_ptr = llama_init_from_params(params);
83+
if (!ctx_ptr) {
84+
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
85+
return 1;
86+
}
87+
}
88+
8389
llama_context & ctx = *ctx_ptr;
84-
gpt_vocab & vocab = llama_context_get_vocab(ctx);
85-
86-
// print system information
87-
llama_print_context_info(ctx);
90+
const gpt_vocab & vocab = llama_context_get_vocab(ctx);
8891

8992
// Add a space in front of the first character to match OG llama tokenizer behavior
9093
params.prompt.insert(0, 1, ' ');
9194

9295
// tokenize the reverse prompt
93-
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.prompt);
96+
std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text(ctx, params.antiprompt);
97+
9498

9599
if (params.interactive) {
96100
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
@@ -126,8 +130,6 @@ int main(int argc, char ** argv) {
126130
" - If you want to submit another line, end your input in '\\'.\n");
127131
}
128132

129-
bool input_noecho = false;
130-
131133
// prompt user immediately after the starting prompt has been loaded
132134
if (params.interactive_start) {
133135
is_interacting = true;
@@ -138,39 +140,44 @@ int main(int argc, char ** argv) {
138140
printf(ANSI_COLOR_YELLOW);
139141
}
140142

141-
if(!llama_ingest_input(ctx, params.prompt))
143+
// Prepare the context with input
144+
// Send "beginning of string"
145+
llama_add_bos(ctx);
146+
147+
// load the input
148+
llama_update_input(ctx, params.prompt);
149+
150+
llama_print_startup_stats(ctx);
151+
152+
if(!llama_prepare_context(ctx))
142153
{
143-
fprintf(stderr, "Failed to ingest prompt\n");
154+
fprintf(stderr, "%s: failed to prepare context\n", __func__);
144155
return 1;
145-
};
146-
147-
// display text
148-
input_noecho = false;
149-
const std::vector<gpt_vocab::id>& embd = llama_context_get_embedding(ctx);
150-
if (!input_noecho) {
151-
for (auto id : embd) {
152-
printf("%s", vocab.id_to_token[id].c_str());
153-
}
154-
fflush(stdout);
155156
}
156157

157-
if (!input_noecho && params.use_color) {
158-
printf(ANSI_COLOR_RESET);
159-
}
160-
161-
const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens(ctx);
162-
163-
while (llama_context_is_finished(ctx) != true) {
164-
gpt_vocab::id model_output = 0;
165-
bool response = llama_infer(ctx, model_output);
166-
if (response) {
167-
printf("%s", vocab.id_to_token[model_output].c_str());
168-
fflush(stdout);
158+
bool input_noecho = false;
159+
bool is_end_of_text = false;
160+
while (llama_context_is_finished(ctx) == false) {
161+
std::string model_output{};
162+
163+
if (llama_has_unconsumed_input(ctx)) {
164+
llama_ingest_all_pending_input(ctx, !input_noecho);
165+
// reset color to default if we there is no pending user input
166+
if (!input_noecho && params.use_color) {
167+
printf(ANSI_COLOR_RESET);
168+
}
169+
}else{
170+
// Run inference if we don't have any pending input
171+
llama_infer(ctx, model_output, is_end_of_text);
172+
// print the single token output
173+
printf("%s", model_output.c_str());
174+
input_noecho = false;
169175
}
170176

171177
// in interactive mode, and not currently processing queued inputs;
172178
// check if we should prompt the user for more
173-
if (params.interactive) {
179+
if (params.interactive && !llama_has_unconsumed_input(ctx)) {
180+
const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens(ctx);
174181
// check for reverse prompt
175182
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {
176183
// reverse prompt found
@@ -200,32 +207,39 @@ int main(int argc, char ** argv) {
200207
buf[n_read] = '\n';
201208
buf[n_read+1] = 0;
202209
}
210+
203211
// Do not clear existing context in interactive mode
204-
llama_update_context_with_prompt(ctx, buf, false);
212+
llama_update_input(ctx, buf);
213+
input_noecho = true; // do not echo this again
205214
}
206215

207216
is_interacting = false;
208217
}
209218
}
210219

211220
// end of text token
212-
if (embd.back() == 2) {
221+
if (is_end_of_text) {
213222
fprintf(stderr, " [end of text]\n");
214223
break;
215224
}
216225
}
217226

218-
// report timing from context
227+
228+
#if defined (_WIN32)
229+
signal(SIGINT, SIG_DFL);
230+
#endif
231+
232+
// report timing
219233
{
220234
const int64_t t_main_end_us = ggml_time_us();
221235
llama_print_end_stats(ctx);
222236
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
223237
}
224-
llama_free_context(ctx_ptr);
238+
239+
llama_free_context(ctx_ptr);
225240

226241
if (params.use_color) {
227242
printf(ANSI_COLOR_RESET);
228243
}
229-
230244
return 0;
231245
}

0 commit comments

Comments
 (0)