Skip to content

Commit 7196c4e

Browse files
committed
new sampling API
1 parent 84b8f2b commit 7196c4e

File tree

1 file changed

+77
-89
lines changed

1 file changed

+77
-89
lines changed

examples/server/server.cpp

Lines changed: 77 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ enum slot_command {
125125
struct slot_params {
126126
bool stream = true;
127127
uint32_t seed = -1; // RNG seed
128-
int n_keep = 0; // RNG seed
128+
int n_keep = 0; // number of tokens to keep from initial prompt
129129
int32_t n_predict = -1; // new tokens to predict
130130
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
131131
bool cache_prompt = false; // remember a the prompt to avoid reprocessing all prompt
@@ -262,6 +262,34 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
262262
return out;
263263
}
264264

265+
struct llama_sampling_context * llama_sampling_init_srv(const struct llama_sampling_params sparams, std::string grammar, int n_ctx) {
266+
struct llama_sampling_context * result = new llama_sampling_context();
267+
268+
result->params = sparams;
269+
result->grammar = nullptr;
270+
271+
// if there is a grammar, parse it
272+
if (!grammar.empty()) {
273+
result->parsed_grammar = grammar_parser::parse(grammar.c_str());
274+
275+
// will be empty (default) if there are parse errors
276+
if (result->parsed_grammar.rules.empty()) {
277+
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
278+
return nullptr;
279+
}
280+
281+
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
282+
283+
result->grammar = llama_grammar_init(
284+
grammar_rules.data(),
285+
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
286+
}
287+
288+
result->prev.resize(n_ctx);
289+
290+
return result;
291+
}
292+
265293
struct slot_image {
266294
clip_image_u8 img_data;
267295
bool request_encode_image = false;
@@ -287,7 +315,6 @@ struct llama_client_slot
287315
int num_tokens_predicted = 0;
288316
llama_token sampled;
289317
std::vector<llama_token> cache_tokens;
290-
std::vector<llama_token> last_n_tokens;
291318
std::vector<completion_token_output> generated_token_probs;
292319
int sent_tokens = 0;
293320
slot_state state = IDLE;
@@ -307,13 +334,12 @@ struct llama_client_slot
307334
double t_token_generation; // ms
308335

309336
struct slot_params params;
337+
338+
// sampling
310339
struct llama_sampling_params sparams;
311-
llama_sampling_context ctx_sampling;
340+
llama_sampling_context* ctx_sampling = nullptr;
312341
bool has_next_token = true;
313-
314-
// grammar props
315-
grammar_parser::parse_state parsed_grammar;
316-
llama_grammar *grammar = nullptr;
342+
int max_context_size = 0;
317343

318344
// multimodal
319345
std::vector<slot_image> images;
@@ -332,47 +358,26 @@ struct llama_client_slot
332358
infill = false;
333359
clean_tokens();
334360

335-
if (grammar != nullptr) {
336-
llama_grammar_free(grammar);
337-
grammar = nullptr;
338-
ctx_sampling.params = sparams;
339-
ctx_sampling.grammar = NULL;
361+
if (ctx_sampling != nullptr) {
362+
llama_sampling_free(ctx_sampling);
340363
}
341364

365+
ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size);
366+
342367
for(slot_image img : images) {
343368
free(img.image_embedding);
344369
delete[] img.img_data.data;
345370
img.prefix_prompt = "";
346371
}
372+
347373
images.clear();
348374
// llama_set_rng_seed(ctx, params.seed); in batched the seed matter???????
349375
}
350376

351377
bool loadGrammar(llama_token eos)
352378
{
353-
if (!params.grammar.empty()) {
354-
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
355-
// will be empty (default) if there are parse errors
356-
if (parsed_grammar.rules.empty()) {
357-
LOG_ERROR("grammar parse error", {{"grammar", params.grammar}});
358-
return false;
359-
}
360-
grammar_parser::print_grammar(stderr, parsed_grammar);
361-
362-
{
363-
auto it = sparams.logit_bias.find(eos);
364-
if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
365-
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
366-
}
367-
}
368-
369-
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
370-
grammar = llama_grammar_init(
371-
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
372-
}
373-
ctx_sampling.params = sparams;
374-
ctx_sampling.grammar = grammar;
375-
return true;
379+
ctx_sampling = llama_sampling_init_srv(sparams, params.grammar, max_context_size);
380+
return ctx_sampling != nullptr;
376381
}
377382

378383
bool hasBudget(gpt_params &global_params) {
@@ -448,7 +453,6 @@ struct llama_server_context
448453
llama_model *model = nullptr;
449454
llama_context *ctx = nullptr;
450455
llama_batch batch;
451-
std::vector<llama_token_data> candidates;
452456
bool all_slots_are_idle = false;
453457
gpt_params params;
454458
int n_ctx;
@@ -468,11 +472,6 @@ struct llama_server_context
468472
llama_free_model(model);
469473
model = nullptr;
470474
}
471-
for(auto &slot : slots) {
472-
if(slot.grammar) {
473-
llama_grammar_free(slot.grammar);
474-
}
475-
}
476475
}
477476

478477
bool loadModel(const gpt_params &params_)
@@ -510,7 +509,6 @@ struct llama_server_context
510509
}
511510
n_ctx = llama_n_ctx(ctx);
512511
n_vocab = llama_n_vocab(model);
513-
candidates.reserve(n_vocab);
514512
return true;
515513
}
516514

