Skip to content

Commit dbf8544

Browse files
committed
llama : use struct llama_sampling in the sampling API
ggml-ci
1 parent f866cb9 commit dbf8544

File tree

30 files changed

+438
-396
lines changed

30 files changed

+438
-396
lines changed

common/common.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2125,7 +2125,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
21252125
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
21262126
llama_kv_cache_clear(lctx);
21272127
llama_synchronize(lctx);
2128-
llama_reset_timings(lctx);
2128+
llama_reset_timings(lctx, nullptr, nullptr);
21292129
}
21302130

21312131
return std::make_tuple(model, lctx);

common/sampling.cpp

+32-33
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22

33
#include <random>
44

5-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) {
5+
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl) {
66
struct llama_sampling_context * result = new llama_sampling_context();
77

88
result->params = params;
9-
result->seq_id = seq_id;
10-
result->ctx = ctx;
9+
result->smpl = smpl;
1110
result->grammar = nullptr;
1211

1312
// if there is a grammar, parse it
@@ -43,7 +42,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
4342

4443
result->n_valid = 0;
4544

46-
llama_sampling_set_rng_seed(result, params.seed);
45+
llama_sampling_set_rng_seed(result->smpl, params.seed);
4746

4847
return result;
4948
}
@@ -79,13 +78,6 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
7978
ctx->n_valid = 0;
8079
}
8180

