Skip to content

Commit 9d44c4a

Browse files
committed
infill : remove cfg support
1 parent d89006b commit 9d44c4a

File tree

1 file changed

+8
-90
lines changed

1 file changed

+8
-90
lines changed

examples/infill/infill.cpp

Lines changed: 8 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -173,17 +173,13 @@ int main(int argc, char ** argv) {
173173

174174
llama_model * model;
175175
llama_context * ctx;
176-
llama_context * ctx_guidance = NULL;
176+
177177
g_model = &model;
178178
g_ctx = &ctx;
179179

180180
// load the model and apply lora adapter, if any
181181
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
182182
std::tie(model, ctx) = llama_init_from_gpt_params(params);
183-
if (sparams.cfg_scale > 1.f) {
184-
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
185-
ctx_guidance = llama_new_context_with_model(model, lparams);
186-
}
187183

188184
if (model == NULL) {
189185
LOG_TEE("%s: error: unable to load model\n", __func__);
@@ -239,25 +235,6 @@ int main(int argc, char ** argv) {
239235
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
240236
}
241237

242-
// Tokenize negative prompt
243-
std::vector<llama_token> guidance_inp;
244-
int guidance_offset = 0;
245-
int original_prompt_len = 0;
246-
if (ctx_guidance) {
247-
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
248-
249-
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true);
250-
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
251-
252-
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
253-
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
254-
255-
original_prompt_len = original_inp.size();
256-
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
257-
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
258-
LOG("guidance_offset: %s", log_tostr(guidance_offset));
259-
}
260-
261238
if ((int) embd_inp.size() > n_ctx - 4) {
262239
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
263240
return 1;
@@ -285,15 +262,6 @@ int main(int argc, char ** argv) {
285262
LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
286263
}
287264

288-
if (ctx_guidance) {
289-
LOG_TEE("\n");
290-
LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
291-
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
292-
for (int i = 0; i < (int) guidance_inp.size(); i++) {
293-
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
294-
}
295-
}
296-
297265
if (params.n_keep > 0) {
298266
LOG_TEE("%s: static prompt based on n_keep: '", __func__);
299267
for (int i = 0; i < params.n_keep; i++) {
@@ -361,12 +329,11 @@ int main(int argc, char ** argv) {
361329
is_interacting = params.interactive_first;
362330
}
363331

364-
bool input_echo = true;
332+
bool input_echo = true;
365333

366-
int n_past = 0;
367-
int n_remain = params.n_predict;
368-
int n_consumed = 0;
369-
int n_past_guidance = 0;
334+
int n_past = 0;
335+
int n_remain = params.n_predict;
336+
int n_consumed = 0;
370337

371338
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
372339
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
@@ -376,7 +343,6 @@ int main(int argc, char ** argv) {
376343
console::set_display(console::prompt);
377344

378345
std::vector<llama_token> embd;
379-
std::vector<llama_token> embd_guidance;
380346

381347
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
382348

@@ -402,7 +368,7 @@ int main(int argc, char ** argv) {
402368
// if we run out of context:
403369
// - take the n_keep first tokens from the original prompt (via n_past)
404370
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
405-
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
371+
if (n_past + (int) embd.size() > n_ctx) {
406372
if (params.n_predict == -2) {
407373
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
408374
break;
@@ -419,57 +385,14 @@ int main(int argc, char ** argv) {
419385

420386
n_past -= n_discard;
421387

422-
if (ctx_guidance) {
423-
n_past_guidance -= n_discard;
424-
}
425-
426-
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
388+
LOG("after swap: n_past = %d\n", n_past);
427389

428390
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
429391

430392
}
431393

432394
// evaluate tokens in batches
433395
// embd is typically prepared beforehand to fit within a batch, but not always
434-
435-
if (ctx_guidance) {
436-
int input_size = 0;
437-
llama_token * input_buf = NULL;
438-
439-
if (n_past_guidance < (int) guidance_inp.size()) {
440-
// Guidance context should have the same data with these modifications:
441-
//
442-
// * Replace the initial prompt
443-
// * Shift everything by guidance_offset
444-
embd_guidance = guidance_inp;
445-
if (embd.begin() + original_prompt_len < embd.end()) {
446-
embd_guidance.insert(
447-
embd_guidance.end(),
448-
embd.begin() + original_prompt_len,
449-
embd.end()
450-
);
451-
}
452-
453-
input_buf = embd_guidance.data();
454-
input_size = embd_guidance.size();
455-
456-
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
457-
} else {
458-
input_buf = embd.data();
459-
input_size = embd.size();
460-
}
461-
462-
for (int i = 0; i < input_size; i += params.n_batch) {
463-
int n_eval = std::min(input_size - i, params.n_batch);
464-
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
465-
LOG_TEE("%s : failed to eval\n", __func__);
466-
return 1;
467-
}
468-
469-
n_past_guidance += n_eval;
470-
}
471-
}
472-
473396
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
474397
int n_eval = (int) embd.size() - i;
475398
if (n_eval > params.n_batch) {
@@ -491,11 +414,9 @@ int main(int argc, char ** argv) {
491414
}
492415

493416
embd.clear();
494-
embd_guidance.clear();
495417

496418
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
497-
498-
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
419+
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
499420

500421
llama_sampling_accept(ctx_sampling, ctx, id, true);
501422

@@ -549,7 +470,6 @@ int main(int argc, char ** argv) {
549470

550471
// if not currently processing queued inputs;
551472
if ((int) embd_inp.size() <= n_consumed) {
552-
553473
// deal with eot token in infill mode
554474
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
555475
if (is_interacting && !params.interactive_first) {
@@ -610,7 +530,6 @@ int main(int argc, char ** argv) {
610530
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
611531
embd_inp.push_back(llama_token_middle(model));
612532
embd.clear();
613-
embd_guidance.clear();
614533
n_remain = params.n_predict;
615534
n_past = 0;
616535
n_consumed = 0;
@@ -717,7 +636,6 @@ int main(int argc, char ** argv) {
717636
llama_print_timings(ctx);
718637
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
719638

720-
if (ctx_guidance) { llama_free(ctx_guidance); }
721639
llama_free(ctx);
722640
llama_free_model(model);
723641

0 commit comments

Comments
 (0)