Skip to content

Commit cb75beb

Browse files
committed
sampling : change temperature sampler logic
For t <= 0.0f, keep the max logit intact and set the rest to -inf
1 parent 33a69ec commit cb75beb

File tree

4 files changed

+32
-3
lines changed

4 files changed

+32
-3
lines changed

common/sampling.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
171171
params.penalize_nl,
172172
params.ignore_eos));
173173

174-
if (params.temp > 0.0f) {
174+
if (params.temp >= 0.0f) {
175175
if (params.mirostat == 0) {
176176
for (const auto & cnstr : params.samplers) {
177177
switch (cnstr) {
@@ -214,6 +214,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
214214
GGML_ASSERT(false && "unknown mirostat version");
215215
}
216216
} else {
217+
// negative temperatures will trigger "greedy" sampling: simply take the most likely token each time
217218
if (params.n_probs > 0) {
218219
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
219220
// ref: https://github.com/ggerganov/llama.cpp/pull/9605

include/llama.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,8 +1082,8 @@ extern "C" {
10821082

10831083
// available samplers:
10841084

1085-
LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void);
1086-
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
1085+
LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
1086+
LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed);
10871087

10881088
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
10891089
/// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
@@ -1104,6 +1104,8 @@ extern "C" {
11041104

11051105
/// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
11061106
LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep);
1107+
1108+
/// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
11071109
LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t);
11081110

11091111
/// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.

src/llama-sampling.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,28 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
915915

916916
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
917917
const auto * ctx = (llama_sampler_temp *) smpl->ctx;
918+
919+
if (ctx->temp <= 0.0f) {
920+
// find the token with the highest logit and set the rest to -inf
921+
llama_token max_id = cur_p->data[0].id;
922+
float max_logit = cur_p->data[0].logit;
923+
924+
for (size_t i = 1; i < cur_p->size; ++i) {
925+
if (cur_p->data[i].logit > max_logit) {
926+
max_id = cur_p->data[i].id;
927+
max_logit = cur_p->data[i].logit;
928+
}
929+
}
930+
931+
for (size_t i = 0; i < cur_p->size; ++i) {
932+
if (cur_p->data[i].id != max_id) {
933+
cur_p->data[i].logit = -INFINITY;
934+
}
935+
}
936+
937+
return;
938+
}
939+
918940
for (size_t i = 0; i < cur_p->size; ++i) {
919941
cur_p->data[i].logit /= ctx->temp;
920942
}
@@ -964,6 +986,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
964986
if (ctx->delta > 0) {
965987
const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
966988
const float max_temp = ctx->temp + ctx->delta;
989+
967990
float exponent_val = ctx->exponent;
968991

969992
// no need to do anything if there is only one (or zero) candidates

tests/test-sampling.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,9 @@ static void test_perf() {
274274
int main(void) {
275275
ggml_time_init();
276276

277+
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f);
278+
test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f);
279+
277280
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1);
278281
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3);
279282
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);

0 commit comments

Comments
 (0)