Skip to content

Commit 8478e59

Browse files
authored
Merge pull request ggml-org#8 from SlyEcho/server_refactor
Change how the token buffers work.
2 parents f2e1130 + 9104fe5 commit 8478e59

File tree

1 file changed

+124
-140
lines changed

1 file changed

+124
-140
lines changed

examples/server/server.cpp

Lines changed: 124 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ struct server_params
1414
bool verbose = false;
1515
};
1616

17+
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
18+
size_t i;
19+
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++);
20+
return i;
21+
}
22+
1723
struct llama_server_context
1824
{
1925
bool stream = false;
@@ -28,10 +34,7 @@ struct llama_server_context
2834

2935
std::vector<llama_token> embd;
3036
std::vector<llama_token> last_n_tokens;
31-
std::vector<llama_token> processed_tokens;
32-
std::vector<llama_token> embd_inp;
3337

34-
std::vector<llama_token> last_prompt_tokens;
3538
llama_context *ctx = nullptr;
3639
gpt_params params;
3740

@@ -55,11 +58,10 @@ struct llama_server_context
5558
generated_text.reserve(params.n_ctx);
5659
stopping_word = "";
5760

58-
//processed_tokens.clear();
59-
embd_inp.clear();
6061
n_remain = 0;
6162
n_past = 0;
6263
n_consumed = 0;
64+
last_n_tokens.clear();
6365
}
6466

