Skip to content

Commit 77998ae

Browse files
committed
sampling : option to use internal set of candidates
ggml-ci
1 parent 9dd2061 commit 77998ae

File tree

8 files changed

+82
-25
lines changed

8 files changed

+82
-25
lines changed

common/sampling.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,25 @@ llama_token llama_sampling_sample(
199199
int idx) {
200200
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
201201

202-
auto * cur_p = llama_sampling_get_candidates(smpl);
202+
// first, sample the token without any grammar constraints
203+
auto id = llama_sampling_sample(smpl, nullptr);
203204

204-
llama_sampling_grammar(smpl, cur_p);
205+
// create an array with a single token data element for the sampled id
206+
llama_token_data single_token_data = {id, 1.0f, 0.0f};
207+
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
205208

206-
return llama_sampling_sample(smpl, cur_p);
209+
llama_sampling_grammar(smpl, &single_token_data_array);
210+
211+
// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
212+
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
213+
if (is_valid) {
214+
return id;
215+
}
216+
217+
// if the token is not valid, sample again, after applying the grammar constraints
218+
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));
219+
220+
llama_sampling_grammar(smpl, nullptr);
221+
222+
return llama_sampling_sample(smpl, nullptr);
207223
}

common/sampling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,4 @@ std::vector<enum llama_sampler_type> llama_sampling_types_from_chars(const std::
6767
llama_token llama_sampling_sample(
6868
struct llama_sampling * smpl,
6969
struct llama_context * ctx,
70-
int idx = -1);
70+
int idx);

examples/batched.swift/Sources/main.swift

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -136,28 +136,17 @@ while n_cur <= n_len {
136136
continue
137137
}
138138

139-
var n_vocab = llama_n_vocab(model)
140139
var logits = llama_get_logits_ith(context, i_batch[i])
141140

142-
var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
141+
llama_sampling_set_logits(smpl, logits)
143142

144-
for token_id in 0 ..< n_vocab {
145-
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
146-
}
147-
148-
var candidates_p: llama_token_data_array = .init(
149-
data: &candidates,
150-
size: candidates.count,
151-
sorted: false
152-
)
153-
154-
llama_sampling_top_k(smpl, &candidates_p)
155-
llama_sampling_top_p(smpl, &candidates_p)
156-
llama_sampling_temp (smpl, &candidates_p)
143+
llama_sampling_top_k(smpl, nil)
144+
llama_sampling_top_p(smpl, nil)
145+
llama_sampling_temp (smpl, nil)
157146

158-
let new_token_id = llama_sampling_sample_dist(smpl, &candidates_p)
147+
let new_token_id = llama_sampling_sample_dist(smpl, nil)
159148

160-
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
149+
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil);
161150

162151
// is it an end of stream? -> mark the stream as finished
163152
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ int main(int argc, char ** argv) {
417417
embd.clear();
418418

419419
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
420-
const llama_token id = llama_sampling_sample(smpl, ctx);
420+
const llama_token id = llama_sampling_sample(smpl, ctx, -1);
421421

422422
llama_sampling_accept(smpl, id, true);
423423

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
4343
static const char * sample(struct llama_sampling * smpl,
4444
struct llama_context * ctx_llama,
4545
int * n_past) {
46-
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
46+
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
4747
llama_sampling_accept(smpl, id, true);
4848
static std::string ret;
4949
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {

examples/llava/minicpmv-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
166166
static const char * sample(struct llama_sampling * smpl,
167167
struct llama_context * ctx_llama,
168168
int * n_past) {
169-
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
169+
const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1);
170170
llama_sampling_accept(smpl, id, true);
171171
static std::string ret;
172172
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ int main(int argc, char ** argv) {
650650
LOG("saved session to %s\n", path_session.c_str());
651651
}
652652

653-
const llama_token id = llama_sampling_sample(smpl, ctx);
653+
const llama_token id = llama_sampling_sample(smpl, ctx, -1);
654654

655655
llama_sampling_accept(smpl, id, /* apply_grammar= */ true);
656656

src/llama.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20139,42 +20139,70 @@ llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * s
2013920139
void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2014020140
time_meas tm(smpl->t_sample_us);
2014120141

20142+
if (candidates == nullptr) {
20143+
candidates = &smpl->cur_p;
20144+
}
20145+
2014220146
llama_sampling_softmax_impl(candidates);
2014320147
}
2014420148

2014520149
void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2014620150
time_meas tm(smpl->t_sample_us);
2014720151

20152+
if (candidates == nullptr) {
20153+
candidates = &smpl->cur_p;
20154+
}
20155+
2014820156
llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep);
2014920157
}
2015020158

