@@ -14,6 +14,12 @@ struct server_params
14
14
bool verbose = false ;
15
15
};
16
16
17
+ static size_t common_part (const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
18
+ size_t i;
19
+ for (i = 0 ; i < a.size () && i < b.size () && a[i] == b[i]; i++);
20
+ return i;
21
+ }
22
+
17
23
struct llama_server_context
18
24
{
19
25
bool stream = false ;
@@ -28,10 +34,7 @@ struct llama_server_context
28
34
29
35
std::vector<llama_token> embd;
30
36
std::vector<llama_token> last_n_tokens;
31
- std::vector<llama_token> processed_tokens;
32
- std::vector<llama_token> embd_inp;
33
37
34
- std::vector<llama_token> last_prompt_tokens;
35
38
llama_context *ctx = nullptr ;
36
39
gpt_params params;
37
40
@@ -55,11 +58,10 @@ struct llama_server_context
55
58
generated_text.reserve (params.n_ctx );
56
59
stopping_word = " " ;
57
60
58
- // processed_tokens.clear();
59
- embd_inp.clear ();
60
61
n_remain = 0 ;
61
62
n_past = 0 ;
62
63
n_consumed = 0 ;
64
+ last_n_tokens.clear ();
63
65
}
64
66
65
67
bool loadModel (const gpt_params ¶ms_)
@@ -80,177 +82,159 @@ struct llama_server_context
80
82
bool loadPrompt () {
81
83
params.prompt .insert (0 , 1 , ' ' ); // always add a first space
82
84
std::vector<llama_token> prompt_tokens = ::llama_tokenize (ctx, params.prompt , true );
83
- if (prompt_tokens == last_prompt_tokens)
84
- {
85
- embd. clear ();
85
+
86
+ if (params. n_keep < 0 ) {
87
+ params. n_keep = ( int )prompt_tokens. size ();
86
88
}
87
- // compare the evaluated prompt with the new prompt
88
- for (n_past = 0 ; n_past < prompt_tokens.size () - 1 && n_past < processed_tokens.size (); n_past++) {
89
- if (prompt_tokens[n_past] != processed_tokens[n_past]) {
90
- break ;
91
- }
89
+ params.n_keep = std::min (params.n_ctx - 4 , params.n_keep );
90
+
91
+ // if input prompt is too big, truncate like normal
92
+ if (prompt_tokens.size () >= (size_t )params.n_ctx ) {
93
+ const int n_left = (params.n_ctx - params.n_keep )/2 ;
94
+ std::vector<llama_token> new_tokens (prompt_tokens.begin (), prompt_tokens.begin () + params.n_keep );
95
+ new_tokens.insert (new_tokens.end (), prompt_tokens.end () - n_left, prompt_tokens.end ());
96
+ prompt_tokens = new_tokens;
92
97
}
93
- processed_tokens.resize (n_past);
94
- if (prompt_tokens.size () > n_past) {
95
- embd_inp.insert (embd_inp.end (), prompt_tokens.begin () + n_past, prompt_tokens.end ());
98
+
99
+ // compare the evaluated prompt with the new prompt
100
+ n_past = common_part (embd, prompt_tokens);
101
+ embd = prompt_tokens;
102
+ if (n_past == prompt_tokens.size ()) {
103
+ // we have to evaluate at least 1 token to generate logits.
104
+ n_past--;
96
105
}
97
- last_prompt_tokens = prompt_tokens;
98
106
has_next_token = true ;
99
107
return true ;
100
108
}
101
109
102
110
void beginCompletion ()
103
111
{
104
- if (n_remain == 0 ) {
105
- // number of tokens to keep when resetting context
106
- if (params.n_keep < 0 || params.n_keep > (int )embd_inp.size ())
107
- {
108
- params.n_keep = (int )embd_inp.size ();
109
- }
110
- }
112
+ // number of tokens to keep when resetting context
113
+
114
+
111
115
n_remain = params.n_predict ;
112
116
llama_set_rng_seed (ctx, params.seed );
113
117
}
114
118
115
119
llama_token nextToken () {
116
120
llama_token result = -1 ;
117
- if (embd.size () > 0 )
121
+
122
+ if (embd.size () >= (size_t )params.n_ctx ) {
123
+ // Reset context
124
+ const int n_left = (params.n_ctx - params.n_keep )/2 ;
125
+
126
+ std::vector<llama_token> new_tokens (embd.begin (), embd.begin () + params.n_keep );
127
+ new_tokens.insert (new_tokens.end (), embd.end () - n_left, embd.end ());
128
+ embd = new_tokens;
129
+ n_past = params.n_keep ;
130
+ }
131
+
132
+ while (n_past < embd.size ())
118
133
{
119
- if (n_past + embd.size () > (size_t )params.n_ctx )
134
+ int n_eval = (int )embd.size () - n_past;
135
+ if (n_eval > params.n_batch )
120
136
{
121
- // Reset context
122
- const int n_left = n_past - params.n_keep ;
123
- n_past = std::max (1 , params.n_keep );
124
- // processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end());
125
- embd.insert (embd.begin (), last_n_tokens.begin () + params.n_ctx - n_left / 2 - embd.size (), last_n_tokens.end () - embd.size ());
137
+ n_eval = params.n_batch ;
126
138
}
127
- for ( int i = 0 ; i < ( int ) embd. size (); i += params.n_batch )
139
+ if ( llama_eval (ctx, & embd[n_past], n_eval, n_past, params.n_threads ) )
128
140
{
129
- int n_eval = (int )embd.size () - i;
130
- if (n_eval > params.n_batch )
131
- {
132
- n_eval = params.n_batch ;
133
- }
134
- if (llama_eval (ctx, &embd[i], n_eval, n_past, params.n_threads ))
135
- {
136
- fprintf (stderr, " %s : failed to eval\n " , __func__);
137
- has_next_token = false ;
138
- return result;
139
- }
140
- n_past += n_eval;
141
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
142
+ has_next_token = false ;
143
+ return result;
141
144
}
145
+ n_past += n_eval;
142
146
}
143
- embd.clear ();
144
- if (embd_inp.size () <= n_consumed)
147
+
148
+ // out of user input, sample next token
149
+ const float temp = params.temp ;
150
+ const int32_t top_k = params.top_k <= 0 ? llama_n_vocab (ctx) : params.top_k ;
151
+ const float top_p = params.top_p ;
152
+ const float tfs_z = params.tfs_z ;
153
+ const float typical_p = params.typical_p ;
154
+ const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n ;
155
+ const float repeat_penalty = params.repeat_penalty ;
156
+ const float alpha_presence = params.presence_penalty ;
157
+ const float alpha_frequency = params.frequency_penalty ;
158
+ const int mirostat = params.mirostat ;
159
+ const float mirostat_tau = params.mirostat_tau ;
160
+ const float mirostat_eta = params.mirostat_eta ;
161
+ const bool penalize_nl = params.penalize_nl ;
162
+ llama_token id = 0 ;
145
163
{
146
- // out of user input, sample next token
147
- const float temp = params.temp ;
148
- const int32_t top_k = params.top_k <= 0 ? llama_n_vocab (ctx) : params.top_k ;
149
- const float top_p = params.top_p ;
150
- const float tfs_z = params.tfs_z ;
151
- const float typical_p = params.typical_p ;
152
- const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n ;
153
- const float repeat_penalty = params.repeat_penalty ;
154
- const float alpha_presence = params.presence_penalty ;
155
- const float alpha_frequency = params.frequency_penalty ;
156
- const int mirostat = params.mirostat ;
157
- const float mirostat_tau = params.mirostat_tau ;
158
- const float mirostat_eta = params.mirostat_eta ;
159
- const bool penalize_nl = params.penalize_nl ;
160
- llama_token id = 0 ;
164
+ auto logits = llama_get_logits (ctx);
165
+ auto n_vocab = llama_n_vocab (ctx);
166
+
167
+ // Apply params.logit_bias map
168
+ for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++)
161
169
{
162
- auto logits = llama_get_logits (ctx) ;
163
- auto n_vocab = llama_n_vocab (ctx);
170
+ logits[it-> first ] += it-> second ;
171
+ }
164
172
165
- // Apply params.logit_bias map
166
- for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++)
167
- {
168
- logits[it->first ] += it->second ;
169
- }
173
+ std::vector<llama_token_data> candidates;
174
+ candidates.reserve (n_vocab);
175
+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++)
176
+ {
177
+ candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
178
+ }
170
179
171
- std::vector<llama_token_data> candidates;
172
- candidates.reserve (n_vocab);
173
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++)
174
- {
175
- candidates.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
176
- }
180
+ llama_token_data_array candidates_p = {candidates.data (), candidates.size (), false };
181
+
182
+ // Apply penalties
183
+ float nl_logit = logits[llama_token_nl ()];
184
+ auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), params.n_ctx );
185
+ llama_sample_repetition_penalty (ctx, &candidates_p,
186
+ last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
187
+ last_n_repeat, repeat_penalty);
188
+ llama_sample_frequency_and_presence_penalties (ctx, &candidates_p,
189
+ last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
190
+ last_n_repeat, alpha_frequency, alpha_presence);
191
+ if (!penalize_nl)
192
+ {
193
+ logits[llama_token_nl ()] = nl_logit;
194
+ }
177
195
178
- llama_token_data_array candidates_p = {candidates.data (), candidates.size (), false };
179
-
180
- // Apply penalties
181
- float nl_logit = logits[llama_token_nl ()];
182
- auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), params.n_ctx );
183
- llama_sample_repetition_penalty (ctx, &candidates_p,
184
- last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
185
- last_n_repeat, repeat_penalty);
186
- llama_sample_frequency_and_presence_penalties (ctx, &candidates_p,
187
- last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
188
- last_n_repeat, alpha_frequency, alpha_presence);
189
- if (!penalize_nl)
196
+ if (temp <= 0 )
197
+ {
198
+ // Greedy sampling
199
+ id = llama_sample_token_greedy (ctx, &candidates_p);
200
+ }
201
+ else
202
+ {
203
+ if (mirostat == 1 )
190
204
{
191
- logits[llama_token_nl ()] = nl_logit;
205
+ static float mirostat_mu = 2 .0f * mirostat_tau;
206
+ const int mirostat_m = 100 ;
207
+ llama_sample_temperature (ctx, &candidates_p, temp);
208
+ id = llama_sample_token_mirostat (ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
192
209
}
193
-
194
- if (temp <= 0 )
210
+ else if (mirostat == 2 )
195
211
{
196
- // Greedy sampling
197
- id = llama_sample_token_greedy (ctx, &candidates_p);
212
+ static float mirostat_mu = 2 .0f * mirostat_tau;
213
+ llama_sample_temperature (ctx, &candidates_p, temp);
214
+ id = llama_sample_token_mirostat_v2 (ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
198
215
}
199
216
else
200
217
{
201
- if (mirostat == 1 )
202
- {
203
- static float mirostat_mu = 2 .0f * mirostat_tau;
204
- const int mirostat_m = 100 ;
205
- llama_sample_temperature (ctx, &candidates_p, temp);
206
- id = llama_sample_token_mirostat (ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
207
- }
208
- else if (mirostat == 2 )
209
- {
210
- static float mirostat_mu = 2 .0f * mirostat_tau;
211
- llama_sample_temperature (ctx, &candidates_p, temp);
212
- id = llama_sample_token_mirostat_v2 (ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
213
- }
214
- else
215
- {
216
- // Temperature sampling
217
- llama_sample_tail_free (ctx, &candidates_p, tfs_z, 1 );
218
- llama_sample_typical (ctx, &candidates_p, typical_p, 1 );
219
- llama_sample_top_p (ctx, &candidates_p, top_p, 1 );
220
- llama_sample_top_k (ctx, &candidates_p, top_k, 1 );
221
- llama_sample_temperature (ctx, &candidates_p, temp);
222
- id = llama_sample_token (ctx, &candidates_p);
223
- }
224
- }
225
- last_n_tokens.erase (last_n_tokens.begin ());
226
- last_n_tokens.push_back (id);
227
- processed_tokens.push_back (id);
228
- num_tokens_predicted++;
229
- }
230
-
231
- // add it to the context
232
- embd.push_back (id);
233
- result = id;
234
- // decrement remaining sampling budget
235
- --n_remain;
236
- }
237
- else
238
- {
239
- // some user input remains from prompt or interaction, forward it to processing
240
- while (embd_inp.size () > n_consumed)
241
- {
242
- embd.push_back (embd_inp[n_consumed]);
243
- last_n_tokens.erase (last_n_tokens.begin ());
244
- last_n_tokens.push_back (embd_inp[n_consumed]);
245
- processed_tokens.push_back (embd_inp[n_consumed]);
246
- ++n_consumed;
247
- if ((int )embd.size () >= params.n_batch )
248
- {
249
- break ;
218
+ // Temperature sampling
219
+ llama_sample_tail_free (ctx, &candidates_p, tfs_z, 1 );
220
+ llama_sample_typical (ctx, &candidates_p, typical_p, 1 );
221
+ llama_sample_top_p (ctx, &candidates_p, top_p, 1 );
222
+ llama_sample_top_k (ctx, &candidates_p, top_k, 1 );
223
+ llama_sample_temperature (ctx, &candidates_p, temp);
224
+ id = llama_sample_token (ctx, &candidates_p);
250
225
}
251
226
}
227
+ last_n_tokens.erase (last_n_tokens.begin ());
228
+ last_n_tokens.push_back (id);
229
+ num_tokens_predicted++;
252
230
}
253
231
232
+ // add it to the context
233
+ embd.push_back (id);
234
+ result = id;
235
+ // decrement remaining sampling budget
236
+ --n_remain;
237
+
254
238
if (!embd.empty () && embd.back () == llama_token_eos ()) {
255
239
stopping_word = llama_token_to_str (ctx, embd.back ());
256
240
has_next_token = false ;
0 commit comments