Skip to content

Commit 2134cab

Browse files
committed
add cpp batch.add_text wrapper
1 parent c5a0176 commit 2134cab

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
371371

372372
llama_batch_ext_clear(batch.get());
373373
for (int i = 0; i < batch_size; i++) {
374-
llama_seq_id seq_id = 0;
375-
llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
374+
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
376375
}
377376

378377
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
@@ -568,7 +567,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
568567
for (int k = 0; k < batch_size; ++k) {
569568
const llama_pos pos = j*n_batch + k;
570569
bool output = pos >= first;
571-
llama_batch_ext_add_text(batch.get(), tokens[seq_start + k], pos, &seq, 1, output);
570+
batch.add_text(tokens[seq_start + k], pos, seq, output);
572571

573572
n_outputs += output ? 1 : 0;
574573
}
@@ -864,7 +863,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
864863

865864
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
866865
std::vector<llama_seq_id> seq_ids = { s0 + 0, s0 + 1, s0 + 2, s0 + 3 };
867-
llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false);
866+
batch.add_text(hs_cur.seq_tokens[0][i], i, seq_ids, false);
868867
}
869868
llama_batch_ext_set_output_last(batch.get());
870869
n_logits += 1;
@@ -875,7 +874,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
875874
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
876875
const bool needs_logits = i < seq_tokens_size - 1;
877876
llama_seq_id seq_id = s0 + s;
878-
llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[s][i], i, &seq_id, 1, needs_logits);
877+
batch.add_text(hs_cur.seq_tokens[s][i], i, seq_id, needs_logits);
879878
n_logits += needs_logits;
880879
}
881880
}
@@ -1143,16 +1142,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11431142

11441143
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
11451144
std::vector<llama_seq_id> seq_ids{ s0 + 0, s0 + 1 };
1146-
llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false);
1145+
batch.add_text(data[i1].seq_tokens[0][i], i, seq_ids, false);
11471146
}
11481147
llama_batch_ext_set_output_last(batch.get());
11491148
n_logits += 1;
11501149

11511150
for (int s = 0; s < 2; ++s) {
11521151
// TODO: end before the last token, no need to predict past the end of the sequences
11531152
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
1154-
llama_seq_id seq_id = s0 + s;
1155-
llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[s][i], i, &seq_id, 1, true);
1153+
batch.add_text(data[i1].seq_tokens[s][i], i, s0 + s, true);
11561154
n_logits += 1;
11571155
}
11581156
}
@@ -1511,7 +1509,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15111509
}
15121510

15131511
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
1514-
llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[0][i], i, batch_indeces.data(), batch_indeces.size(), false);
1512+
batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false);
15151513
}
15161514
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
15171515
n_logits += 1;
@@ -1521,8 +1519,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15211519
// TODO: don't evaluate the last token of each sequence
15221520
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
15231521
const bool needs_logits = i < seq_tokens_size - 1;
1524-
llama_seq_id seq_id = { s0 + s };
1525-
llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[s][i], i, &seq_id, 1, needs_logits);
1522+
batch.add_text(cur_task.seq_tokens[s][i], i, s0 + s, needs_logits);
15261523
n_logits += needs_logits;
15271524
}
15281525
}
@@ -1749,8 +1746,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17491746

17501747
llama_batch_ext_clear(batch.get());
17511748
for (int i = 0; i < batch_size; i++) {
1752-
llama_seq_id seq_id = 0;
1753-
llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
1749+
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
17541750
}
17551751

17561752
if (llama_decode_ext(ctx, batch.get())) {

include/llama-cpp.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,31 @@ struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_de
6464
llama_seq_id seq_id) {
6565
return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(embd, n_tokens, n_embd, pos0, seq_id));
6666
}
67+
68+
// Wrapper to add a sequence of tokens to the batch
69+
void add_seq(const std::vector<llama_token> & tokens, llama_pos pos0, llama_seq_id seq_id, bool output_last) {
70+
size_t n_tokens = tokens.size();
71+
for (size_t i = 0; i < n_tokens; i++) {
72+
llama_batch_ext_add_text(this->get(), tokens[i], i + pos0, &seq_id, 1, false);
73+
}
74+
if (output_last) {
75+
llama_batch_ext_set_output_last(this->get());
76+
}
77+
}
78+
79+
// Wrapper to add a single token to the batch, support multiple sequence IDs
80+
void add_text(llama_token token, llama_pos pos, std::vector<llama_seq_id> & seq_id, bool output_last) {
81+
llama_batch_ext_add_text(this->get(), token, pos, seq_id.data(), seq_id.size(), false);
82+
if (output_last) {
83+
llama_batch_ext_set_output_last(this->get());
84+
}
85+
}
86+
87+
// Wrapper to add a single token to the batch (single sequence ID)
88+
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool output_last) {
89+
llama_batch_ext_add_text(this->get(), token, pos, &seq_id, 1, false);
90+
if (output_last) {
91+
llama_batch_ext_set_output_last(this->get());
92+
}
93+
}
6794
};

0 commit comments

Comments
 (0)