Skip to content

Commit b4fbb2a

Browse files
committed
examples : simplify sampling using new API
ggml-ci
1 parent 77998ae commit b4fbb2a

File tree

7 files changed

+35
-91
lines changed

7 files changed

+35
-91
lines changed

examples/batched/batched.cpp

+9-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#include "llama.h"
33

44
#include <algorithm>
5-
#include <cmath>
65
#include <cstdio>
76
#include <string>
87
#include <vector>
@@ -66,6 +65,8 @@ int main(int argc, char ** argv) {
6665
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
6766

6867
auto sparams = llama_sampling_default_params();
68+
69+
sparams.seed = params.sparams.seed;
6970
sparams.top_k = 40;
7071
sparams.top_p = 0.9f;
7172
sparams.temp = 0.4f;
@@ -171,25 +172,17 @@ int main(int argc, char ** argv) {
171172
continue;
172173
}
173174

174-
auto n_vocab = llama_n_vocab(model);
175-
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
176-
177-
std::vector<llama_token_data> candidates;
178-
candidates.reserve(n_vocab);
179-
180-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
181-
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
182-
}
175+
const auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
183176

184-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
177+
llama_sampling_set_logits(smpl, logits);
185178

186-
llama_sampling_top_k(smpl, &candidates_p);
187-
llama_sampling_top_p(smpl, &candidates_p);
188-
llama_sampling_temp (smpl, &candidates_p);
179+
llama_sampling_top_k(smpl, nullptr);
180+
llama_sampling_top_p(smpl, nullptr);
181+
llama_sampling_temp (smpl, nullptr);
189182

190-
const llama_token new_token_id = llama_sampling_sample_dist(smpl, &candidates_p);
183+
const llama_token new_token_id = llama_sampling_sample_dist(smpl, nullptr);
191184

192-
//const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
185+
//const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
193186

194187
// is it an end of generation? -> mark the stream as finished
195188
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {

examples/gritlm/gritlm.cpp

+2-7
Original file line numberDiff line numberDiff line change
@@ -118,14 +118,9 @@ static std::string generate(llama_context * ctx, llama_sampling * smpl, const st
118118
llama_decode(ctx, bat);
119119
auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
120120

121-
auto candidates = std::vector<llama_token_data>(llama_n_vocab(model));
122-
auto n_candidates = (int32_t)candidates.size();
123-
for (int32_t token = 0; token < n_candidates; token++) {
124-
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
125-
}
126-
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
121+
llama_sampling_set_logits(smpl, logits);
127122

128-
llama_token token = llama_sampling_sample_greedy(smpl, &candidates_p);
123+
llama_token token = llama_sampling_sample_greedy(smpl, nullptr);
129124
if (token == eos_token) {
130125
break;
131126
}

examples/llama.android/llama/src/main/cpp/llama-android.cpp

+2-9
Original file line numberDiff line numberDiff line change
@@ -396,17 +396,10 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
396396
auto n_vocab = llama_n_vocab(model);
397397
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
398398

399-
std::vector<llama_token_data> candidates;
400-
candidates.reserve(n_vocab);
401-
402-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
403-
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
404-
}
405-
406-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
399+
llama_sampling_set_logits(sampling, logits);
407400

408401
// sample the most likely token
409-
const auto new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p);
402+
const auto new_token_id = llama_sampling_sample_greedy(sampling, nullptr);
410403

411404
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
412405
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

+2-10
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,9 @@ actor LlamaContext {
149149
let n_vocab = llama_n_vocab(model)
150150
let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
151151

152-
var candidates = Array<llama_token_data>()
153-
candidates.reserveCapacity(Int(n_vocab))
152+
llama_sampling_set_logits(sampling, logits);
154153

155-
for token_id in 0..<n_vocab {
156-
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
157-
}
158-
candidates.withUnsafeMutableBufferPointer() { buffer in
159-
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
160-
161-
new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p)
162-
}
154+
new_token_id = llama_sampling_sample_greedy(sampling, nil)
163155

164156
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
165157
print("\n")

