@@ -47,25 +47,6 @@ void sigint_handler(int signo) {
47
47
}
48
48
#endif
49
49
50
- const char * llama_print_system_info (void ) {
51
- static std::string s;
52
-
53
- s = " " ;
54
- s += " AVX = " + std::to_string (ggml_cpu_has_avx ()) + " | " ;
55
- s += " AVX2 = " + std::to_string (ggml_cpu_has_avx2 ()) + " | " ;
56
- s += " AVX512 = " + std::to_string (ggml_cpu_has_avx512 ()) + " | " ;
57
- s += " FMA = " + std::to_string (ggml_cpu_has_fma ()) + " | " ;
58
- s += " NEON = " + std::to_string (ggml_cpu_has_neon ()) + " | " ;
59
- s += " ARM_FMA = " + std::to_string (ggml_cpu_has_arm_fma ()) + " | " ;
60
- s += " F16C = " + std::to_string (ggml_cpu_has_f16c ()) + " | " ;
61
- s += " FP16_VA = " + std::to_string (ggml_cpu_has_fp16_va ()) + " | " ;
62
- s += " WASM_SIMD = " + std::to_string (ggml_cpu_has_wasm_simd ()) + " | " ;
63
- s += " BLAS = " + std::to_string (ggml_cpu_has_blas ()) + " | " ;
64
- s += " SSE3 = " + std::to_string (ggml_cpu_has_sse3 ()) + " | " ;
65
- s += " VSX = " + std::to_string (ggml_cpu_has_vsx ()) + " | " ;
66
-
67
- return s.c_str ();
68
- }
69
50
70
51
int main (int argc, char ** argv) {
71
52
ggml_time_init ();
@@ -94,50 +75,24 @@ int main(int argc, char ** argv) {
94
75
95
76
int64_t t_load_us = 0 ;
96
77
97
- gpt_vocab vocab;
98
- llama_model model;
99
-
100
78
// load the model
101
- {
102
- const int64_t t_start_us = ggml_time_us ();
79
+ const int64_t t_start_us = ggml_time_us ();
80
+ // TODO: FIXME: this is a hack
81
+ llama_context* ctx_ptr = llama_init_from_params (params);
82
+ llama_context & ctx = *ctx_ptr;
83
+ gpt_vocab & vocab = llama_context_get_vocab (ctx);
103
84
104
- if (!llama_model_load (params.model , model, vocab, 512 )) { // TODO: set context from user input ??
105
- fprintf (stderr, " %s: failed to load model from '%s'\n " , __func__, params.model .c_str ());
106
- return 1 ;
107
- }
108
-
109
- t_load_us = ggml_time_us () - t_start_us;
110
- }
85
+ t_load_us = ggml_time_us () - t_start_us;
111
86
112
87
// print system information
113
- {
114
- fprintf (stderr, " \n " );
115
- fprintf (stderr, " system_info: n_threads = %d / %d | %s\n " ,
116
- params.n_threads , std::thread::hardware_concurrency (), llama_print_system_info ());
117
- }
118
-
119
- int n_past = 0 ;
120
-
121
- int64_t t_sample_us = 0 ;
122
- int64_t t_predict_us = 0 ;
123
-
124
- std::vector<float > logits;
88
+ llama_print_context_info (ctx);
125
89
126
90
// tokenize the prompt
127
- std::vector<gpt_vocab::id> embd_inp = ::llama_tokenize (vocab, params.prompt , true );
128
-
129
- params.n_predict = std::min (params.n_predict , model.hparams .n_ctx - (int ) embd_inp.size ());
91
+ std::vector<gpt_vocab::id> embd_inp = llama_tokenize_text (ctx, params.prompt );
130
92
131
93
// tokenize the reverse prompt
132
- std::vector<gpt_vocab::id> antiprompt_inp = :: llama_tokenize (vocab , params.antiprompt , false );
94
+ std::vector<gpt_vocab::id> antiprompt_inp = llama_tokenize_text (ctx , params.prompt );
133
95
134
- fprintf (stderr, " \n " );
135
- fprintf (stderr, " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
136
- fprintf (stderr, " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
137
- for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
138
- fprintf (stderr, " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
139
- }
140
- fprintf (stderr, " \n " );
141
96
if (params.interactive ) {
142
97
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
143
98
struct sigaction sigint_action;
@@ -161,17 +116,6 @@ int main(int argc, char ** argv) {
161
116
fprintf (stderr, " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
162
117
fprintf (stderr, " \n\n " );
163
118
164
- std::vector<gpt_vocab::id> embd;
165
-
166
- // determine the required inference memory per token:
167
- size_t mem_per_token = 0 ;
168
- llama_eval (model, params.n_threads , 0 , { 0 , 1 , 2 , 3 }, logits, mem_per_token);
169
-
170
- int last_n_size = params.repeat_last_n ;
171
- std::vector<gpt_vocab::id> last_n_tokens (last_n_size);
172
- std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
173
-
174
-
175
119
if (params.interactive ) {
176
120
fprintf (stderr, " == Running in interactive mode. ==\n "
177
121
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
@@ -181,8 +125,6 @@ int main(int argc, char ** argv) {
181
125
" - If you want to submit another line, end your input in '\\ '.\n " );
182
126
}
183
127
184
- int remaining_tokens = params.n_predict ;
185
- int input_consumed = 0 ;
186
128
bool input_noecho = false ;
187
129
188
130
// prompt user immediately after the starting prompt has been loaded
@@ -195,81 +137,39 @@ int main(int argc, char ** argv) {
195
137
printf (ANSI_COLOR_YELLOW);
196
138
}
197
139
198
- while (remaining_tokens > 0 ) {
199
- // predict
200
- if (embd.size () > 0 ) {
201
- const int64_t t_start_us = ggml_time_us ();
202
-
203
- if (!llama_eval (model, params.n_threads , n_past, embd, logits, mem_per_token)) {
204
- fprintf (stderr, " Failed to predict\n " );
205
- return 1 ;
206
- }
207
-
208
- t_predict_us += ggml_time_us () - t_start_us;
209
- }
210
-
211
- n_past += embd.size ();
212
- embd.clear ();
213
-
214
- if (embd_inp.size () <= input_consumed) {
215
- // out of user input, sample next token
216
- const float top_k = params.top_k ;
217
- const float top_p = params.top_p ;
218
- const float temp = params.temp ;
219
- const float repeat_penalty = params.repeat_penalty ;
220
-
221
- const int n_vocab = model.hparams .n_vocab ;
222
-
223
- gpt_vocab::id id = 0 ;
224
-
225
- {
226
- const int64_t t_start_sample_us = ggml_time_us ();
227
-
228
- id = llama_sample_top_p_top_k (vocab, logits.data () + (logits.size () - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
229
-
230
- last_n_tokens.erase (last_n_tokens.begin ());
231
- last_n_tokens.push_back (id);
232
-
233
- t_sample_us += ggml_time_us () - t_start_sample_us;
234
- }
140
+ if (!llama_injest_input (ctx, params.prompt ))
141
+ {
142
+ fprintf (stderr, " Failed to injest prompt\n " );
143
+ return 1 ;
144
+ };
235
145
236
- // add it to the context
237
- embd.push_back (id);
146
+ // display text
147
+ input_noecho = false ;
148
+ const std::vector<gpt_vocab::id>& embd = llama_context_get_embd (ctx);
149
+ if (!input_noecho) {
150
+ for (auto id : embd) {
151
+ printf (" %s" , vocab.id_to_token [id].c_str ());
152
+ }
153
+ fflush (stdout);
154
+ }
238
155
239
- // echo this to console
240
- input_noecho = false ;
156
+ if (!input_noecho && params.use_color ) {
157
+ printf (ANSI_COLOR_RESET);
158
+ }
241
159
242
- // decrement remaining sampling budget
243
- --remaining_tokens;
244
- } else {
245
- // some user input remains from prompt or interaction, forward it to processing
246
- // Copy at most n_batch elements from embd_inp to embd
247
- size_t num_copied = std::min ((size_t ) params.n_batch , embd_inp.size () - input_consumed);
248
- std::copy (embd_inp.begin () + input_consumed, embd_inp.begin () + input_consumed + num_copied, std::back_inserter (embd));
249
- input_consumed += num_copied;
250
-
251
- // Copy the last `last_n_size` elements copied into embd to last_n_tokens
252
- size_t num_copied_last_n = std::min (num_copied, (size_t ) last_n_size);
253
- last_n_tokens.erase (last_n_tokens.begin (), last_n_tokens.begin ()+num_copied_last_n);
254
- last_n_tokens.insert (last_n_tokens.end (), embd.end () - num_copied_last_n, embd.end ());
255
-
256
- // reset color to default if we there is no pending user input
257
- if (!input_noecho && params.use_color && embd_inp.size () == input_consumed) {
258
- printf (ANSI_COLOR_RESET);
259
- }
260
- }
160
+ const std::vector<gpt_vocab::id>& last_n_tokens = llama_context_get_last_n_tokens (ctx);
261
161
262
- // display text
263
- if (!input_noecho) {
264
- for (auto id : embd) {
265
- printf (" %s" , vocab.id_to_token [id].c_str ());
266
- }
162
+ while (llama_context_not_finished (ctx) > 0 ) {
163
+ std::optional<gpt_vocab::id> model_output = llama_inference (ctx);
164
+ if (model_output.has_value ()) {
165
+ printf (" %s" , vocab.id_to_token [model_output.value ()].c_str ());
267
166
fflush (stdout);
268
167
}
269
168
169
+
270
170
// in interactive mode, and not currently processing queued inputs;
271
171
// check if we should prompt the user for more
272
- if (params.interactive && embd_inp. size () <= input_consumed ) {
172
+ if (params.interactive ) {
273
173
// check for reverse prompt
274
174
if (antiprompt_inp.size () && std::equal (antiprompt_inp.rbegin (), antiprompt_inp.rend (), last_n_tokens.rbegin ())) {
275
175
// reverse prompt found
@@ -299,13 +199,8 @@ int main(int argc, char ** argv) {
299
199
buf[n_read] = ' \n ' ;
300
200
buf[n_read+1 ] = 0 ;
301
201
}
302
-
303
- std::vector<gpt_vocab::id> line_inp = ::llama_tokenize (vocab, buf, false );
304
- embd_inp.insert (embd_inp.end (), line_inp.begin (), line_inp.end ());
305
-
306
- remaining_tokens -= line_inp.size ();
307
-
308
- input_noecho = true ; // do not echo this again
202
+ // Do not clear existing context in interactive mode
203
+ llama_init_context_with_prompt (ctx, buf, false );
309
204
}
310
205
311
206
is_interacting = false ;
@@ -318,21 +213,14 @@ int main(int argc, char ** argv) {
318
213
break ;
319
214
}
320
215
}
321
-
322
-
323
- // report timing
216
+
217
+ // mmreport timing from context
324
218
{
325
219
const int64_t t_main_end_us = ggml_time_us ();
326
-
327
- fprintf (stderr, " \n\n " );
328
- fprintf (stderr, " %s: mem per token = %8zu bytes\n " , __func__, mem_per_token);
329
- fprintf (stderr, " %s: load time = %8.2f ms\n " , __func__, t_load_us/1000 .0f );
330
- fprintf (stderr, " %s: sample time = %8.2f ms\n " , __func__, t_sample_us/1000 .0f );
331
- fprintf (stderr, " %s: predict time = %8.2f ms / %.2f ms per token\n " , __func__, t_predict_us/1000 .0f , t_predict_us/1000 .0f /n_past);
220
+ llama_print_end_stats (ctx);
332
221
fprintf (stderr, " %s: total time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us)/1000 .0f );
333
222
}
334
-
335
- ggml_free (model.ctx );
223
+ llama_free_context (ctx_ptr);
336
224
337
225
return 0 ;
338
226
}
0 commit comments