Skip to content

logit_bias: apply configurable escalating EOG bias at low n_remain #14229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,15 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
exit(1); // for other exceptions, we exit with status code 1
}

float &pafter = params.sampling.start_eog_after;
float &premain = params.sampling.start_eog_at_remain;
float const premain0 = premain;
float remain = params.n_predict - pafter;
if (premain < remain)
premain = remain;
if (params.sampling.eog_bias_per_tok)
LOG_INF("%s: n_predict=%d (first of start_eog_at_remain=%0.3g start_eog_after=%0.3g) => (remain=%0.3g) eog-bias-per-tok=%0.3g\n", __func__, (int) params.n_predict,
(double) premain0, (double) pafter, (double)premain, (double) params.sampling.eog_bias_per_tok);
return true;
}

Expand Down Expand Up @@ -1937,6 +1946,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_sparam());
add_opt(common_arg(
{"-eog", "--eog-bias-per-tok"}, "N",
string_format("when fewer than -start-eog-at-remain tokens are left to generate after -n, add this bias eog for each subsequent token (default: %.1f)", (double)params.sampling.eog_bias_per_tok),
[](common_params & params, const std::string & value) {
params.sampling.eog_bias_per_tok = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"-remain", "--start-eog-at-remain"}, "N",
string_format("start applying -eog bias when this many tokens remain of the -n max (default: %.1f)", (double)params.sampling.start_eog_at_remain),
[](common_params & params, const std::string & value) {
params.sampling.start_eog_at_remain = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"-after", "--start-eog-after"}, "N",
string_format("start applying -eog bias after this many tokens generated (default: %.1f); whichever happens first between -remain and -after applies", (double)params.sampling.start_eog_after),
[](common_params & params, const std::string & value) {
params.sampling.start_eog_after = std::stof(value);
}
).set_sparam());
add_opt(common_arg(
{"--grammar"}, "GRAMMAR",
string_format("BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '%s')", params.sampling.grammar.c_str()),
Expand Down
7 changes: 7 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,13 @@ struct common_params_sampling {

std::vector<llama_logit_bias> logit_bias; // logit biases to apply

float eog_bias_per_tok = 0; // escalating bias added to eog per token after:
/// this many remaining tokens (before applying eog_bias_per_tok) ...
float start_eog_at_remain = 0;
// or (whichever is first) after start_eog_after many generated:
/// (i.e. EOG logit bias = max(0,start_eog_after = max(start_eog_after, n_remain - start_eog_at_remain)) * eog_bias_per_tok)
float start_eog_after = 1e9;

// print the parameters into a string
std::string print() const;
};
Expand Down
27 changes: 15 additions & 12 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
params.logit_bias.data(),
params.eog_bias_per_tok,
params.start_eog_at_remain,
vocab));

if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) {
Expand Down Expand Up @@ -335,18 +338,18 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}
}

llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first, float n_remain) {
gsmpl->set_logits(ctx, idx);

auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits

if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(grmr, &cur_p, n_remain);
}

llama_sampler_apply(chain, &cur_p);
llama_sampler_apply(chain, &cur_p, n_remain);

GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");

Expand All @@ -361,7 +364,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };

llama_sampler_apply(grmr, &single_token_data_array);
llama_sampler_apply(grmr, &single_token_data_array, n_remain);

const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
Expand All @@ -373,23 +376,23 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);

llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);
llama_sampler_apply(grmr, &cur_p, n_remain);
llama_sampler_apply(chain, &cur_p, n_remain);

GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");

return cur_p.data[cur_p.selected].id;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first, float n_remain) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");

std::vector<llama_token> result;
result.reserve(idxs.size());

size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain);

common_sampler_accept(gsmpl, id, true);

Expand All @@ -401,7 +404,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
}

if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first, n_remain);

common_sampler_accept(gsmpl, id, true);

Expand All @@ -411,13 +414,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result;
}

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first, float n_remain) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}

return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first, n_remain);
}

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
Expand Down
6 changes: 3 additions & 3 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
// if grammar_first is true, the grammar is applied before the samplers (slower)
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false, float n_remain = 0);

// generalized version of common_sampler_sample
//
Expand All @@ -76,10 +76,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
//
// returns at least 1 token, up to idxs.size()
//
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0);

// assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false, float n_remain = 0);

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