82-
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
83-
if (seed == LLAMA_DEFAULT_SEED) {
84-
seed = std::random_device{}();
85-
}
86-
llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id);
87-
}
88-
8981
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
9082
if (dst->grammar) {
9183
llama_grammar_free(dst->grammar);
@@ -230,10 +222,13 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
230222

231223
// no reasons to expose this function in header
232224
static void sampler_queue(
233-
struct llama_context * ctx_main,
234-
const llama_sampling_params & params,
225+
struct llama_sampling_context * ctx_sampling,
235226
llama_token_data_array & cur_p,
236227
size_t min_keep) {
228+
llama_sampling * smpl = ctx_sampling->smpl;
229+
230+
const llama_sampling_params & params = ctx_sampling->params;
231+
237232
const float temp = params.temp;
238233
const float dynatemp_range = params.dynatemp_range;
239234
const float dynatemp_exponent = params.dynatemp_exponent;
@@ -246,18 +241,18 @@ static void sampler_queue(
246241

247242
for (auto sampler_type : samplers_sequence) {
248243
switch (sampler_type) {
249-
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
250-
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
251-
case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
252-
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
253-
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
244+
case llama_sampler_type::TOP_K : llama_sampling_top_k (smpl, &cur_p, top_k, min_keep); break;
245+
case llama_sampler_type::TFS_Z : llama_sampling_tail_free(smpl, &cur_p, tfs_z, min_keep); break;
246+
case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, &cur_p, typical_p, min_keep); break;
247+
case llama_sampler_type::TOP_P : llama_sampling_top_p (smpl, &cur_p, top_p, min_keep); break;
248+
case llama_sampler_type::MIN_P : llama_sampling_min_p (smpl, &cur_p, min_p, min_keep); break;
254249
case llama_sampler_type::TEMPERATURE:
255250
if (dynatemp_range > 0) {
256251
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
257252
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
258-
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
253+
llama_sampling_entropy(smpl, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
259254
} else {
260-
llama_sample_temp(ctx_main, &cur_p, temp);
255+
llama_sampling_temp(smpl, &cur_p, temp);
261256
}
262257
break;
263258
default : break;
@@ -271,6 +266,8 @@ static llama_token llama_sampling_sample_impl(
271266
struct llama_context * ctx_cfg,
272267
const int idx,
273268
bool is_resampling) {
269+
llama_sampling * smpl = ctx_sampling->smpl;
270+
274271
const llama_sampling_params & params = ctx_sampling->params;
275272

276273
const float temp = params.temp;
@@ -287,26 +284,26 @@ static llama_token llama_sampling_sample_impl(
287284

288285
if (temp < 0.0) {
289286
// greedy sampling, with probs
290-
llama_sample_softmax(ctx_main, &cur_p);
287+
llama_sampling_softmax(smpl, &cur_p);
291288
id = cur_p.data[0].id;
292289
} else if (temp == 0.0) {
293290
// greedy sampling, no probs
294-
id = llama_sample_token_greedy(ctx_main, &cur_p);
291+
id = llama_sampling_sample_greedy(smpl, &cur_p);
295292
} else {
296293
if (mirostat == 1) {
297294
const int mirostat_m = 100;
298-
llama_sample_temp(ctx_main, &cur_p, temp);
299-
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
295+
llama_sampling_temp(smpl, &cur_p, temp);
296+
id = llama_sampling_sample_mirostat(smpl, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
300297
} else if (mirostat == 2) {
301-
llama_sample_temp(ctx_main, &cur_p, temp);
302-
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
298+
llama_sampling_temp(smpl, &cur_p, temp);
299+
id = llama_sampling_sample_mirostat_v2(smpl, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
303300
} else {
304301
// temperature sampling
305302
size_t min_keep = std::max(1, params.min_keep);
306303

307-
sampler_queue(ctx_main, params, cur_p, min_keep);
304+
sampler_queue(ctx_sampling, cur_p, min_keep);
308305

309-
id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id);
306+
id = llama_sampling_sample(smpl, &cur_p);
310307

311308
//{
312309
// const int n_top = 10;
@@ -315,11 +312,11 @@ static llama_token llama_sampling_sample_impl(
315312
// for (int i = 0; i < n_top; i++) {
316313
// const llama_token id = cur_p.data[i].id;
317314
// (void)id; // To avoid a warning that id is unused when logging is disabled.
318-
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
315+
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p);
319316
// }
320317
//}
321318

322-
//LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
319+
//LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(smpl, id).c_str());
323320
}
324321
}
325322

@@ -360,6 +357,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
360357
const int idx,
361358
bool apply_grammar,
362359
std::vector<float> * original_logits) {
360+
llama_sampling * smpl = ctx_sampling->smpl;
361+
363362
const llama_sampling_params & params = ctx_sampling->params;
364363

365364
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@@ -390,7 +389,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
390389

391390
if (ctx_cfg) {
392391
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
393-
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
392+
llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale);
394393
}
395394

396395
cur.resize(n_vocab);
@@ -407,7 +406,7 @@ static llama_token_data_array llama_sampling_prepare_impl(
407406
if (penalty_tokens_used_size) {
408407
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
409408

410-
llama_sample_repetition_penalties(ctx_main, &cur_p,
409+
llama_sampling_repetition_penalties(smpl, &cur_p,
411410
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
412411
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
413412

@@ -445,7 +444,7 @@ llama_token_data_array llama_sampling_prepare(
445444
const int idx,
446445
bool apply_grammar,
447446
std::vector<float> * original_logits) {
448-
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
447+
return llama_sampling_prepare_impl(ctx_sampling, ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
449448
}
450449

451450
void llama_sampling_accept(

common/sampling.h

+2-7
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,10 @@ struct llama_sampling_context {
7070
// parameters that will be used for sampling
7171
llama_sampling_params params;
7272

73-
llama_seq_id seq_id;
74-
7573
// mirostat sampler state
7674
float mirostat_mu;
7775

78-
llama_context * ctx; // TMP
76+
llama_sampling * smpl;
7977
llama_grammar * grammar;
8078

8179
// internal
@@ -91,7 +89,7 @@ struct llama_sampling_context {
9189
#include "common.h"
9290

9391
// Create a new sampling context instance.
94-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id);
92+
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl);
9593

9694
void llama_sampling_free(struct llama_sampling_context * ctx);
9795

@@ -100,9 +98,6 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
10098
// - reset grammar
10199
void llama_sampling_reset(llama_sampling_context * ctx);
102100

103-
// Set the sampler seed
104-
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);
105-
106101
// Copy the sampler context
107102
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
108103

examples/batched-bench/batched-bench.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ int main(int argc, char ** argv) {
200200
}
201201
}
202202

203-
llama_print_timings(ctx);
203+
llama_print_timings(ctx, nullptr, nullptr);
204204

205205
llama_batch_free(batch);
206206

examples/batched/batched.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ int main(int argc, char ** argv) {
6464
ctx_params.n_batch = std::max(n_predict, n_parallel);
6565

6666
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
67+
llama_sampling * smpl = llama_get_sampling(ctx);
6768

6869
if (ctx == NULL) {
6970
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@@ -180,13 +181,13 @@ int main(int argc, char ** argv) {
180181
const float top_p = 0.9f;
181182
const float temp = 0.4f;
182183

183-
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
184-
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
185-
llama_sample_temp (ctx, &candidates_p, temp);
184+
llama_sampling_top_k(smpl, &candidates_p, top_k, 1);
185+
llama_sampling_top_p(smpl, &candidates_p, top_p, 1);
186+
llama_sampling_temp (smpl, &candidates_p, temp);
186187

187-
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
188+
const llama_token new_token_id = llama_sampling_sample(smpl, &candidates_p);
188189

189-
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
190+
//const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
190191

191192
// is it an end of generation? -> mark the stream as finished
192193
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
@@ -244,12 +245,13 @@ int main(int argc, char ** argv) {
244245
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
245246
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
246247

247-
llama_print_timings(ctx);
248+
llama_print_timings(ctx, smpl, nullptr);
248249

249250
fprintf(stderr, "\n");
250251

251252
llama_batch_free(batch);
252253

254+
llama_sampling_free(smpl);
253255
llama_free(ctx);
254256
llama_free_model(model);
255257

examples/embedding/embedding.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ int main(int argc, char ** argv) {
258258
}
259259

260260
// clean up
261-
llama_print_timings(ctx);
261+
llama_print_timings(ctx, nullptr, nullptr);
262262
llama_batch_free(batch);
263263
llama_free(ctx);
264264
llama_free_model(model);

examples/eval-callback/eval-callback.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ int main(int argc, char ** argv) {
182182
return 1;
183183
}
184184

185-
llama_print_timings(ctx);
185+
llama_print_timings(ctx, nullptr, nullptr);
186186

187187
llama_free(ctx);
188188
llama_free_model(model);

examples/gritlm/gritlm.cpp

+15-14
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
1010
std::vector<std::vector<float>> result;
1111

12-
const llama_model * mdl = llama_get_model(ctx);
12+
const llama_model * model = llama_get_model(ctx);
1313

1414
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
1515

@@ -18,16 +18,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
1818

1919
const std::string input_string = instruction + sentences[i];
2020

21-
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
21+
std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
2222

2323
const int32_t n_toks = inputs.size();
2424

2525
// GritLM seems to have EOS = ""
2626
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
27-
// inputs.push_back(llama_token_eos(mdl));
27+
// inputs.push_back(llama_token_eos(model));
2828

2929
// we want to ignore instruction tokens for mean pooling
30-
const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
30+
const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
3131

3232
#ifdef GRIT_DEBUG
3333
// debug tokens - should be matching as referenced in the GritLM sample
@@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
5151
llama_decode(ctx, batch);
5252

5353
// get embedding dimensions
54-
uint64_t n_embd = llama_n_embd(mdl);
54+
uint64_t n_embd = llama_n_embd(model);
5555

5656
// allocate embedding output
5757
std::vector<float> emb_unorm(n_embd, 0.0f);
@@ -95,16 +95,17 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
9595
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
9696
std::string result;
9797

98-
const llama_model * mdl = llama_get_model(ctx);
99-
llama_token eos_token = llama_token_eos(mdl);
98+
const llama_model * model = llama_get_model(ctx);
99+
llama_sampling * smpl = llama_get_sampling(ctx);
100+
llama_token eos_token = llama_token_eos(model);
100101

101102
llama_kv_cache_clear(ctx);
102103
llama_set_embeddings(ctx, false);
103104
llama_set_causal_attn(ctx, true);
104105

105106
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
106107

107-
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
108+
std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
108109
int32_t i_current_token = 0;
109110

110111
while (true) {
@@ -118,14 +119,14 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
118119
llama_decode(ctx, bat);
119120
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
120121

121-
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
122+
auto candidates = std::vector<llama_token_data>(llama_n_vocab(model));
122123
auto n_candidates = (int32_t)candidates.size();
123124
for (int32_t token = 0; token < n_candidates; token++) {
124125
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
125126
}
126127
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
127128

128-
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
129+
llama_token token = llama_sampling_sample_greedy(smpl, &candidates_p);
129130
if (token == eos_token) {
130131
break;
131132
}
@@ -167,10 +168,10 @@ int main(int argc, char * argv[]) {
167168

168169
llama_backend_init();
169170

170-
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
171+
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
171172

172173
// create generation context
173-
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
174+
llama_context * ctx = llama_new_context_with_model(model, cparams);
174175

175176
// ### Embedding/Representation ###
176177
// samples taken from: https://github.com/ContextualAI/gritlm#basic
@@ -191,7 +192,7 @@ int main(int argc, char * argv[]) {
191192
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
192193
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
193194

194-
const int n_embd = llama_n_embd(mdl);
195+
const int n_embd = llama_n_embd(model);
195196

196197
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
197198
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
@@ -212,7 +213,7 @@ int main(int argc, char * argv[]) {
212213
}
213214

214215
llama_free(ctx);
215-
llama_free_model(mdl);
216+
llama_free_model(model);
216217
llama_backend_free();
217218

218219
return 0;

0 commit comments

Comments
 (0)