6567
bool loadModel(const gpt_params &params_)
@@ -80,177 +82,159 @@ struct llama_server_context
8082
bool loadPrompt() {
8183
params.prompt.insert(0, 1, ' '); // always add a first space
8284
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
83-
if (prompt_tokens == last_prompt_tokens)
84-
{
85-
embd.clear();
85+
86+
if (params.n_keep < 0) {
87+
params.n_keep = (int)prompt_tokens.size();
8688
}
87-
// compare the evaluated prompt with the new prompt
88-
for (n_past = 0; n_past < prompt_tokens.size() - 1 && n_past < processed_tokens.size(); n_past++) {
89-
if (prompt_tokens[n_past] != processed_tokens[n_past]) {
90-
break;
91-
}
89+
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
90+
91+
// if input prompt is too big, truncate like normal
92+
if (prompt_tokens.size() >= (size_t)params.n_ctx) {
93+
const int n_left = (params.n_ctx - params.n_keep)/2;
94+
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
95+
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left, prompt_tokens.end());
96+
prompt_tokens = new_tokens;
9297
}
93-
processed_tokens.resize(n_past);
94-
if (prompt_tokens.size() > n_past) {
95-
embd_inp.insert(embd_inp.end(), prompt_tokens.begin() + n_past, prompt_tokens.end());
98+
99+
// compare the evaluated prompt with the new prompt
100+
n_past = common_part(embd, prompt_tokens);
101+
embd = prompt_tokens;
102+
if (n_past == prompt_tokens.size()) {
103+
// we have to evaluate at least 1 token to generate logits.
104+
n_past--;
96105
}
97-
last_prompt_tokens = prompt_tokens;
98106
has_next_token = true;
99107
return true;
100108
}
101109

102110
void beginCompletion()
103111
{
104-
if(n_remain == 0) {
105-
// number of tokens to keep when resetting context
106-
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size())
107-
{
108-
params.n_keep = (int)embd_inp.size();
109-
}
110-
}
112+
// number of tokens to keep when resetting context
113+
114+
111115
n_remain = params.n_predict;
112116
llama_set_rng_seed(ctx, params.seed);
113117
}
114118

115119
llama_token nextToken() {
116120
llama_token result = -1;
117-
if (embd.size() > 0)
121+
122+
if (embd.size() >= (size_t)params.n_ctx) {
123+
// Reset context
124+
const int n_left = (params.n_ctx - params.n_keep)/2;
125+
126+
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
127+
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
128+
embd = new_tokens;
129+
n_past = params.n_keep;
130+
}
131+
132+
while (n_past < embd.size())
118133
{
119-
if (n_past + embd.size() > (size_t)params.n_ctx)
134+
int n_eval = (int)embd.size() - n_past;
135+
if (n_eval > params.n_batch)
120136
{
121-
// Reset context
122-
const int n_left = n_past - params.n_keep;
123-
n_past = std::max(1, params.n_keep);
124-
//processed_tokens.erase(processed_tokens.begin() + n_past, processed_tokens.end());
125-
embd.insert(embd.begin(), last_n_tokens.begin() + params.n_ctx - n_left / 2 - embd.size(), last_n_tokens.end() - embd.size());
137+
n_eval = params.n_batch;
126138
}
127-
for (int i = 0; i < (int)embd.size(); i += params.n_batch)
139+
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
128140
{
129-
int n_eval = (int)embd.size() - i;
130-
if (n_eval > params.n_batch)
131-
{
132-
n_eval = params.n_batch;
133-
}
134-
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads))
135-
{
136-
fprintf(stderr, "%s : failed to eval\n", __func__);
137-
has_next_token = false;
138-
return result;
139-
}
140-
n_past += n_eval;
141+
fprintf(stderr, "%s : failed to eval\n", __func__);
142+
has_next_token = false;
143+
return result;
141144
}
145+
n_past += n_eval;
142146
}
143-
embd.clear();
144-
if (embd_inp.size() <= n_consumed)
147+
148+
// out of user input, sample next token
149+
const float temp = params.temp;
150+
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
151+
const float top_p = params.top_p;
152+
const float tfs_z = params.tfs_z;
153+
const float typical_p = params.typical_p;
154+
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
155+
const float repeat_penalty = params.repeat_penalty;
156+
const float alpha_presence = params.presence_penalty;
157+
const float alpha_frequency = params.frequency_penalty;
158+
const int mirostat = params.mirostat;
159+
const float mirostat_tau = params.mirostat_tau;
160+
const float mirostat_eta = params.mirostat_eta;
161+
const bool penalize_nl = params.penalize_nl;
162+
llama_token id = 0;
145163
{
146-
// out of user input, sample next token
147-
const float temp = params.temp;
148-
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
149-
const float top_p = params.top_p;
150-
const float tfs_z = params.tfs_z;
151-
const float typical_p = params.typical_p;
152-
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_ctx : params.repeat_last_n;
153-
const float repeat_penalty = params.repeat_penalty;
154-
const float alpha_presence = params.presence_penalty;
155-
const float alpha_frequency = params.frequency_penalty;
156-
const int mirostat = params.mirostat;
157-
const float mirostat_tau = params.mirostat_tau;
158-
const float mirostat_eta = params.mirostat_eta;
159-
const bool penalize_nl = params.penalize_nl;
160-
llama_token id = 0;
164+
auto logits = llama_get_logits(ctx);
165+
auto n_vocab = llama_n_vocab(ctx);
166+
167+
// Apply params.logit_bias map
168+
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++)
161169
{
162-
auto logits = llama_get_logits(ctx);
163-
auto n_vocab = llama_n_vocab(ctx);
170+
logits[it->first] += it->second;
171+
}
164172

165-
// Apply params.logit_bias map
166-
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++)
167-
{
168-
logits[it->first] += it->second;
169-
}
173+
std::vector<llama_token_data> candidates;
174+
candidates.reserve(n_vocab);
175+
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
176+
{
177+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
178+
}
170179

171-
std::vector<llama_token_data> candidates;
172-
candidates.reserve(n_vocab);
173-
for (llama_token token_id = 0; token_id < n_vocab; token_id++)
174-
{
175-
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
176-
}
180+
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
181+
182+
// Apply penalties
183+
float nl_logit = logits[llama_token_nl()];
184+
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
185+
llama_sample_repetition_penalty(ctx, &candidates_p,
186+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
187+
last_n_repeat, repeat_penalty);
188+
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
189+
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
190+
last_n_repeat, alpha_frequency, alpha_presence);
191+
if (!penalize_nl)
192+
{
193+
logits[llama_token_nl()] = nl_logit;
194+
}
177195