Expand Down
4 changes: 2 additions & 2 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,12 +238,12 @@ llama_tokens common_speculative_gen_draft(
llama_decode(ctx, batch);

common_sampler_reset(smpl);

int n_remain = params.n_draft;
// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);

common_sampler_sample(smpl, ctx, 0, true);
common_sampler_sample(smpl, ctx, 0, true, --n_remain);

const auto * cur_p = common_sampler_get_candidates(smpl);

Expand Down
4 changes: 3 additions & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ int main(int argc, char ** argv) {

const auto t_main_start = ggml_time_us();

int n_remain = n_predict;
while (n_cur <= n_predict) {
--n_remain;
// prepare the next batch
common_batch_clear(batch);

Expand All @@ -173,7 +175,7 @@ int main(int argc, char ** argv) {
continue;
}

const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i], n_remain);

// is it an end of generation? -> mark the stream as finished
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
Expand Down
4 changes: 2 additions & 2 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std

std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
int32_t i_current_token = 0;

int n_remain = 32;
while (true) {
common_batch_clear(bat);
{
Expand All @@ -122,7 +122,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std

llama_decode(ctx, bat);

llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1, --n_remain);

if (token == eos_token) {
break;
Expand Down
7 changes: 5 additions & 2 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ int main(int argc, char ** argv) {

int seq_id_best = 0;

int n_remain = N;
for (int v = 0; v < N; ++v) {
int i_batch = 0;

Expand All @@ -274,8 +275,9 @@ int main(int argc, char ** argv) {
}
}

--n_remain;
// sample the next token
id = common_sampler_sample(smpl, ctx, i_batch);
id = common_sampler_sample(smpl, ctx, i_batch, n_remain);

common_sampler_accept(smpl, id, true);

Expand Down Expand Up @@ -349,10 +351,11 @@ int main(int argc, char ** argv) {
tokens_j[j] = tokens_j[j + 1];
}

unsigned constexpr NA = (unsigned)-1;
if (v == 0) {
// sample from the last level
for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i, NA);
}
} else {
for (int i = 0; i < W; i++) {
Expand Down
3 changes: 2 additions & 1 deletion examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ int main(int argc, char ** argv){
int i_dft = 0;
while (true) {
// sample from the target model
llama_token id = common_sampler_sample(smpl, ctx, i_dft);
unsigned const n_remain = params.n_predict - n_predict;
llama_token id = common_sampler_sample(smpl, ctx, i_dft, n_remain);

common_sampler_accept(smpl, id, true);

Expand Down
4 changes: 3 additions & 1 deletion examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,12 @@ int main(int argc, char ** argv) {

const auto t_main_start = ggml_time_us();

int n_remain = n_len - n_cur;
while (n_cur <= n_len) {
--n_remain;
// sample the next token
{
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1, n_remain);

// is it an end of generation?
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
Expand Down
12 changes: 9 additions & 3 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ int main(int argc, char ** argv) {
// first run
printf("\nfirst run: %s", params.prompt.c_str());

int n_remain = params.n_predict;
for (auto i = 0; i < params.n_predict; i++) {
auto next_token = llama_sampler_sample(smpl, ctx, -1);
--n_remain;
auto next_token = llama_sampler_sample(smpl, ctx, -1, n_remain);
auto next_token_str = common_token_to_piece(ctx, next_token);

printf("%s", next_token_str.c_str());
Expand Down Expand Up @@ -128,8 +130,10 @@ int main(int argc, char ** argv) {
n_past = n_past_saved;

// second run
n_remain = params.n_predict;
for (auto i = 0; i < params.n_predict; i++) {
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
--n_remain;
auto next_token = llama_sampler_sample(smpl2, ctx2, -1, n_remain);
auto next_token_str = common_token_to_piece(ctx2, next_token);

printf("%s", next_token_str.c_str());
Expand Down Expand Up @@ -209,8 +213,10 @@ int main(int argc, char ** argv) {
}

// third run with seq 1 instead of 0
n_remain = params.n_predict;
for (auto i = 0; i < params.n_predict; i++) {
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
--n_remain;
auto next_token = llama_sampler_sample(smpl3, ctx3, -1, n_remain);
auto next_token_str = common_token_to_piece(ctx3, next_token);

printf("%s", next_token_str.c_str());
Expand Down
4 changes: 3 additions & 1 deletion examples/simple-chat/simple-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ int main(int argc, char ** argv) {
// prepare a batch for the prompt
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
llama_token new_token_id;
int n_remain = batch.n_tokens;
while (true) {
--n_remain;
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0);
Expand All @@ -125,7 +127,7 @@ int main(int argc, char ** argv) {
}

// sample the next token
new_token_id = llama_sampler_sample(smpl, ctx, -1);
new_token_id = llama_sampler_sample(smpl, ctx, -1, n_remain);

// is it an end of generation?
if (llama_vocab_is_eog(vocab, new_token_id)) {
Expand Down
4 changes: 3 additions & 1 deletion examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ int main(int argc, char ** argv) {
int n_decode = 0;
llama_token new_token_id;

int n_remain = n_predict;

for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) {
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
Expand All @@ -162,7 +164,7 @@ int main(int argc, char ** argv) {

// sample the next token
{
new_token_id = llama_sampler_sample(smpl, ctx, -1);
new_token_id = llama_sampler_sample(smpl, ctx, -1, --n_remain);

// is it an end of generation?
if (llama_vocab_is_eog(vocab, new_token_id)) {
Expand Down
11 changes: 7 additions & 4 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ extern "C" {
struct llama_sampler_i {
const char * (*name) (const struct llama_sampler * smpl); // can be NULL
void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL
void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required
void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p, float n_remain); // required
void (*reset) ( struct llama_sampler * smpl); // can be NULL
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
Expand All @@ -1215,7 +1215,7 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p, float n_remain);
LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
Expand Down Expand Up @@ -1346,7 +1346,10 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);
const llama_logit_bias * logit_bias,
float eog_bias_per_tok,
float start_eog_at_remain,
const struct llama_vocab *vocab);

// this sampler is meant to be used for fill-in-the-middle infilling
// it's supposed to be used after top_k + top_p sampling
Expand Down Expand Up @@ -1384,7 +1387,7 @@ extern "C" {
// llama_sampler_accept(smpl, token);
// return token;
// Returns the sampled token
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx, float n_remain);

// TODO: extend in the future
//LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...);
Expand Down
Loading