examples/passkey/passkey.cpp

+3-11
Original file line numberDiff line numberDiff line change
@@ -216,20 +216,12 @@ int main(int argc, char ** argv) {
216216
while (n_cur <= n_len) {
217217
// sample the next token
218218
{
219-
auto n_vocab = llama_n_vocab(model);
220-
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
219+
const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
221220

222-
std::vector<llama_token_data> candidates;
223-
candidates.reserve(n_vocab);
224-
225-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
226-
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
227-
}
228-
229-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
221+
llama_sampling_set_logits(smpl, logits);
230222

231223
// sample the most likely token
232-
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
224+
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
233225

234226
// is it an end of generation?
235227
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {

examples/save-load-state/save-load-state.cpp

+14-27
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,11 @@ int main(int argc, char ** argv) {
6969
printf("\nfirst run: %s", params.prompt.c_str());
7070

7171
for (auto i = 0; i < params.n_predict; i++) {
72-
auto * logits = llama_get_logits(ctx);
73-
auto n_vocab = llama_n_vocab(model);
72+
const auto * logits = llama_get_logits(ctx);
7473

75-
std::vector<llama_token_data> candidates;
76-
candidates.reserve(n_vocab);
77-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
78-
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
79-
}
80-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
81-
auto next_token = llama_sampling_sample_dist(smpl, &candidates_p);
74+
llama_sampling_set_logits(smpl, logits);
75+
76+
auto next_token = llama_sampling_sample_dist(smpl, nullptr);
8277
auto next_token_str = llama_token_to_piece(ctx, next_token);
8378

8479
printf("%s", next_token_str.c_str());
@@ -131,15 +126,11 @@ int main(int argc, char ** argv) {
131126

132127
// second run
133128
for (auto i = 0; i < params.n_predict; i++) {
134-
auto * logits = llama_get_logits(ctx2);
135-
auto n_vocab = llama_n_vocab(model);
136-
std::vector<llama_token_data> candidates;
137-
candidates.reserve(n_vocab);
138-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
139-
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
140-
}
141-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
142-
auto next_token = llama_sampling_sample_dist(smpl2, &candidates_p);
129+
const auto * logits = llama_get_logits(ctx2);
130+
131+
llama_sampling_set_logits(smpl2, logits);
132+
133+
auto next_token = llama_sampling_sample_dist(smpl2, nullptr);
143134
auto next_token_str = llama_token_to_piece(ctx2, next_token);
144135

145136
printf("%s", next_token_str.c_str());
@@ -224,15 +215,11 @@ int main(int argc, char ** argv) {
224215

225216
// third run with seq 1 instead of 0
226217
for (auto i = 0; i < params.n_predict; i++) {
227-
auto * logits = llama_get_logits(ctx3);
228-
auto n_vocab = llama_n_vocab(model);
229-
std::vector<llama_token_data> candidates;
230-
candidates.reserve(n_vocab);
231-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
232-
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
233-
}
234-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
235-
auto next_token = llama_sampling_sample_dist(smpl3, &candidates_p);
218+
const auto * logits = llama_get_logits(ctx3);
219+
220+
llama_sampling_set_logits(smpl3, logits);
221+
222+
auto next_token = llama_sampling_sample_dist(smpl3, nullptr);
236223
auto next_token_str = llama_token_to_piece(ctx3, next_token);
237224

238225
printf("%s", next_token_str.c_str());

examples/simple/simple.cpp

+3-11
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,12 @@ int main(int argc, char ** argv) {
112112
while (n_cur <= n_predict) {
113113
// sample the next token
114114
{
115-
auto n_vocab = llama_n_vocab(model);
116-
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
115+
const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
117116

118-
std::vector<llama_token_data> candidates;
119-
candidates.reserve(n_vocab);
120-
121-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
122-
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
123-
}
124-
125-
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
117+
llama_sampling_set_logits(smpl, logits);
126118

127119
// sample the most likely token
128-
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);
120+
const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr);
129121

130122
// is it an end of generation?
131123
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {

0 commit comments

Comments
 (0)