Skip to content

Commit 3802ff2

Browse files
committed
add batch.clear() and batch.n_tokens()
1 parent 2cec1cf commit 3802ff2

File tree

16 files changed

+77
-69
lines changed

16 files changed

+77
-69
lines changed

common/speculative.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ llama_tokens common_speculative_gen_draft(
204204
}
205205

206206
// prepare a batch to evaluate any new tokens in the prompt
207-
llama_batch_ext_clear(batch.get());
207+
batch.clear();
208208

209209
for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
210210
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
@@ -214,7 +214,7 @@ llama_tokens common_speculative_gen_draft(
214214
}
215215

216216
// we should rarely end-up here during normal decoding
217-
if (llama_batch_ext_get_n_tokens(batch.get()) > 0) {
217+
if (batch.n_tokens() > 0) {
218218
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
219219

220220
llama_decode_ext(ctx, batch.get());
@@ -224,7 +224,7 @@ llama_tokens common_speculative_gen_draft(
224224

225225
LOG_DBG("%s: n_past = %d\n", __func__, n_past);
226226

227-
llama_batch_ext_clear(batch.get());
227+
batch.clear();
228228
batch.add_text(id_last, n_past, 0, true);
229229

230230
prompt.push_back(id_last);
@@ -237,7 +237,7 @@ llama_tokens common_speculative_gen_draft(
237237

238238
// sample n_draft tokens from the draft model
239239
for (int i = 0; i < params.n_draft; ++i) {
240-
llama_batch_ext_clear(batch.get());
240+
batch.clear();
241241

242242
common_sampler_sample(smpl, ctx, 0, true);
243243

examples/gritlm/gritlm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
1717
llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1));
1818

1919
for (uint64_t i = 0; i < sentences.size(); i++) {
20-
llama_batch_ext_clear(batch.get());
20+
batch.clear();
2121

2222
const std::string input_string = instruction + sentences[i];
2323

@@ -111,7 +111,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
111111
int32_t i_current_token = 0;
112112

113113
while (true) {
114-
llama_batch_ext_clear(batch.get());
114+
batch.clear();
115115
{
116116
const int32_t n_inputs = inputs.size();
117117

@@ -123,7 +123,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
123123

124124
llama_decode_ext(ctx, batch.get());
125125

126-
llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
126+
llama_token token = llama_sampler_sample(smpl, ctx, batch.n_tokens() - 1);
127127

128128
if (token == eos_token) {
129129
break;

examples/llava/gemma3-cli.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ struct gemma3_context {
9090

9191
static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) {
9292
llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true);
93-
llama_batch_ext_clear(ctx.batch.get());
93+
ctx.batch.clear();
9494
for (llama_token & t : tokens) {
9595
ctx.batch.add_text(t, ctx.n_past++, 0, false);
9696
}
@@ -178,7 +178,7 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
178178
fflush(stdout);
179179

180180
// eval the token
181-
llama_batch_ext_clear(ctx.batch.get());
181+
ctx.batch.clear();
182182
ctx.batch.add_text(token_id, ctx.n_past++, 0, true);
183183
if (llama_decode_ext(ctx.lctx, ctx.batch.get())) {
184184
LOG_ERR("failed to decode token\n");

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ int main(int argc, char ** argv){
197197
// clean the cache of draft tokens that weren't accepted
198198
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
199199

200-
llama_batch_ext_clear(batch_tgt.get());
200+
batch_tgt.clear();
201201
batch_tgt.add_text(draft[0], n_past, 0, true);
202202

203203
// Draft already contains a single token sampled from the model:

examples/parallel/parallel.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -217,23 +217,23 @@ int main(int argc, char ** argv) {
217217
common_kv_cache_dump_view_seqs(kvc_view, 40);
218218
}
219219

220-
llama_batch_ext_clear(batch.get());
220+
batch.clear();
221221

222222
// decode any currently ongoing sequences
223223
for (auto & client : clients) {
224224
if (client.seq_id == -1) {
225225
continue;
226226
}
227227

228-
client.i_batch = llama_batch_ext_get_n_tokens(batch.get());
228+
client.i_batch = batch.n_tokens();
229229

230230
llama_seq_id seq_id = client.id + 1;
231231
batch.add_text(client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, seq_id, true);
232232

233233
client.n_decoded += 1;
234234
}
235235

236-
if (llama_batch_ext_get_n_tokens(batch.get()) == 0) {
236+
if (batch.n_tokens() == 0) {
237237
// all sequences have ended - clear the entire KV cache
238238
for (int i = 1; i <= n_clients; ++i) {
239239
llama_kv_self_seq_rm(ctx, i, -1, -1);
@@ -245,7 +245,7 @@ int main(int argc, char ** argv) {
245245
}
246246

247247
// insert new sequences for decoding
248-
if (cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) {
248+
if (cont_batching || batch.n_tokens() == 0) {
249249
for (auto & client : clients) {
250250
if (client.seq_id == -1 && g_seq_id < n_seq) {
251251
client.seq_id = g_seq_id;
@@ -269,13 +269,13 @@ int main(int argc, char ** argv) {
269269
}
270270

271271
// extract the logits only for the last token
272-
if (llama_batch_ext_get_n_tokens(batch.get()) > 0) {
272+
if (batch.n_tokens() > 0) {
273273
llama_batch_ext_set_output_last(batch.get());
274274
}
275275

276276
client.n_prompt = tokens_prompt.size();
277277
client.n_decoded = 0;
278-
client.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1;
278+
client.i_batch = batch.n_tokens() - 1;
279279

280280
LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
281281

@@ -289,14 +289,14 @@ int main(int argc, char ** argv) {
289289
}
290290
}
291291

292-
if (llama_batch_ext_get_n_tokens(batch.get()) == 0) {
292+
if (batch.n_tokens() == 0) {
293293
break;
294294
}
295295

296296
// process in chunks of params.n_batch
297297
int32_t n_batch = params.n_batch;
298298

299-
int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch.get());
299+
int32_t n_tokens_in_batch = batch.n_tokens();
300300
for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) {
301301
// experiment: process in powers of 2
302302
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {

examples/passkey/passkey.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ int main(int argc, char ** argv) {
141141
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
142142
}
143143

144-
llama_batch_ext_clear(batch.get());
144+
batch.clear();
145145

146146
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
147147
batch.add_text(tokens_list[i + j], n_past++, 0, false);
@@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
175175

176176
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
177177

178-
llama_batch_ext_clear(batch.get());
178+
batch.clear();
179179

180180
for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
181181
batch.add_text(tokens_list[i + j], n_past++, 0, false);
@@ -224,7 +224,7 @@ int main(int argc, char ** argv) {
224224
while (n_cur <= n_len) {
225225
// sample the next token
226226
{
227-
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1);
227+
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens() - 1);
228228

229229
// is it an end of generation?
230230
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
@@ -238,7 +238,7 @@ int main(int argc, char ** argv) {
238238
n_decode += 1;
239239

240240
// prepare the next batch
241-
llama_batch_ext_clear(batch.get());
241+
batch.clear();
242242

243243
// push this new token for next evaluation
244244
llama_seq_id seq_id = 0;

examples/perplexity/perplexity.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
369369
const int batch_start = start + j * n_batch;
370370
const int batch_size = std::min(end - batch_start, n_batch);
371371

372-
llama_batch_ext_clear(batch.get());
372+
batch.clear();
373373
for (int i = 0; i < batch_size; i++) {
374374
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
375375
}
@@ -552,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
552552

553553
int n_outputs = 0;
554554

555-
llama_batch_ext_clear(batch.get());
555+
batch.clear();
556556
for (int seq = 0; seq < n_seq_batch; seq++) {
557557
int seq_start = batch_start + seq*n_ctx;
558558

@@ -846,7 +846,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
846846
size_t i1 = i0;
847847
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
848848

849-
llama_batch_ext_clear(batch.get());
849+
batch.clear();
850850

851851
// batch as much tasks as possible into the available context
852852
// each task has 4 unique sequence ids - one for each ending
@@ -1131,7 +1131,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11311131
size_t i1 = i0;
11321132
size_t i_logits = 0;
11331133

1134-
llama_batch_ext_clear(batch.get());
1134+
batch.clear();
11351135

11361136
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
11371137
int n_logits = 0;
@@ -1485,7 +1485,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
14851485
size_t i1 = i0;
14861486
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
14871487

1488-
llama_batch_ext_clear(batch.get());
1488+
batch.clear();
14891489

14901490
// batch as much tasks as possible into the available context
14911491
// each task has 4 unique sequence ids - one for each ending
@@ -1744,7 +1744,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17441744
tokens[batch_start] = llama_vocab_bos(vocab);
17451745
}
17461746

1747-
llama_batch_ext_clear(batch.get());
1747+
batch.clear();
17481748
for (int i = 0; i < batch_size; i++) {
17491749
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
17501750
}

examples/run/run.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
954954
static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) {
955955
const int n_ctx = llama_n_ctx(ctx.get());
956956
const int n_ctx_used = llama_kv_self_used_cells(ctx.get());
957-
if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) {
957+
if (n_ctx_used + batch.n_tokens() > n_ctx) {
958958
printf(LOG_COL_DEFAULT "\n");
959959
printe("context size exceeded\n");
960960
return 1;
@@ -1001,7 +1001,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
10011001
return 1;
10021002
}
10031003

1004-
llama_data.n_past += llama_batch_ext_get_n_tokens(batch.get());
1004+
llama_data.n_past += batch.n_tokens();
10051005

10061006
// sample the next token, check is it an end of generation?
10071007
new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);

examples/save-load-state/save-load-state.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
5252

5353
// evaluate prompt
5454
llama_decode_ext(ctx, batch.get());
55-
n_past += llama_batch_ext_get_n_tokens(batch.get());
55+
n_past += batch.n_tokens();
5656

5757
// save state (rng, logits, embedding and kv_cache) to file
5858
{
@@ -79,7 +79,7 @@ int main(int argc, char ** argv) {
7979
printf("%s", next_token_str.c_str());
8080
result0 += next_token_str;
8181

82-
llama_batch_ext_clear(batch.get());
82+
batch.clear();
8383
batch.add_text(next_token, 0, 0, true);
8484

8585
if (llama_decode_ext(ctx, batch.get())) {
@@ -131,7 +131,7 @@ int main(int argc, char ** argv) {
131131
printf("%s", next_token_str.c_str());
132132
result1 += next_token_str;
133133

134-
llama_batch_ext_clear(batch.get());
134+
batch.clear();
135135
batch.add_text(next_token, 0, 0, true);
136136

137137
if (llama_decode_ext(ctx2, batch.get())) {
@@ -212,7 +212,7 @@ int main(int argc, char ** argv) {
212212
printf("%s", next_token_str.c_str());
213213
result2 += next_token_str;
214214

215-
llama_batch_ext_clear(batch.get());
215+
batch.clear();
216216
batch.add_text(next_token, 0, 1, true);
217217

218218
if (llama_decode_ext(ctx3, batch.get())) {

0 commit comments

Comments
 (0)