Skip to content

Commit bbc0b7c

Browse files
committed
sampling : hide prev behind API and apply #3661
ggml-ci
1 parent 7e2b5fb commit bbc0b7c

File tree

9 files changed

+107
-83
lines changed

9 files changed

+107
-83
lines changed

common/sampling.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,24 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds
6666
dst->prev = src->prev;
6767
}
6868

69+
llama_token llama_sampling_last(llama_sampling_context * ctx) {
70+
return ctx->prev.back();
71+
}
72+
73+
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
74+
const int size = ctx_sampling->prev.size();
75+
76+
n = std::min(n, size);
77+
78+
std::string result;
79+
80+
for (int i = size - n; i < size; i++) {
81+
result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]);
82+
}
83+
84+
return result;
85+
}
86+
6987
std::string llama_sampling_print(const llama_sampling_params & params) {
7088
char result[1024];
7189

@@ -193,11 +211,12 @@ llama_token llama_sampling_sample(
193211
void llama_sampling_accept(
194212
struct llama_sampling_context * ctx_sampling,
195213
struct llama_context * ctx_main,
196-
llama_token id) {
214+
llama_token id,
215+
bool apply_grammar) {
197216
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
198217
ctx_sampling->prev.push_back(id);
199218

200-
if (ctx_sampling->grammar != NULL) {
219+
if (ctx_sampling->grammar != NULL && apply_grammar) {
201220
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
202221
}
203222
}

common/sampling.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ void llama_sampling_reset(llama_sampling_context * ctx);
7070
// Copy the sampler context
7171
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
7272

73+
// Get the last sampled token
74+
llama_token llama_sampling_last(llama_sampling_context * ctx);
75+
76+
// Get a string representation of the last sampled tokens
77+
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
78+
7379
// Print sampling parameters into a string
7480
std::string llama_sampling_print(const llama_sampling_params & params);
7581

@@ -99,4 +105,5 @@ llama_token llama_sampling_sample(
99105
void llama_sampling_accept(
100106
struct llama_sampling_context * ctx_sampling,
101107
struct llama_context * ctx_main,
102-
llama_token id);
108+
llama_token id,
109+
bool apply_grammar);

examples/CMakeLists.txt

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,26 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR})
1212

1313
if (EMSCRIPTEN)
1414
else()
15+
add_subdirectory(baby-llama)
16+
add_subdirectory(batched)
17+
add_subdirectory(batched-bench)
18+
add_subdirectory(beam-search)
19+
add_subdirectory(benchmark)
20+
add_subdirectory(convert-llama2c-to-ggml)
21+
add_subdirectory(embedding)
22+
add_subdirectory(finetune)
23+
add_subdirectory(infill)
24+
add_subdirectory(llama-bench)
25+
add_subdirectory(llava)
1526
add_subdirectory(main)
27+
add_subdirectory(parallel)
28+
add_subdirectory(perplexity)
1629
add_subdirectory(quantize)
1730
add_subdirectory(quantize-stats)
18-
add_subdirectory(perplexity)
19-
add_subdirectory(embedding)
2031
add_subdirectory(save-load-state)
21-
add_subdirectory(benchmark)
22-
add_subdirectory(baby-llama)
23-
add_subdirectory(train-text-from-scratch)
24-
add_subdirectory(finetune)
25-
add_subdirectory(convert-llama2c-to-ggml)
2632
add_subdirectory(simple)
27-
add_subdirectory(batched)
28-
add_subdirectory(batched-bench)
2933
add_subdirectory(speculative)
30-
add_subdirectory(parallel)
31-
add_subdirectory(llava)
32-
add_subdirectory(llama-bench)
33-
add_subdirectory(beam-search)
34+
add_subdirectory(train-text-from-scratch)
3435
if (LLAMA_METAL)
3536
add_subdirectory(metal)
3637
endif()

examples/infill/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_11)
66
if(TARGET BUILD_INFO)
7-
add_dependencies(${TARGET} BUILD_INFO)
7+
add_dependencies(${TARGET} BUILD_INFO)
88
endif()

