@@ -363,15 +363,16 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
363
363
// clear the KV cache
364
364
llama_kv_self_clear (ctx);
365
365
366
- common_batch batch (n_batch, 1 );
366
+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_batch, 1 ) );
367
367
368
368
for (int j = 0 ; j < num_batches; ++j) {
369
369
const int batch_start = start + j * n_batch;
370
370
const int batch_size = std::min (end - batch_start, n_batch);
371
371
372
- batch.clear ( );
372
+ llama_batch_ext_clear ( batch.get () );
373
373
for (int i = 0 ; i < batch_size; i++) {
374
- batch.add_text (tokens[batch_start + i], j*n_batch + i, 0 , true );
374
+ llama_seq_id seq_id = 0 ;
375
+ llama_batch_ext_add_text (batch.get (), tokens[batch_start + i], j*n_batch + i, &seq_id, 1 , true );
375
376
}
376
377
377
378
// LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
@@ -501,7 +502,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
501
502
GGML_ASSERT (n_batch < n_ctx || n_batch % n_ctx == 0 );
502
503
GGML_ASSERT (params.n_ctx == n_seq * n_ctx);
503
504
504
- common_batch batch (std::min (n_batch, n_ctx*n_seq), 1 );
505
+ llama_batch_ext_ptr batch (llama_batch_ext_init ( std::min (n_batch, n_ctx*n_seq), 1 ) );
505
506
506
507
std::vector<float > logits;
507
508
if (num_batches > 1 ) {
@@ -552,7 +553,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
552
553
553
554
int n_outputs = 0 ;
554
555
555
- batch.clear ( );
556
+ llama_batch_ext_clear ( batch.get () );
556
557
for (int seq = 0 ; seq < n_seq_batch; seq++) {
557
558
int seq_start = batch_start + seq*n_ctx;
558
559
@@ -567,7 +568,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
567
568
for (int k = 0 ; k < batch_size; ++k) {
568
569
const llama_pos pos = j*n_batch + k;
569
570
bool output = pos >= first;
570
- batch.add_text ( tokens[seq_start + k], pos, seq, output);
571
+ llama_batch_ext_add_text ( batch.get (), tokens[seq_start + k], pos, & seq, 1 , output);
571
572
572
573
n_outputs += output ? 1 : 0 ;
573
574
}
@@ -649,26 +650,15 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
649
650
return {tokens, ppl, logit_history, prob_history};
650
651
}
651
652
652
- static bool decode_helper (llama_context * ctx, common_batch & batch, std::vector<float > & batch_logits, int n_batch, int n_vocab) {
653
- int prev_outputs = 0 ;
654
- for (int i = 0 ; i < (int ) batch.get_n_tokens (); i += n_batch) {
655
- const int n_tokens = std::min<int >(n_batch, batch.get_n_tokens () - i);
656
-
657
- common_batch batch_view = batch.get_view (i, n_tokens);
658
-
659
- const int ret = llama_decode_ext (ctx, batch_view.get ());
660
- if (ret != 0 ) {
661
- LOG_ERR (" failed to decode the batch, n_batch = %d, ret = %d\n " , n_batch, ret);
662
- return false ;
663
- }
664
-
665
- int n_outputs = batch_view.n_outputs ;
666
-
667
- memcpy (batch_logits.data () + size_t (prev_outputs)*n_vocab, llama_get_logits (ctx), size_t (n_outputs)*n_vocab*sizeof (float ));
668
-
669
- prev_outputs += n_outputs;
653
+ static bool decode_helper (llama_context * ctx, llama_batch_ext_ptr & batch, std::vector<float > & batch_logits, size_t n_outputs, int n_vocab) {
654
+ const int ret = llama_decode_ext (ctx, batch.get ());
655
+ if (ret != 0 ) {
656
+ LOG_ERR (" failed to decode the batch, ret = %d\n " , ret);
657
+ return false ;
670
658
}
671
659
660
+ memcpy (batch_logits.data (), llama_get_logits (ctx), n_outputs*n_vocab*sizeof (float ));
661
+
672
662
return true ;
673
663
}
674
664
@@ -836,14 +826,12 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
836
826
double acc = 0 .0f ;
837
827
838
828
const int n_ctx = llama_n_ctx (ctx);
839
- const int n_batch = params.n_batch ;
840
-
841
829
const int n_vocab = llama_vocab_n_tokens (vocab);
842
830
843
831
const int max_tasks_per_batch = 32 ;
844
832
const int max_seq = std::min (4 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
845
833
846
- common_batch batch (n_ctx, 4 );
834
+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_ctx, 4 ) );
847
835
848
836
std::vector<float > tok_logits (n_vocab);
849
837
// TODO: this could be made smaller; it's currently the worst-case size
@@ -859,7 +847,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
859
847
size_t i1 = i0;
860
848
size_t i_logits = 0 ; // this tells us how many logits were needed before this point in the batch
861
849
862
- batch.clear ( );
850
+ llama_batch_ext_clear ( batch.get () );
863
851
864
852
// batch as much tasks as possible into the available context
865
853
// each task has 4 unique sequence ids - one for each ending
@@ -875,7 +863,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
875
863
}
876
864
877
865
for (size_t i = 0 ; i < hs_cur.common_prefix ; ++i) {
878
- batch.add_text_multi_seq (hs_cur.seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 , s0 + 2 , s0 + 3 }, false );
866
+ std::vector<llama_seq_id> seq_ids = { s0 + 0 , s0 + 1 , s0 + 2 , s0 + 3 };
867
+ llama_batch_ext_add_text (batch.get (), hs_cur.seq_tokens [0 ][i], i, seq_ids.data (), seq_ids.size (), false );
879
868
}
880
869
llama_batch_ext_set_output_last (batch.get ());
881
870
n_logits += 1 ;
@@ -885,7 +874,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
885
874
// TODO: don't evaluate the last token of each sequence
886
875
for (size_t i = hs_cur.common_prefix ; i < seq_tokens_size; ++i) {
887
876
const bool needs_logits = i < seq_tokens_size - 1 ;
888
- batch.add_text_multi_seq (hs_cur.seq_tokens [s][i], i, { s0 + s }, needs_logits);
877
+ llama_seq_id seq_id = s0 + s;
878
+ llama_batch_ext_add_text (batch.get (), hs_cur.seq_tokens [s][i], i, &seq_id, 1 , needs_logits);
889
879
n_logits += needs_logits;
890
880
}
891
881
}
@@ -907,7 +897,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
907
897
llama_kv_self_clear (ctx);
908
898
909
899
// decode all tasks [i0, i1)
910
- if (!decode_helper (ctx, batch, batch_logits, n_batch , n_vocab)) {
900
+ if (!decode_helper (ctx, batch, batch_logits, i_logits , n_vocab)) {
911
901
LOG_ERR (" %s: llama_decode() failed\n " , __func__);
912
902
return ;
913
903
}
@@ -1118,14 +1108,12 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
1118
1108
LOG_INF (" %s : calculating winogrande score over selected tasks.\n " , __func__);
1119
1109
1120
1110
const int n_ctx = llama_n_ctx (ctx);
1121
- const int n_batch = params.n_batch ;
1122
-
1123
1111
const int n_vocab = llama_vocab_n_tokens (vocab);
1124
1112
1125
1113
const int max_tasks_per_batch = 128 ;
1126
1114
const int max_seq = std::min (2 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
1127
1115
1128
- common_batch batch (n_ctx, 2 );
1116
+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_ctx, 2 ) );
1129
1117
1130
1118
std::vector<float > tok_logits (n_vocab);
1131
1119
// TODO: this could be made smaller; it's currently the worst-case size
@@ -1144,7 +1132,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
1144
1132
size_t i1 = i0;
1145
1133
size_t i_logits = 0 ;
1146
1134
1147
- batch.clear ( );
1135
+ llama_batch_ext_clear ( batch.get () );
1148
1136
1149
1137
while (n_cur + (int ) data[i1].required_tokens <= n_ctx) {
1150
1138
int n_logits = 0 ;
@@ -1154,15 +1142,17 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
1154
1142
}
1155
1143
1156
1144
for (size_t i = 0 ; i < data[i1].common_prefix ; ++i) {
1157
- batch.add_text_multi_seq (data[i1].seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 }, false );
1145
+ std::vector<llama_seq_id> seq_ids{ s0 + 0 , s0 + 1 };
1146
+ llama_batch_ext_add_text (batch.get (), data[i1].seq_tokens [0 ][i], i, seq_ids.data (), seq_ids.size (), false );
1158
1147
}
1159
1148
llama_batch_ext_set_output_last (batch.get ());
1160
1149
n_logits += 1 ;
1161
1150
1162
1151
for (int s = 0 ; s < 2 ; ++s) {
1163
1152
// TODO: end before the last token, no need to predict past the end of the sequences
1164
1153
for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size (); ++i) {
1165
- batch.add_text_multi_seq (data[i1].seq_tokens [s][i], i, { s0 + s }, true );
1154
+ llama_seq_id seq_id = s0 + s;
1155
+ llama_batch_ext_add_text (batch.get (), data[i1].seq_tokens [s][i], i, &seq_id, 1 , true );
1166
1156
n_logits += 1 ;
1167
1157
}
1168
1158
}
@@ -1184,7 +1174,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
1184
1174
llama_kv_self_clear (ctx);
1185
1175
1186
1176
// decode all tasks [i0, i1)
1187
- if (!decode_helper (ctx, batch, batch_logits, n_batch , n_vocab)) {
1177
+ if (!decode_helper (ctx, batch, batch_logits, i_logits , n_vocab)) {
1188
1178
LOG_ERR (" %s: llama_decode() failed\n " , __func__);
1189
1179
return ;
1190
1180
}
@@ -1472,14 +1462,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
1472
1462
LOG (" \n task\t acc_norm\n " );
1473
1463
1474
1464
const int n_ctx = llama_n_ctx (ctx);
1475
- const int n_batch = params.n_batch ;
1476
-
1477
1465
const int n_vocab = llama_vocab_n_tokens (vocab);
1478
1466
1479
1467
const int max_tasks_per_batch = 32 ;
1480
1468
const int max_seq = std::min (4 *max_tasks_per_batch, (int ) llama_n_seq_max (ctx));
1481
1469
1482
- common_batch batch (n_ctx, max_seq);
1470
+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_ctx, max_seq) );
1483
1471
1484
1472
std::vector<float > tok_logits (n_vocab);
1485
1473
std::vector<float > batch_logits (size_t (n_ctx)*n_vocab);
@@ -1499,7 +1487,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
1499
1487
size_t i1 = i0;
1500
1488
size_t i_logits = 0 ; // this tells us how many logits were needed before this point in the batch
1501
1489
1502
- batch.clear ( );
1490
+ llama_batch_ext_clear ( batch.get () );
1503
1491
1504
1492
// batch as much tasks as possible into the available context
1505
1493
// each task has 4 unique sequence ids - one for each ending
@@ -1518,11 +1506,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
1518
1506
if (int (batch_indeces.size ()) != num_answers) {
1519
1507
batch_indeces.resize (num_answers);
1520
1508
}
1521
- for (int s = 0 ; s < num_answers; ++s) batch_indeces[s] = s0 + s;
1509
+ for (int s = 0 ; s < num_answers; ++s) {
1510
+ batch_indeces[s] = s0 + s;
1511
+ }
1522
1512
1523
1513
for (size_t i = 0 ; i < cur_task.common_prefix ; ++i) {
1524
- // llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
1525
- batch.add_text_multi_seq (cur_task.seq_tokens [0 ][i], i, batch_indeces, false );
1514
+ llama_batch_ext_add_text (batch.get (), cur_task.seq_tokens [0 ][i], i, batch_indeces.data (), batch_indeces.size (), false );
1526
1515
}
1527
1516
llama_batch_ext_set_output_last (batch.get ()); // we need logits for the last token of the common prefix
1528
1517
n_logits += 1 ;
@@ -1532,7 +1521,8 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
1532
1521
// TODO: don't evaluate the last token of each sequence
1533
1522
for (size_t i = cur_task.common_prefix ; i < seq_tokens_size; ++i) {
1534
1523
const bool needs_logits = i < seq_tokens_size - 1 ;
1535
- batch.add_text_multi_seq (cur_task.seq_tokens [s][i], i, { s0 + s }, needs_logits);
1524
+ llama_seq_id seq_id = { s0 + s };
1525
+ llama_batch_ext_add_text (batch.get (), cur_task.seq_tokens [s][i], i, &seq_id, 1 , needs_logits);
1536
1526
n_logits += needs_logits;
1537
1527
}
1538
1528
}
@@ -1556,7 +1546,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
1556
1546
llama_kv_self_clear (ctx);
1557
1547
1558
1548
// decode all tasks [i0, i1)
1559
- if (!decode_helper (ctx, batch, batch_logits, n_batch , n_vocab)) {
1549
+ if (!decode_helper (ctx, batch, batch_logits, i_logits , n_vocab)) {
1560
1550
LOG_ERR (" %s: llama_decode() failed\n " , __func__);
1561
1551
return ;
1562
1552
}
@@ -1743,7 +1733,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
1743
1733
// clear the KV cache
1744
1734
llama_kv_self_clear (ctx);
1745
1735
1746
- common_batch batch (n_batch, 1 );
1736
+ llama_batch_ext_ptr batch (llama_batch_ext_init ( n_batch, 1 ) );
1747
1737
1748
1738
for (int j = 0 ; j < num_batches; ++j) {
1749
1739
const int batch_start = start + j * n_batch;
@@ -1757,9 +1747,10 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
1757
1747
tokens[batch_start] = llama_vocab_bos (vocab);
1758
1748
}
1759
1749
1760
- batch.clear ( );
1750
+ llama_batch_ext_clear ( batch.get () );
1761
1751
for (int i = 0 ; i < batch_size; i++) {
1762
- batch.add_text_multi_seq (tokens[batch_start + i], j*n_batch + i, {0 }, true );
1752
+ llama_seq_id seq_id = 0 ;
1753
+ llama_batch_ext_add_text (batch.get (), tokens[batch_start + i], j*n_batch + i, &seq_id, 1 , true );
1763
1754
}
1764
1755
1765
1756
if (llama_decode_ext (ctx, batch.get ())) {
0 commit comments