178-
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
179-
180-
// Apply penalties
181-
float nl_logit = logits[llama_token_nl()];
182-
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
183-
llama_sample_repetition_penalty(ctx, &candidates_p,
184-
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
185-
last_n_repeat, repeat_penalty);
186-
llama_sample_frequency_and_presence_penalties(ctx, &candidates_p,
187-
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
188-
last_n_repeat, alpha_frequency, alpha_presence);
189-
if (!penalize_nl)
196+
if (temp <= 0)
197+
{
198+
// Greedy sampling
199+
id = llama_sample_token_greedy(ctx, &candidates_p);
200+
}
201+
else
202+
{
203+
if (mirostat == 1)
190204
{
191-
logits[llama_token_nl()] = nl_logit;
205+
static float mirostat_mu = 2.0f * mirostat_tau;
206+
const int mirostat_m = 100;
207+
llama_sample_temperature(ctx, &candidates_p, temp);
208+
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
192209
}
193-
194-
if (temp <= 0)
210+
else if (mirostat == 2)
195211
{
196-
// Greedy sampling
197-
id = llama_sample_token_greedy(ctx, &candidates_p);
212+
static float mirostat_mu = 2.0f * mirostat_tau;
213+
llama_sample_temperature(ctx, &candidates_p, temp);
214+
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
198215
}
199216
else
200217
{
201-
if (mirostat == 1)
202-
{
203-
static float mirostat_mu = 2.0f * mirostat_tau;
204-
const int mirostat_m = 100;
205-
llama_sample_temperature(ctx, &candidates_p, temp);
206-
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
207-
}
208-
else if (mirostat == 2)
209-
{
210-
static float mirostat_mu = 2.0f * mirostat_tau;
211-
llama_sample_temperature(ctx, &candidates_p, temp);
212-
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
213-
}
214-
else
215-
{
216-
// Temperature sampling
217-
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
218-
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
219-
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
220-
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
221-
llama_sample_temperature(ctx, &candidates_p, temp);
222-
id = llama_sample_token(ctx, &candidates_p);
223-
}
224-
}
225-
last_n_tokens.erase(last_n_tokens.begin());
226-
last_n_tokens.push_back(id);
227-
processed_tokens.push_back(id);
228-
num_tokens_predicted++;
229-
}
230-
231-
// add it to the context
232-
embd.push_back(id);
233-
result = id;
234-
// decrement remaining sampling budget
235-
--n_remain;
236-
}
237-
else
238-
{
239-
// some user input remains from prompt or interaction, forward it to processing
240-
while (embd_inp.size() > n_consumed)
241-
{
242-
embd.push_back(embd_inp[n_consumed]);
243-
last_n_tokens.erase(last_n_tokens.begin());
244-
last_n_tokens.push_back(embd_inp[n_consumed]);
245-
processed_tokens.push_back(embd_inp[n_consumed]);
246-
++n_consumed;
247-
if ((int)embd.size() >= params.n_batch)
248-
{
249-
break;
218+
// Temperature sampling
219+
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
220+
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
221+
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
222+
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
223+
llama_sample_temperature(ctx, &candidates_p, temp);
224+
id = llama_sample_token(ctx, &candidates_p);
250225
}
251226
}
227+
last_n_tokens.erase(last_n_tokens.begin());
228+
last_n_tokens.push_back(id);
229+
num_tokens_predicted++;
252230
}
253231

232+
// add it to the context
233+
embd.push_back(id);
234+
result = id;
235+
// decrement remaining sampling budget
236+
--n_remain;
237+
254238
if (!embd.empty() && embd.back() == llama_token_eos()) {
255239
stopping_word = llama_token_to_str(ctx, embd.back());
256240
has_next_token = false;

0 commit comments

Comments
 (0)