examples/infill/infill.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ int main(int argc, char ** argv) {
523523

524524
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
525525

526-
llama_sampling_accept(ctx_sampling, ctx, id);
526+
llama_sampling_accept(ctx_sampling, ctx, id, true);
527527

528528
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
529529

@@ -541,8 +541,11 @@ int main(int argc, char ** argv) {
541541
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
542542
while ((int) embd_inp.size() > n_consumed) {
543543
embd.push_back(embd_inp[n_consumed]);
544-
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
545-
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
544+
545+
// push the prompt in the sampling context in order to apply repetition penalties later
546+
// for the prompt, we don't apply grammar rules
547+
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
548+
546549
++n_consumed;
547550
if ((int) embd.size() >= params.n_batch) {
548551
break;
@@ -574,7 +577,7 @@ int main(int argc, char ** argv) {
574577
if ((int) embd_inp.size() <= n_consumed) {
575578

576579
// deal with eot token in infill mode
577-
if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
580+
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){
578581
if(is_interacting && !params.interactive_first) {
579582
// print an eot token
580583
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@@ -591,7 +594,7 @@ int main(int argc, char ** argv) {
591594
buffer += line;
592595
} while (another_line);
593596
// check if we got an empty line, if so we use the old input
594-
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
597+
if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
595598
params.input_prefix = buffer;
596599
}
597600
buffer.clear();
@@ -601,7 +604,7 @@ int main(int argc, char ** argv) {
601604
buffer += line;
602605
} while (another_line);
603606
// check if we got an empty line
604-
if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
607+
if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) {
605608
params.input_suffix = buffer;
606609
}
607610
buffer.clear();
@@ -614,7 +617,7 @@ int main(int argc, char ** argv) {
614617
process_escapes(params.input_suffix);
615618
}
616619
suff_rm_leading_spc = params.escape;
617-
if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
620+
if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
618621
params.input_suffix.erase(0, 1);
619622
suff_rm_leading_spc = false;
620623
}
@@ -641,7 +644,7 @@ int main(int argc, char ** argv) {
641644
is_interacting = false;
642645
}
643646
// deal with end of text token in interactive mode
644-
else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
647+
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
645648
LOG("found EOS token\n");
646649

647650
if (params.interactive) {

examples/main/main.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ int main(int argc, char ** argv) {
611611

612612
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
613613

614-
llama_sampling_accept(ctx_sampling, ctx, id);
614+
llama_sampling_accept(ctx_sampling, ctx, id, true);
615615

616616
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
617617

@@ -630,12 +630,9 @@ int main(int argc, char ** argv) {
630630
while ((int) embd_inp.size() > n_consumed) {
631631
embd.push_back(embd_inp[n_consumed]);
632632

633-
// GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context
634-
// Most likely will remove this in the future to avoid exposing "prev"
635-
// Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition
636-
// penalty will be applied only based on the tokens generated by the model.
637-
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
638-
ctx_sampling->prev.push_back(embd_inp[n_consumed]);
633+
// push the prompt in the sampling context in order to apply repetition penalties later
634+
// for the prompt, we don't apply grammar rules
635+
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
639636

640637
++n_consumed;
641638
if ((int) embd.size() >= params.n_batch) {
@@ -666,12 +663,10 @@ int main(int argc, char ** argv) {
666663

667664
// if not currently processing queued inputs;
668665
if ((int) embd_inp.size() <= n_consumed) {
669-
// check for reverse prompt
666+
// check for reverse prompt in the last n_prev tokens
670667
if (!params.antiprompt.empty()) {
671-
std::string last_output;
672-
for (auto id : ctx_sampling->prev) {
673-
last_output += llama_token_to_piece(ctx, id);
674-
}
668+
const int n_prev = 32;
669+
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
675670

676671
is_antiprompt = false;
677672
// Check if each of the reverse prompts appears at the end of the output.
@@ -698,7 +693,7 @@ int main(int argc, char ** argv) {
698693
}
699694

700695
// deal with end of text token in interactive mode
701-
if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
696+
if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
702697
LOG("found EOS token\n");
703698

704699
if (params.interactive) {

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ int main(int argc, char ** argv) {
330330

331331
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
332332

333-
llama_sampling_accept(client.ctx_sampling, ctx, id);
333+
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
334334

335335
if (client.n_decoded == 1) {
336336
// start measuring generation time after the first token to make sure all concurrent clients

examples/server/server.cpp

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,32 @@ struct llama_server_context
317317
return true;
318318
}
319319

320+
void truncatePrompt(std::vector<llama_token> &prompt_tokens) {
321+
const int n_left = n_ctx - params.n_keep;
322+
const int n_block_size = n_left / 2;
323+
const int erased_blocks = (prompt_tokens.size() - params.n_keep - n_block_size) / n_block_size;
324+
325+
// Keep n_keep tokens at start of prompt (at most n_ctx - 4)
326+
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
327+
328+
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_block_size, prompt_tokens.end());
329+
330+
LOG_VERBOSE("input truncated", {
331+
{"n_ctx", n_ctx},
332+
{"n_keep", params.n_keep},
333+
{"n_left", n_left},
334+
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
335+
{"num_prompt_tokens", new_tokens.size()}
336+
});
337+
338+
truncated = true;
339+
prompt_tokens = new_tokens;
340+
}
341+
320342
void loadInfill()
321343
{
322344
bool suff_rm_leading_spc = true;
323-
if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) {
345+
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
324346
params.input_suffix.erase(0, 1);
325347
suff_rm_leading_spc = false;
326348
}
@@ -336,6 +358,7 @@ struct llama_server_context
336358
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
337359
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
338360
prefix_tokens.push_back(llama_token_middle(ctx));
361+
339362
auto prompt_tokens = prefix_tokens;
340363

341364
num_prompt_tokens = prompt_tokens.size();
@@ -347,31 +370,18 @@ struct llama_server_context
347370
params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
348371

349372
// if input prompt is too big, truncate like normal
350-
if (num_prompt_tokens >= (size_t)params.n_ctx)
373+
if (num_prompt_tokens >= (size_t) n_ctx)
351374
{
352-
printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
353-
// todo we probably want to cut from both sides
354-
const int n_left = (params.n_ctx - params.n_keep) / 2;
355-
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
356-
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
357-
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
358-
std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
375+
truncatePrompt(prompt_tokens);
376+
num_prompt_tokens = prompt_tokens.size();
359377

360-
LOG_VERBOSE("input truncated", {
361-
{"n_ctx", params.n_ctx},
362-
{"n_keep", params.n_keep},
363-
{"n_left", n_left},
364-
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
365-
});
366-
367-
truncated = true;
368-
prompt_tokens = new_tokens;
378+
GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
369379
}
370-
else
380+
381+
// push the prompt into the sampling context (do not apply grammar)
382+
for (auto & token : prompt_tokens)
371383
{
372-
const size_t ps = num_prompt_tokens;
373-
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
374-
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
384+
llama_sampling_accept(ctx_sampling, ctx, token, false);
375385
}
376386

377387
// compare the evaluated prompt with the new prompt
@@ -409,29 +419,18 @@ struct llama_server_context
409419
params.n_keep = std::min(n_ctx - 4, params.n_keep);
410420

411421
// if input prompt is too big, truncate like normal
412-
if (num_prompt_tokens >= (size_t)n_ctx)
422+
if (num_prompt_tokens >= (size_t) n_ctx)
413423
{
414-
const int n_left = (n_ctx - params.n_keep) / 2;
415-
std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
416-
const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
417-
new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
418-
std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin());
424+
truncatePrompt(prompt_tokens);
425+
num_prompt_tokens = prompt_tokens.size();
419426

420-
LOG_VERBOSE("input truncated", {
421-
{"n_ctx", n_ctx},
422-
{"n_keep", params.n_keep},
423-
{"n_left", n_left},
424-
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
425-
});
426-
427-
truncated = true;
428-
prompt_tokens = new_tokens;
427+
GGML_ASSERT(num_prompt_tokens < (size_t)n_ctx);
429428
}
430-
else
429+
430+
// push the prompt into the sampling context (do not apply grammar)
431+
for (auto & token : prompt_tokens)
431432
{
432-
const size_t ps = num_prompt_tokens;
433-
std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0);
434-
std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps);
433+
llama_sampling_accept(ctx_sampling, ctx, token, false);
435434
}
436435

437436
// compare the evaluated prompt with the new prompt
@@ -542,7 +541,7 @@ struct llama_server_context
542541
result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p});
543542
}
544543

545-
llama_sampling_accept(ctx_sampling, ctx, result.tok);
544+
llama_sampling_accept(ctx_sampling, ctx, result.tok, true);
546545

547546
if (tg) {
548547
num_tokens_predicted++;

examples/speculative/speculative.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ int main(int argc, char ** argv) {
154154
// sample from the target model
155155
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
156156

157-
llama_sampling_accept(ctx_sampling, ctx_tgt, id);
157+
llama_sampling_accept(ctx_sampling, ctx_tgt, id, true);
158158

159159
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
160160

@@ -328,7 +328,7 @@ int main(int argc, char ** argv) {
328328

329329
const int s = sa[is];
330330

331-
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
331+
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
332332

333333
drafts[s].tokens.push_back(id);
334334

0 commit comments

Comments
 (0)