Skip to content

Commit 57fb835

Browse files
committed
cont : no need for special "greedy" logic
top-k == 1 is the same
1 parent cb75beb commit 57fb835

File tree

1 file changed

+37
-50
lines changed

1 file changed

+37
-50
lines changed

common/sampling.cpp

Lines changed: 37 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -171,59 +171,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
171171
params.penalize_nl,
172172
params.ignore_eos));
173173

174-
if (params.temp >= 0.0f) {
175-
if (params.mirostat == 0) {
176-
for (const auto & cnstr : params.samplers) {
177-
switch (cnstr) {
178-
case COMMON_SAMPLER_TYPE_TOP_K:
179-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
180-
break;
181-
case COMMON_SAMPLER_TYPE_TOP_P:
182-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
183-
break;
184-
case COMMON_SAMPLER_TYPE_MIN_P:
185-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
186-
break;
187-
case COMMON_SAMPLER_TYPE_XTC:
188-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
189-
break;
190-
case COMMON_SAMPLER_TYPE_TFS_Z:
191-
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
192-
break;
193-
case COMMON_SAMPLER_TYPE_TYPICAL_P:
194-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
195-
break;
196-
case COMMON_SAMPLER_TYPE_TEMPERATURE:
197-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
198-
break;
199-
case COMMON_SAMPLER_TYPE_INFILL:
200-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
201-
break;
202-
default:
203-
GGML_ASSERT(false && "unknown sampler type");
204-
}
174+
if (params.mirostat == 0) {
175+
for (const auto & cnstr : params.samplers) {
176+
switch (cnstr) {
177+
case COMMON_SAMPLER_TYPE_TOP_K:
178+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
179+
break;
180+
case COMMON_SAMPLER_TYPE_TOP_P:
181+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
182+
break;
183+
case COMMON_SAMPLER_TYPE_MIN_P:
184+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
185+
break;
186+
case COMMON_SAMPLER_TYPE_XTC:
187+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
188+
break;
189+
case COMMON_SAMPLER_TYPE_TFS_Z:
190+
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
191+
break;
192+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
193+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
194+
break;
195+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
196+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
197+
break;
198+
case COMMON_SAMPLER_TYPE_INFILL:
199+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
200+
break;
201+
default:
202+
GGML_ASSERT(false && "unknown sampler type");
205203
}
206-
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
207-
} else if (params.mirostat == 1) {
208-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
209-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
210-
} else if (params.mirostat == 2) {
211-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
212-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
213-
} else {
214-
GGML_ASSERT(false && "unknown mirostat version");
215204
}
205+
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
206+
} else if (params.mirostat == 1) {
207+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
208+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
209+
} else if (params.mirostat == 2) {
210+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
211+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
216212
} else {
217-
// negative temperatures will trigger "greedy" sampling: simply take the most likely token each time
218-
if (params.n_probs > 0) {
219-
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
220-
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
221-
//
222-
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
223-
// it is much faster, since we avoid sorting all tokens and should give a good approximation
224-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
225-
}
226-
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
213+
GGML_ASSERT(false && "unknown mirostat version");
227214
}
228215

229216
return result;

0 commit comments

Comments
 (0)