Skip to content

Commit f866cb9

Browse files
committed
llama : move sampling rngs from common to llama
ggml-ci
1 parent 938943c commit f866cb9

22 files changed

+342
-344
lines changed

common/sampling.cpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
#define LLAMA_API_INTERNAL
21
#include "sampling.h"
2+
33
#include <random>
44

5-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
5+
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) {
66
struct llama_sampling_context * result = new llama_sampling_context();
77

88
result->params = params;
9+
result->seq_id = seq_id;
10+
result->ctx = ctx;
911
result->grammar = nullptr;
1012

1113
// if there is a grammar, parse it
@@ -81,7 +83,7 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s
8183
if (seed == LLAMA_DEFAULT_SEED) {
8284
seed = std::random_device{}();
8385
}
84-
ctx->rng.seed(seed);
86+
llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id);
8587
}
8688

8789
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
@@ -271,10 +273,10 @@ static llama_token llama_sampling_sample_impl(
271273
bool is_resampling) {
272274
const llama_sampling_params & params = ctx_sampling->params;
273275

274-
const float temp = params.temp;
275-
const int mirostat = params.mirostat;
276-
const float mirostat_tau = params.mirostat_tau;
277-
const float mirostat_eta = params.mirostat_eta;
276+
const float temp = params.temp;
277+
const int mirostat = params.mirostat;
278+
const float mirostat_tau = params.mirostat_tau;
279+
const float mirostat_eta = params.mirostat_eta;
278280

279281
std::vector<float> original_logits;
280282
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
@@ -304,7 +306,7 @@ static llama_token llama_sampling_sample_impl(
304306

305307
sampler_queue(ctx_main, params, cur_p, min_keep);
306308

307-
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
309+
id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id);
308310

309311
//{
310312
// const int n_top = 10;

common/sampling.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,12 @@ struct llama_sampling_context {
7070
// parameters that will be used for sampling
7171
llama_sampling_params params;
7272

73+
llama_seq_id seq_id;
74+
7375
// mirostat sampler state
7476
float mirostat_mu;
7577

78+
llama_context * ctx; // TMP
7679
llama_grammar * grammar;
7780

7881
// internal
@@ -81,15 +84,14 @@ struct llama_sampling_context {
8184
// TODO: replace with ring-buffer
8285
std::vector<llama_token> prev;
8386
std::vector<llama_token_data> cur;
84-
size_t n_valid; // Number of correct top tokens with correct probabilities.
8587

86-
std::mt19937 rng;
88+
size_t n_valid; // Number of correct top tokens with correct probabilities.
8789
};
8890

8991
#include "common.h"
9092

9193
// Create a new sampling context instance.
92-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
94+
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id);
9395

9496
void llama_sampling_free(struct llama_sampling_context * ctx);
9597

examples/gbnf-validator/gbnf-validator.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
#define LLAMA_API_INTERNAL
2-
31
#include "grammar-parser.h"
42
#include "ggml.h"
53
#include "llama.h"
4+
#include "llama-impl.h"
65
#include "unicode.h"
76

87
#include <cstdio>

examples/infill/infill.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ int main(int argc, char ** argv) {
346346

347347
std::vector<llama_token> embd;
348348

349-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
349+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0);
350350

351351
while (n_remain != 0 || params.interactive) {
352352
// predict

examples/llava/llava-cli.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
191191

192192
LOG_TEE("\n");
193193

194-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
194+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->ctx_llama, 0);
195195
if (!ctx_sampling) {
196196
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
197197
exit(1);

examples/lookahead/lookahead.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
118118
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
119119

120120
// target model sampling context
121-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
121+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0);
122122

123123
// verification n-grams
124124
std::vector<ngram_data> ngrams_cur(G);

examples/lookup/lookup.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ int main(int argc, char ** argv){
106106

107107
bool has_eos = false;
108108

109-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
109+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx, 0);
110110

111111
std::vector<llama_token> draft;
112112

examples/main/main.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ int main(int argc, char ** argv) {
527527
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
528528
}
529529

530-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
530+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams, ctx, 0);
531531
if (!ctx_sampling) {
532532
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
533533
exit(1);

examples/parallel/parallel.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
161161
for (size_t i = 0; i < clients.size(); ++i) {
162162
auto & client = clients[i];
163163
client.id = i;
164-
client.ctx_sampling = llama_sampling_init(params.sparams);
164+
client.ctx_sampling = llama_sampling_init(params.sparams, ctx, i);
165165
}
166166

167167
std::vector<llama_token> tokens_system;

examples/quantize-stats/quantize-stats.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#define LLAMA_API_INTERNAL
21
#include "common.h"
32
#include "ggml.h"
43
#include "llama.h"
4+
#include "llama-impl.h"
55

66
#include <algorithm>
77
#include <cassert>

examples/server/server.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ struct server_context {
10901090
if (slot.ctx_sampling != nullptr) {
10911091
llama_sampling_free(slot.ctx_sampling);
10921092
}
1093-
slot.ctx_sampling = llama_sampling_init(slot.sparams);
1093+
slot.ctx_sampling = llama_sampling_init(slot.sparams, ctx, slot.id);
10941094
if (slot.ctx_sampling == nullptr) {
10951095
// for now, the only error that may happen here is invalid grammar
10961096
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);

examples/speculative/speculative.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
175175
bool has_eos = false;
176176

177177
// target model sampling context
178-
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
178+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, ctx_tgt, 0);
179179

180180
// draft sequence data
181181
std::vector<seq_draft> drafts(n_seq_dft);
@@ -186,7 +186,7 @@ int main(int argc, char ** argv) {
186186
}
187187

188188
for (int s = 0; s < n_seq_dft; ++s) {
189-
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
189+
drafts[s].ctx_sampling = llama_sampling_init(params.sparams, ctx_dft, s);
190190
}
191191

192192
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);

