@@ -371,8 +371,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
371
371
372
372
llama_batch_ext_clear (batch.get ());
373
373
for (int i = 0 ; i < batch_size; i++) {
374
- llama_seq_id seq_id = 0 ;
375
- llama_batch_ext_add_text (batch.get (), tokens[batch_start + i], j*n_batch + i, &seq_id, 1 , true );
374
+ batch.add_text (tokens[batch_start + i], j*n_batch + i, 0 , true );
376
375
}
377
376
378
377
// LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
@@ -568,7 +567,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
568
567
for (int k = 0 ; k < batch_size; ++k) {
569
568
const llama_pos pos = j*n_batch + k;
570
569
bool output = pos >= first;
571
- llama_batch_ext_add_text ( batch.get (), tokens[seq_start + k], pos, & seq, 1 , output);
570
+ batch.add_text ( tokens[seq_start + k], pos, seq, output);
572
571
573
572
n_outputs += output ? 1 : 0 ;
574
573
}
@@ -864,7 +863,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
864
863
865
864
for (size_t i = 0 ; i < hs_cur.common_prefix ; ++i) {
866
865
std::vector<llama_seq_id> seq_ids = { s0 + 0 , s0 + 1 , s0 + 2 , s0 + 3 };
867
- llama_batch_ext_add_text ( batch.get (), hs_cur.seq_tokens [0 ][i], i, seq_ids. data (), seq_ids. size () , false );
866
+ batch.add_text ( hs_cur.seq_tokens [0 ][i], i, seq_ids, false );
868
867
}
869
868
llama_batch_ext_set_output_last (batch.get ());
870
869
n_logits += 1 ;
@@ -875,7 +874,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
875
874
for (size_t i = hs_cur.common_prefix ; i < seq_tokens_size; ++i) {
876
875
const bool needs_logits = i < seq_tokens_size - 1 ;
877
876
llama_seq_id seq_id = s0 + s;
878
- llama_batch_ext_add_text ( batch.get (), hs_cur.seq_tokens [s][i], i, & seq_id, 1 , needs_logits);
877
+ batch.add_text ( hs_cur.seq_tokens [s][i], i, seq_id, needs_logits);
879
878
n_logits += needs_logits;
880
879
}
881
880
}
@@ -1143,16 +1142,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
1143
1142
1144
1143
for (size_t i = 0 ; i < data[i1].common_prefix ; ++i) {
1145
1144
std::vector<llama_seq_id> seq_ids{ s0 + 0 , s0 + 1 };
1146
- llama_batch_ext_add_text ( batch.get (), data[i1].seq_tokens [0 ][i], i, seq_ids. data (), seq_ids. size () , false );
1145
+ batch.add_text ( data[i1].seq_tokens [0 ][i], i, seq_ids, false );
1147
1146
}
1148
1147
llama_batch_ext_set_output_last (batch.get ());
1149
1148
n_logits += 1 ;
1150
1149
1151
1150
for (int s = 0 ; s < 2 ; ++s) {
1152
1151
// TODO: end before the last token, no need to predict past the end of the sequences
1153
1152
for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size (); ++i) {
1154
- llama_seq_id seq_id = s0 + s;
1155
- llama_batch_ext_add_text (batch.get (), data[i1].seq_tokens [s][i], i, &seq_id, 1 , true );
1153
+ batch.add_text (data[i1].seq_tokens [s][i], i, s0 + s, true );
1156
1154
n_logits += 1 ;
1157
1155
}
1158
1156
}
@@ -1511,7 +1509,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
1511
1509
}
1512
1510
1513
1511
for (size_t i = 0 ; i < cur_task.common_prefix ; ++i) {
1514
- llama_batch_ext_add_text ( batch.get (), cur_task.seq_tokens [0 ][i], i, batch_indeces. data (), batch_indeces. size () , false );
1512
+ batch.add_text ( cur_task.seq_tokens [0 ][i], i, batch_indeces, false );
1515
1513
}
1516
1514
llama_batch_ext_set_output_last (batch.get ()); // we need logits for the last token of the common prefix
1517
1515
n_logits += 1 ;
@@ -1521,8 +1519,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
1521
1519
// TODO: don't evaluate the last token of each sequence
1522
1520
for (size_t i = cur_task.common_prefix ; i < seq_tokens_size; ++i) {
1523
1521
const bool needs_logits = i < seq_tokens_size - 1 ;
1524
- llama_seq_id seq_id = { s0 + s };
1525
- llama_batch_ext_add_text (batch.get (), cur_task.seq_tokens [s][i], i, &seq_id, 1 , needs_logits);
1522
+ batch.add_text (cur_task.seq_tokens [s][i], i, s0 + s, needs_logits);
1526
1523
n_logits += needs_logits;
1527
1524
}
1528
1525
}
@@ -1749,8 +1746,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
1749
1746
1750
1747
llama_batch_ext_clear (batch.get ());
1751
1748
for (int i = 0 ; i < batch_size; i++) {
1752
- llama_seq_id seq_id = 0 ;
1753
- llama_batch_ext_add_text (batch.get (), tokens[batch_start + i], j*n_batch + i, &seq_id, 1 , true );
1749
+ batch.add_text (tokens[batch_start + i], j*n_batch + i, 0 , true );
1754
1750
}
1755
1751
1756
1752
if (llama_decode_ext (ctx, batch.get ())) {
0 commit comments