@@ -529,13 +527,12 @@ struct llama_server_context
529527
{
530528
llama_client_slot slot;
531529
slot.id = i;
532-
slot.last_n_tokens.resize(max_ctx_per_slot);
533-
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
530+
slot.max_context_size = max_ctx_per_slot;
534531
slot.reset();
535532
LOG_TEE(" -> Slot %i - max context: %i\n", slot.id, max_ctx_per_slot);
536533
slots.push_back(slot);
537534
}
538-
batch = llama_batch_init(n_ctx, 0);
535+
batch = llama_batch_init(n_ctx, 0, 1);
539536
// empty system prompt
540537
system_prompt = "";
541538
num_tokens_system = 0;
@@ -626,10 +623,7 @@ struct llama_server_context
626623

627624
for (int32_t i = 0; i < batch.n_tokens; ++i)
628625
{
629-
batch.token[i] = tokens_system[i];
630-
batch.pos[i] = i;
631-
batch.seq_id[i] = 0;
632-
batch.logits[i] = false;
626+
llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
633627
}
634628

635629
if (llama_decode(ctx, batch) != 0)
@@ -726,8 +720,6 @@ struct llama_server_context
726720

727721
bool processToken(completion_token_output & result, llama_client_slot & slot) {
728722
// remember which tokens were sampled - used for repetition penalties during sampling
729-
slot.last_n_tokens.erase(slot.last_n_tokens.begin());
730-
slot.last_n_tokens.push_back(result.tok);
731723
const std::string token_str = llama_token_to_piece(ctx, result.tok);
732724
slot.sampled = result.tok;
733725

@@ -859,11 +851,12 @@ struct llama_server_context
859851
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
860852
llama_batch batch_view = {
861853
n_tokens,
862-
batch.token + i,
854+
batch.token + i,
863855
nullptr,
864-
batch.pos + i,
865-
batch.seq_id + i,
866-
batch.logits + i,
856+
batch.pos + i,
857+
batch.n_seq_id + i,
858+
batch.seq_id + i,
859+
batch.logits + i,
867860
0, 0, 0, // unused
868861
};
869862
if (llama_decode(ctx, batch_view)) {
@@ -878,8 +871,8 @@ struct llama_server_context
878871
if (n_eval > n_batch) {
879872
n_eval = n_batch;
880873
}
881-
llama_batch batch = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
882-
if (llama_decode(ctx, batch)) {
874+
llama_batch batch_img = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
875+
if (llama_decode(ctx, batch_img)) {
883876
LOG_TEE("%s : failed to eval image\n", __func__);
884877
return false;
885878
}
@@ -894,10 +887,7 @@ struct llama_server_context
894887
(json)(slot.images[image_idx].prefix_prompt);
895888
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
896889
for (int i = 0; i < append_tokens.size(); ++i) {
897-
batch.token [batch.n_tokens] = append_tokens[i];
898-
batch.pos [batch.n_tokens] = slot.n_past;
899-
batch.seq_id[batch.n_tokens] = slot.id;
900-
batch.logits[batch.n_tokens] = false;
890+
llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
901891
slot.n_past += 1;
902892
batch.n_tokens += 1;
903893
}
@@ -922,7 +912,6 @@ struct llama_server_context
922912
std::this_thread::sleep_for(std::chrono::milliseconds(5));
923913
}
924914

925-
// context shift takes effect only when there is a single slot
926915
for(llama_client_slot &slot : slots) {
927916
if (slot.isProcessing() && slot.cache_tokens.size() >= (size_t)max_ctx_per_slot)
928917
{
@@ -976,16 +965,12 @@ struct llama_server_context
976965
continue;
977966
}
978967

979-
batch.token [batch.n_tokens] = slot.sampled;
980-
batch.pos [batch.n_tokens] = num_tokens_system + slot.n_past;
981-
batch.seq_id[batch.n_tokens] = slot.id;
982-
batch.logits[batch.n_tokens] = true;
968+
slot.i_batch = batch.n_tokens;
969+
970+
llama_batch_add(batch, slot.sampled, num_tokens_system + slot.n_past, { slot.id }, true);
983971

984972
slot.n_decoded += 1;
985-
slot.i_batch = batch.n_tokens;
986973
slot.n_past += 1;
987-
988-
batch.n_tokens += 1;
989974
}
990975
// process in chunks of params.n_batch
991976
int32_t n_batch = params.n_batch;
@@ -1026,7 +1011,7 @@ struct llama_server_context
10261011
slot.num_prompt_tokens = prompt_tokens.size();
10271012

10281013
if(!slot.params.cache_prompt) {
1029-
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end(), 0);
1014+
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end(), 0);
10301015
slot.n_past = 0;
10311016
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
10321017
} else {
@@ -1038,23 +1023,27 @@ struct llama_server_context
10381023
//if input prompt is too big, truncate like normal
10391024
if (slot.num_prompt_tokens >= (size_t)max_ctx_per_slot)
10401025
{
1026+
// applied bug of #3661
10411027
const int n_left = max_ctx_per_slot - slot.params.n_keep;
1028+
const int n_block_size = n_left / 2;
1029+
const int erased_blocks = (slot.num_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
10421030
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + slot.params.n_keep);
10431031
// Use half the left-over space in the context for the prompt
1044-
new_tokens.insert(new_tokens.end(), prompt_tokens.end() - n_left / 2, prompt_tokens.end());
1032+
new_tokens.insert(new_tokens.end(), prompt_tokens.end() + slot.params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
10451033
LOG_VERBOSE("input truncated", {
1046-
{"n_ctx", n_ctx},
1047-
{"n_keep", params.n_keep},
1034+
{"n_ctx", max_ctx_per_slot},
1035+
{"n_keep", slot.params.n_keep},
10481036
{"n_left", n_left},
10491037
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
10501038
});
10511039
slot.truncated = true;
10521040
prompt_tokens = new_tokens;
10531041
slot.num_prompt_tokens = prompt_tokens.size();
1042+
GGML_ASSERT(slot.num_prompt_tokens < (size_t)max_ctx_per_slot);
10541043
}
10551044
const size_t ps = slot.num_prompt_tokens;
1056-
std::fill(slot.last_n_tokens.begin(), slot.last_n_tokens.end() - ps, 0);
1057-
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.last_n_tokens.end() - ps);
1045+
std::fill(slot.ctx_sampling->prev.begin(), slot.ctx_sampling->prev.end() - ps, 0);
1046+
std::copy(prompt_tokens.begin(), prompt_tokens.end(), slot.ctx_sampling->prev.end() - ps);
10581047
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
10591048
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
10601049
LOG_TEE("slot %i - in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
@@ -1081,11 +1070,7 @@ struct llama_server_context
10811070
// process the prefix of first image
10821071
std::vector<llama_token> prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
10831072
for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) {
1084-
batch.token [batch.n_tokens] = prefix_tokens[slot.n_past];
1085-
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
1086-
batch.seq_id[batch.n_tokens] = slot.id;
1087-
batch.logits[batch.n_tokens] = false;
1088-
batch.n_tokens += 1;
1073+
llama_batch_add(batch, prefix_tokens[slot.n_past], num_tokens_system + slot.n_past, { slot.id }, false);
10891074
}
10901075

10911076
if(ingest_images && !ingestImages(slot, n_batch)) {
@@ -1113,11 +1098,12 @@ struct llama_server_context
11131098
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
11141099
llama_batch batch_view = {
11151100
n_tokens,
1116-
batch.token + i,
1101+
batch.token + i,
11171102
nullptr,
1118-
batch.pos + i,
1119-
batch.seq_id + i,
1120-
batch.logits + i,
1103+
batch.pos + i,
1104+
batch.n_seq_id + i,
1105+
batch.seq_id + i,
1106+
batch.logits + i,
11211107
0, 0, 0, // unused
11221108
};
11231109

@@ -1150,25 +1136,27 @@ struct llama_server_context
11501136
}
11511137

11521138
completion_token_output result;
1153-
const llama_token id = llama_sampling_sample(ctx, NULL, slot.ctx_sampling, slot.last_n_tokens, candidates, slot.i_batch - i);
1139+
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
1140+
1141+
llama_sampling_accept(slot.ctx_sampling, ctx, id);
11541142

11551143
if (slot.n_decoded == 1) {
11561144
slot.t_start_genereration = ggml_time_us();
11571145
slot.t_prompt_processing = (slot.t_start_genereration - slot.t_start_process_prompt) / 1e3;
11581146
}
11591147

1160-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
1148+
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
11611149
result.tok = id;
11621150
const int32_t n_probs = slot.sparams.n_probs;
11631151
if (slot.sparams.temp <= 0 && n_probs > 0)
11641152
{
11651153
// For llama_sample_token_greedy we need to sort candidates
1166-
llama_sample_softmax(ctx, &candidates_p);
1154+
llama_sample_softmax(ctx, &cur_p);
11671155
}
11681156

1169-
for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i)
1157+
for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i)
11701158
{
1171-
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
1159+
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
11721160
}
11731161

11741162
if (!processToken(result, slot)) {

0 commit comments

Comments
 (0)