-
Notifications
You must be signed in to change notification settings - Fork 13.4k
Implement classifier-free guidance #2135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
d09d5ed
4786300
8ba5b13
8f91b52
114d4c5
422a7ff
66eb048
8e66e59
325fc88
abf164d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -109,10 +109,16 @@ int main(int argc, char ** argv) { | |||||
|
||||||
llama_model * model; | ||||||
llama_context * ctx; | ||||||
llama_context * guidance_ctx = NULL; | ||||||
|
||||||
g_ctx = &ctx; | ||||||
|
||||||
// load the model and apply lora adapter, if any | ||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params); | ||||||
if (params.cfg_scale > 1.f) { | ||||||
struct llama_context_params lparams = llama_get_context_params_from_gpt_params(params); | ||||||
guidance_ctx = llama_new_context_with_model(model, lparams); | ||||||
} | ||||||
|
||||||
if (model == NULL) { | ||||||
fprintf(stderr, "%s: error: unable to load model\n", __func__); | ||||||
return 1; | ||||||
|
@@ -183,15 +189,28 @@ int main(int argc, char ** argv) { | |||||
// tokenize the prompt | ||||||
std::vector<llama_token> embd_inp; | ||||||
|
||||||
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { | ||||||
// Add a space in front of the first character to match OG llama tokenizer behavior | ||||||
params.prompt.insert(0, 1, ' '); | ||||||
// Add a space in front of the first character to match OG llama tokenizer behavior | ||||||
params.prompt.insert(0, 1, ' '); | ||||||
|
||||||
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { | ||||||
embd_inp = ::llama_tokenize(ctx, params.prompt, true); | ||||||
} else { | ||||||
embd_inp = session_tokens; | ||||||
} | ||||||
|
||||||
// Tokenize negative prompt | ||||||
std::vector<llama_token> guidance_inp; | ||||||
int guidance_offset = 0; | ||||||
int original_prompt_len = 0; | ||||||
if (guidance_ctx) { | ||||||
params.cfg_negative_prompt.insert(0, 1, ' '); | ||||||
guidance_inp = ::llama_tokenize(guidance_ctx, params.cfg_negative_prompt, true); | ||||||
|
||||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true); | ||||||
original_prompt_len = original_inp.size(); | ||||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len; | ||||||
} | ||||||
|
||||||
const int n_ctx = llama_n_ctx(ctx); | ||||||
|
||||||
if ((int) embd_inp.size() > n_ctx - 4) { | ||||||
|
@@ -258,6 +277,16 @@ int main(int argc, char ** argv) { | |||||
for (int i = 0; i < (int) embd_inp.size(); i++) { | ||||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); | ||||||
} | ||||||
|
||||||
if (guidance_ctx) { | ||||||
fprintf(stderr, "\n"); | ||||||
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); | ||||||
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); | ||||||
for (int i = 0; i < (int) guidance_inp.size(); i++) { | ||||||
fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i])); | ||||||
} | ||||||
} | ||||||
|
||||||
if (params.n_keep > 0) { | ||||||
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); | ||||||
for (int i = 0; i < params.n_keep; i++) { | ||||||
|
@@ -334,11 +363,13 @@ int main(int argc, char ** argv) { | |||||
int n_remain = params.n_predict; | ||||||
int n_consumed = 0; | ||||||
int n_session_consumed = 0; | ||||||
int guidance_n_past = 0; | ||||||
|
int guidance_n_past = 0; | |
int n_past_guidance = 0; |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector<llama_token> guidance_embd; | |
std::vector<llama_token> embd_guidance; |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2141,6 +2141,75 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l | |||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
template<typename T, typename LogitAccessor> | ||||||||||
void llama_log_softmax(T * array, int size, LogitAccessor logit_accessor) { | ||||||||||
T* element = std::max_element( | ||||||||||
array, array + size, | ||||||||||
[&logit_accessor](T& lhs, T& rhs) { | ||||||||||
return logit_accessor(lhs) < logit_accessor(rhs); | ||||||||||
} | ||||||||||
); | ||||||||||
|
||||||||||
float max_l = logit_accessor(*element); | ||||||||||
float sum = 0.f; | ||||||||||
for (int i = 0; i < size; ++i) { | ||||||||||
float& logit = logit_accessor(array[i]); | ||||||||||
float p = expf(logit - max_l); | ||||||||||
sum += p; | ||||||||||
logit = p; | ||||||||||
} | ||||||||||
|
||||||||||
for (int i = 0; i < size; ++i) { | ||||||||||
float& logit = logit_accessor(array[i]); | ||||||||||
logit = logf(logit / sum); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
void llama_sample_classifier_free_guidance( | ||||||||||
struct llama_context * ctx, | ||||||||||
llama_token_data_array * candidates, | ||||||||||
struct llama_context * guidance_ctx, | ||||||||||
float scale, | ||||||||||
float smooth_factor) { | ||||||||||
int64_t t_start_sample_us = t_start_sample_us = ggml_time_us(); | ||||||||||
|
||||||||||
assert(ctx); | ||||||||||
auto n_vocab = llama_n_vocab(ctx); | ||||||||||
assert(n_vocab == (int)candidates->size); | ||||||||||
assert(!candidates->sorted); | ||||||||||
|
||||||||||
auto logit_from_token_data = [](llama_token_data& data) -> float& { | ||||||||||
return data.logit; | ||||||||||
}; | ||||||||||
|
||||||||||
auto logit_from_float = [](float& item) -> float& { | ||||||||||
return item; | ||||||||||
}; | ||||||||||
|
||||||||||
llama_log_softmax(candidates->data, candidates->size, logit_from_token_data); | ||||||||||
|
||||||||||
auto* guidance_logits = llama_get_logits(guidance_ctx); | ||||||||||
llama_log_softmax(guidance_logits, n_vocab, logit_from_float); | ||||||||||
|
||||||||||
for (int i = 0; i < n_vocab; ++i) { | ||||||||||
float guidance_logit = guidance_logits[i]; | ||||||||||
float base_logit = candidates->data[i].logit; | ||||||||||
|
float guidance_logit = guidance_logits[i]; | |
float base_logit = candidates->data[i].logit; | |
float logit_guidance = guidance_logits[i]; | |
float logit_base = candidates->data[i].logit; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess
llama_context_params_from_gpt_params()
should fit better.We tend to use
get
andset
to access properties, while here we constructcontext_params