Skip to content

Commit 416f491

Browse files
Save and load example adjust
1 parent 6c4c88d commit 416f491

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

examples/save-load-state/save-load-state.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
6464
// first run
6565
printf("\n%s", params.prompt.c_str());
6666
for (auto i = 0; i < params.n_predict; i++) {
67-
auto next_token = llama_sample_top_p_top_k(
68-
ctx,
69-
&last_n_tokens_data.back() - params.repeat_last_n,
70-
params.repeat_last_n,
71-
40,
72-
1.0,
73-
1.0,
74-
1.1);
67+
auto logits = llama_get_logits(ctx);
68+
auto n_vocab = llama_n_vocab(ctx);
69+
std::vector<llama_token_data> candidates;
70+
candidates.reserve(n_vocab);
71+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
72+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
73+
}
74+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
75+
auto next_token = llama_sample_token(ctx, &candidates_p);
7576
auto next_token_str = llama_token_to_str(ctx, next_token);
7677
last_n_tokens_data.push_back(next_token);
7778
printf("%s", next_token_str);
@@ -106,14 +107,15 @@ int main(int argc, char ** argv) {
106107

107108
// second run
108109
for (auto i = 0; i < params.n_predict; i++) {
109-
auto next_token = llama_sample_top_p_top_k(
110-
ctx2,
111-
&last_n_tokens_data.back() - params.repeat_last_n,
112-
params.repeat_last_n,
113-
40,
114-
1.0,
115-
1.0,
116-
1.1);
110+
auto logits = llama_get_logits(ctx2);
111+
auto n_vocab = llama_n_vocab(ctx2);
112+
std::vector<llama_token_data> candidates;
113+
candidates.reserve(n_vocab);
114+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
115+
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
116+
}
117+
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
118+
auto next_token = llama_sample_token(ctx2, &candidates_p);
117119
auto next_token_str = llama_token_to_str(ctx2, next_token);
118120
last_n_tokens_data.push_back(next_token);
119121
printf("%s", next_token_str);

0 commit comments

Comments
 (0)