include/llama.h

+12-57
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
4141

4242
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
43-
#define LLAMA_SESSION_VERSION 7
43+
#define LLAMA_SESSION_VERSION 8
4444

4545
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
4646
#define LLAMA_STATE_SEQ_VERSION 1
@@ -1031,6 +1031,9 @@ extern "C" {
10311031
// Sets the current rng seed.
10321032
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
10331033

1034+
LLAMA_API DEPRECATED(void llama_set_rng_seed_seq(struct llama_context * ctx, uint32_t seed, llama_seq_id),
1035+
"temporary API, until llama_sampling_context is implemented, do not use");
1036+
10341037
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
10351038
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
10361039
LLAMA_API void llama_sample_repetition_penalties(
@@ -1137,11 +1140,18 @@ extern "C" {
11371140
struct llama_context * ctx,
11381141
llama_token_data_array * candidates);
11391142

1140-
/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
1143+
/// @details Randomly selects a token from the candidates based on their probabilities using RNG[0] of ctx.
11411144
LLAMA_API llama_token llama_sample_token(
11421145
struct llama_context * ctx,
11431146
llama_token_data_array * candidates);
11441147

1148+
/// @details Same as llama_sample_token, but uses a seqeuence-specific RNG[seq_id].
1149+
LLAMA_API DEPRECATED(llama_token llama_sample_token_seq(
1150+
struct llama_context * ctx,
1151+
llama_token_data_array * candidates,
1152+
llama_seq_id seq_id),
1153+
"temporary API, until llama_sampling_context is implemented, do not use");
1154+
11451155
//
11461156
// Model split
11471157
//
@@ -1175,59 +1185,4 @@ extern "C" {
11751185
}
11761186
#endif
11771187

1178-
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only
1179-
#ifdef LLAMA_API_INTERNAL
1180-
1181-
#include <random>
1182-
#include <string>
1183-
#include <vector>
1184-
1185-
struct ggml_tensor;
1186-
1187-
const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
1188-
struct llama_context * ctx
1189-
);
1190-
1191-
struct llama_partial_utf8 {
1192-
uint32_t value; // bit value so far (unshifted)
1193-
int n_remain; // num bytes remaining; -1 indicates invalid sequence
1194-
};
1195-
1196-
struct llama_grammar_candidate {
1197-
size_t index;
1198-
const uint32_t * code_points;
1199-
llama_partial_utf8 partial_utf8;
1200-
};
1201-
1202-
using llama_grammar_rule = std::vector< llama_grammar_element>;
1203-
using llama_grammar_stack = std::vector<const llama_grammar_element *>;
1204-
1205-
using llama_grammar_rules = std::vector<llama_grammar_rule>;
1206-
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
1207-
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
1208-
1209-
const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
1210-
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);
1211-
1212-
void llama_grammar_accept(
1213-
const llama_grammar_rules & rules,
1214-
const llama_grammar_stacks & stacks,
1215-
const uint32_t chr,
1216-
llama_grammar_stacks & new_stacks);
1217-
1218-
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
1219-
const llama_grammar_rules & rules,
1220-
const llama_grammar_stack & stack,
1221-
const llama_grammar_candidates & candidates);
1222-
1223-
std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
1224-
const std::string & src,
1225-
llama_partial_utf8 partial_start);
1226-
1227-
// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
1228-
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
1229-
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng);
1230-
1231-
#endif // LLAMA_API_INTERNAL
1232-
12331188
#endif // LLAMA_H