2015120159
void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2015220160
time_meas tm(smpl->t_sample_us);
2015320161

20162+
if (candidates == nullptr) {
20163+
candidates = &smpl->cur_p;
20164+
}
20165+
2015420166
llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep);
2015520167
}
2015620168

2015720169
void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2015820170
time_meas tm(smpl->t_sample_us);
2015920171

20172+
if (candidates == nullptr) {
20173+
candidates = &smpl->cur_p;
20174+
}
20175+
2016020176
llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep);
2016120177
}
2016220178

2016320179
void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2016420180
time_meas tm(smpl->t_sample_us);
2016520181

20182+
if (candidates == nullptr) {
20183+
candidates = &smpl->cur_p;
20184+
}
20185+
2016620186
llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep);
2016720187
}
2016820188

2016920189
void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2017020190
time_meas tm(smpl->t_sample_us);
2017120191

20192+
if (candidates == nullptr) {
20193+
candidates = &smpl->cur_p;
20194+
}
20195+
2017220196
llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep);
2017320197
}
2017420198

2017520199
void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2017620200
time_meas tm(smpl->t_sample_us);
2017720201

20202+
if (candidates == nullptr) {
20203+
candidates = &smpl->cur_p;
20204+
}
20205+
2017820206
if (smpl->params.dynatemp_range > 0) {
2017920207
const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range);
2018020208
const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range);
@@ -20188,6 +20216,10 @@ void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array *
2018820216
void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2018920217
time_meas tm(smpl->t_grammar_us);
2019020218

20219+
if (candidates == nullptr) {
20220+
candidates = &smpl->cur_p;
20221+
}
20222+
2019120223
if (smpl->grammar) {
2019220224
llama_sampling_grammar_impl(candidates, *smpl->grammar);
2019320225

@@ -20200,6 +20232,10 @@ void llama_sampling_penalties(
2020020232
llama_token_data_array * candidates) {
2020120233
time_meas tm(smpl->t_sample_us);
2020220234

20235+
if (candidates == nullptr) {
20236+
candidates = &smpl->cur_p;
20237+
}
20238+
2020320239
const size_t penalty_last_n = std::min<size_t>(smpl->params.penalty_last_n, smpl->prev.size());
2020420240

2020520241
const float penalty_repeat = smpl->params.penalty_repeat;
@@ -20224,6 +20260,10 @@ void llama_sampling_penalties(
2022420260
llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2022520261
time_meas tm(smpl->t_sample_us);
2022620262

20263+
if (candidates == nullptr) {
20264+
candidates = &smpl->cur_p;
20265+
}
20266+
2022720267
const auto type = smpl->params.mirostat;
2022820268

2022920269
llama_token res;
@@ -20254,6 +20294,10 @@ llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_t
2025420294
llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2025520295
time_meas tm(smpl->t_sample_us);
2025620296

20297+
if (candidates == nullptr) {
20298+
candidates = &smpl->cur_p;
20299+
}
20300+
2025720301
auto res = llama_sampling_sample_greedy_impl(candidates);
2025820302

2025920303
smpl->n_sample++;
@@ -20264,6 +20308,10 @@ llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_tok
2026420308
llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2026520309
time_meas tm(smpl->t_sample_us);
2026620310

20311+
if (candidates == nullptr) {
20312+
candidates = &smpl->cur_p;
20313+
}
20314+
2026720315
auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng);
2026820316

2026920317
smpl->n_sample++;
@@ -20274,6 +20322,10 @@ llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token
2027420322
llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
2027520323
time_meas tm(smpl->t_sample_us);
2027620324

20325+
if (candidates == nullptr) {
20326+
candidates = &smpl->cur_p;
20327+
}
20328+
2027720329
const auto & params = smpl->params;
2027820330

2027920331
const float temp = params.temp;

0 commit comments

Comments
 (0)