88#include < cstring>
99#include < ctime>
1010#include < cfloat>
11+ #include < chrono>
1112#include < cmath>
1213#include < numeric>
1314#include < random>
@@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
162163 cur_p->size = k;
163164}
164165
166+ static uint32_t get_rng_seed (uint32_t seed) {
167+ if (seed == LLAMA_DEFAULT_SEED) {
168+ // use system clock if std::random_device is not a true RNG
169+ static bool is_rd_prng = std::random_device ().entropy () == 0 ;
170+ if (is_rd_prng) {
171+ return (uint32_t ) std::chrono::system_clock::now ().time_since_epoch ().count ();
172+ }
173+ std::random_device rd;
174+ return rd ();
175+ }
176+ return seed;
177+ }
178+
165179// llama_sampler API
166180
167181const char * llama_sampler_name (const struct llama_sampler * smpl) {
@@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() {
387401
388402struct llama_sampler_dist {
389403 const uint32_t seed;
404+ uint32_t seed_cur;
390405
391406 std::mt19937 rng;
392407};
@@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
416431
417432static void llama_sampler_dist_reset (struct llama_sampler * smpl) {
418433 auto * ctx = (llama_sampler_dist *) smpl->ctx ;
419- ctx->rng = std::mt19937 (ctx->seed );
434+ ctx->seed_cur = get_rng_seed (ctx->seed );
435+ ctx->rng .seed (ctx->seed_cur );
420436}
421437
422438static void llama_sampler_dist_free (struct llama_sampler * smpl) {
@@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
433449};
434450
435451struct llama_sampler * llama_sampler_init_dist (uint32_t seed) {
452+ auto seed_cur = get_rng_seed (seed);
436453 return new llama_sampler {
437454 /* .iface = */ &llama_sampler_dist_i,
438455 /* .ctx = */ new llama_sampler_dist {
439- /* .seed = */ seed,
440- /* .rng = */ std::mt19937 (seed),
456+ /* .seed = */ seed,
457+ /* .seed_cur = */ seed_cur,
458+ /* .rng = */ std::mt19937 (seed_cur),
441459 },
442460 };
443461}
@@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat {
10321050 const int32_t n_vocab;
10331051
10341052 const uint32_t seed;
1053+ uint32_t seed_cur;
10351054
10361055 const float tau;
10371056 const float eta;
@@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
11001119static void llama_sampler_mirostat_reset (struct llama_sampler * smpl) {
11011120 auto * ctx = (llama_sampler_mirostat *) smpl->ctx ;
11021121 ctx->mu = 2 .0f *ctx->tau ;
1103- ctx->rng = std::mt19937 (ctx->seed );
1122+ ctx->seed_cur = get_rng_seed (ctx->seed );
1123+ ctx->rng .seed (ctx->seed_cur );
11041124}
11051125
11061126static void llama_sampler_mirostat_free (struct llama_sampler * smpl) {
@@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
11171137};
11181138
11191139struct llama_sampler * llama_sampler_init_mirostat (int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1140+ auto seed_cur = get_rng_seed (seed);
11201141 return new llama_sampler {
11211142 /* .iface = */ &llama_sampler_mirostat_i,
11221143 /* .ctx = */ new llama_sampler_mirostat {
1123- /* .n_vocab = */ n_vocab,
1124- /* .seed = */ seed,
1125- /* .tau = */ tau,
1126- /* .eta = */ eta,
1127- /* .m = */ m,
1128- /* .mu = */ 2 .0f *tau,
1129- /* .rng = */ std::mt19937 (seed),
1144+ /* .n_vocab = */ n_vocab,
1145+ /* .seed = */ seed,
1146+ /* .seed_cur = */ seed_cur,
1147+ /* .tau = */ tau,
1148+ /* .eta = */ eta,
1149+ /* .m = */ m,
1150+ /* .mu = */ 2 .0f *tau,
1151+ /* .rng = */ std::mt19937 (seed_cur),
11301152 },
11311153 };
11321154}
@@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
11351157
11361158struct llama_sampler_mirostat_v2 {
11371159 const uint32_t seed;
1160+ uint32_t seed_cur;
11381161
11391162 const float tau;
11401163 const float eta;
@@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
11791202static void llama_sampler_mirostat_v2_reset (struct llama_sampler * smpl) {
11801203 auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx ;
11811204 ctx->mu = 2 .0f *ctx->tau ;
1182- ctx->rng = std::mt19937 (ctx->seed );
1205+ ctx->seed_cur = get_rng_seed (ctx->seed );
1206+ ctx->rng .seed (ctx->seed_cur );
11831207}
11841208
11851209static struct llama_sampler * llama_sampler_mirostat_v2_clone (const struct llama_sampler * smpl) {
@@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
12121236};
12131237
12141238struct llama_sampler * llama_sampler_init_mirostat_v2 (uint32_t seed, float tau, float eta) {
1239+ auto seed_cur = get_rng_seed (seed);
12151240 return new llama_sampler {
12161241 /* .iface = */ &llama_sampler_mirostat_v2_i,
12171242 /* .ctx = */ new llama_sampler_mirostat_v2 {
1218- /* .seed = */ seed,
1219- /* .tau = */ tau,
1220- /* .eta = */ eta,
1221- /* .mu = */ 2 .0f *tau,
1222- /* .rng = */ std::mt19937 (seed),
1243+ /* .seed = */ seed,
1244+ /* .seed_cur = */ seed_cur,
1245+ /* .tau = */ tau,
1246+ /* .eta = */ eta,
1247+ /* .mu = */ 2 .0f *tau,
1248+ /* .rng = */ std::mt19937 (seed_cur),
12231249 },
12241250 };
12251251}
@@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties(
15051531 ignore_eos = false ;
15061532 }
15071533
1534+ penalty_last_n = std::max (penalty_last_n, 0 );
1535+
15081536 return new llama_sampler {
15091537 /* .iface = */ &llama_sampler_penalties_i,
15101538 /* .ctx = */ new llama_sampler_penalties {
@@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
15681596 }
15691597 }
15701598}
1599+
15711600static struct llama_sampler * llama_sampler_logit_bias_clone (const struct llama_sampler * smpl) {
15721601 const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx ;
15731602 return llama_sampler_init_logit_bias (ctx->n_vocab , ctx->logit_bias .size (), ctx->logit_bias .data ());
@@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias(
15991628 },
16001629 };
16011630}
1631+
1632+ // utils
1633+
1634+ uint32_t llama_sampler_get_seed (const struct llama_sampler * smpl) {
1635+ if (smpl->iface == &llama_sampler_dist_i) {
1636+ return ((const llama_sampler_dist *) smpl->ctx )->seed_cur ;
1637+ }
1638+
1639+ if (smpl->iface == &llama_sampler_mirostat_i) {
1640+ return ((const llama_sampler_mirostat *) smpl->ctx )->seed_cur ;
1641+ }
1642+
1643+ if (smpl->iface == &llama_sampler_mirostat_v2_i) {
1644+ return ((const llama_sampler_mirostat_v2 *) smpl->ctx )->seed_cur ;
1645+ }
1646+
1647+ if (smpl->iface == &llama_sampler_chain_i) {
1648+ const auto * ctx = (const llama_sampler_chain *) smpl->ctx ;
1649+ for (auto it = ctx->samplers .rbegin (); it != ctx->samplers .rend (); ++it) {
1650+ const uint32_t seed = llama_sampler_get_seed (*it);
1651+ if (seed != LLAMA_DEFAULT_SEED) {
1652+ return seed;
1653+ }
1654+ }
1655+ }
1656+
1657+ return LLAMA_DEFAULT_SEED;
1658+ }
0 commit comments