src/llama-grammar.cpp

+20-31
Original file line numberDiff line numberDiff line change
@@ -445,15 +445,15 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
445445
delete grammar;
446446
}
447447

448-
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) {
449-
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
448+
struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) {
449+
llama_grammar * result = new llama_grammar{ grammar.rules, grammar.stacks, grammar.partial_utf8 };
450450

451451
// redirect elements in stacks to point to new rules
452452
for (size_t is = 0; is < result->stacks.size(); is++) {
453453
for (size_t ie = 0; ie < result->stacks[is].size(); ie++) {
454-
for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) {
455-
for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) {
456-
if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) {
454+
for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) {
455+
for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) {
456+
if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) {
457457
result->stacks[is][ie] = &result->rules[ir0][ir1];
458458
}
459459
}
@@ -464,14 +464,9 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
464464
return result;
465465
}
466466

467-
void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
468-
GGML_ASSERT(grammar);
469-
GGML_ASSERT(vocab);
470-
471-
int64_t t_start_sample_us = ggml_time_us();
472-
467+
void llama_grammar_sample_impl(const struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token_data_array * candidates) {
473468
bool allow_eog = false;
474-
for (const auto & stack : grammar->stacks) {
469+
for (const auto & stack : grammar.stacks) {
475470
if (stack.empty()) {
476471
allow_eog = true;
477472
break;
@@ -486,54 +481,48 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
486481

487482
for (size_t i = 0; i < candidates->size; ++i) {
488483
const llama_token id = candidates->data[i].id;
489-
const std::string & piece = vocab->cache_token_to_piece.at(id);
484+
const std::string & piece = vocab.cache_token_to_piece.at(id);
490485

491-
if (llama_token_is_eog_impl(*vocab, id)) {
486+
if (llama_token_is_eog_impl(vocab, id)) {
492487
if (!allow_eog) {
493488
candidates->data[i].logit = -INFINITY;
494489
}
495490
} else if (piece.empty() || piece[0] == 0) {
496491
candidates->data[i].logit = -INFINITY;
497492
} else {
498-
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
493+
candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
499494
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
500495
}
501496
}
502497

503-
const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar);
498+
const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar);
504499
for (const auto & reject : rejects) {
505500
candidates->data[reject.index].logit = -INFINITY;
506501
}
507-
508-
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
509502
}
510503

511-
void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
512-
const int64_t t_start_sample_us = ggml_time_us();
513-
514-
if (llama_token_is_eog_impl(*vocab, token)) {
515-
for (const auto & stack : grammar->stacks) {
504+
void llama_grammar_accept_token_impl(struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token token) {
505+
if (llama_token_is_eog_impl(vocab, token)) {
506+
for (const auto & stack : grammar.stacks) {
516507
if (stack.empty()) {
517508
return;
518509
}
519510
}
520511
GGML_ASSERT(false);
521512
}
522513

523-
const std::string & piece = vocab->cache_token_to_piece.at(token);
514+
const std::string & piece = vocab.cache_token_to_piece.at(token);
524515

525516
// Note terminating 0 in decoded string
526-
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
517+
const auto decoded = decode_utf8(piece, grammar.partial_utf8);
527518
const auto & code_points = decoded.first;
528519

529520
llama_grammar_stacks tmp_new_stacks;
530521
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
531-
llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks);
532-
grammar->stacks = tmp_new_stacks;
522+
llama_grammar_accept(grammar.rules, grammar.stacks, *it, tmp_new_stacks);
523+
grammar.stacks = tmp_new_stacks;
533524
}
534525

535-
grammar->partial_utf8 = decoded.second;
536-
GGML_ASSERT(!grammar->stacks.empty());
537-
538-
smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
526+
grammar.partial_utf8 = decoded.second;
527+
GGML_ASSERT(!grammar.stacks.empty());
539528
}

0 commit comments

Comments
 (0)