Skip to content

Commit e5150b1

Browse files
committed
cont : no need for special "greedy" logic
top-k == 1 is the same
1 parent fc66b20 commit e5150b1

File tree

1 file changed

+34
-47
lines changed

1 file changed

+34
-47
lines changed

common/sampling.cpp

Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -171,56 +171,43 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
171171
params.penalize_nl,
172172
params.ignore_eos));
173173

174-
if (params.temp >= 0.0f) {
175-
if (params.mirostat == 0) {
176-
for (const auto & cnstr : params.samplers) {
177-
switch (cnstr) {
178-
case COMMON_SAMPLER_TYPE_TOP_K:
179-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
180-
break;
181-
case COMMON_SAMPLER_TYPE_TOP_P:
182-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
183-
break;
184-
case COMMON_SAMPLER_TYPE_MIN_P:
185-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
186-
break;
187-
case COMMON_SAMPLER_TYPE_XTC:
188-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
189-
break;
190-
case COMMON_SAMPLER_TYPE_TFS_Z:
191-
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
192-
break;
193-
case COMMON_SAMPLER_TYPE_TYPICAL_P:
194-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
195-
break;
196-
case COMMON_SAMPLER_TYPE_TEMPERATURE:
197-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
198-
break;
199-
default:
200-
GGML_ASSERT(false && "unknown sampler type");
201-
}
174+
if (params.mirostat == 0) {
175+
for (const auto & cnstr : params.samplers) {
176+
switch (cnstr) {
177+
case COMMON_SAMPLER_TYPE_TOP_K:
178+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
179+
break;
180+
case COMMON_SAMPLER_TYPE_TOP_P:
181+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
182+
break;
183+
case COMMON_SAMPLER_TYPE_MIN_P:
184+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
185+
break;
186+
case COMMON_SAMPLER_TYPE_XTC:
187+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
188+
break;
189+
case COMMON_SAMPLER_TYPE_TFS_Z:
190+
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
191+
break;
192+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
193+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
194+
break;
195+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
196+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
197+
break;
198+
default:
199+
GGML_ASSERT(false && "unknown sampler type");
202200
}
203-
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
204-
} else if (params.mirostat == 1) {
205-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
206-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
207-
} else if (params.mirostat == 2) {
208-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
209-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
210-
} else {
211-
GGML_ASSERT(false && "unknown mirostat version");
212201
}
202+
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
203+
} else if (params.mirostat == 1) {
204+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
205+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
206+
} else if (params.mirostat == 2) {
207+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
208+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
213209
} else {
214-
// negative temperatures will trigger "greedy" sampling: simply take the most likely token each time
215-
if (params.n_probs > 0) {
216-
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
217-
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
218-
//
219-
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
220-
// it is much faster, since we avoid sorting all tokens and should give a good approximation
221-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
222-
}
223-
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
210+
GGML_ASSERT(false && "unknown mirostat version");
224211
}
225212

226213
return result;

0 commit comments

Comments
 (0)