Skip to content

Commit 76fd7d6

Browse files
committed
perplexity : avoid common_batch
ggml-ci
1 parent 8b80d68 commit 76fd7d6

File tree

1 file changed

+41
-50
lines changed

1 file changed

+41
-50
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -363,15 +363,16 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
363363
// clear the KV cache
364364
llama_kv_self_clear(ctx);
365365

366-
common_batch batch(n_batch, 1);
366+
llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1));
367367

368368
for (int j = 0; j < num_batches; ++j) {
369369
const int batch_start = start + j * n_batch;
370370
const int batch_size = std::min(end - batch_start, n_batch);
371371

372-
batch.clear();
372+
llama_batch_ext_clear(batch.get());
373373
for (int i = 0; i < batch_size; i++) {
374-
batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true);
374+
llama_seq_id seq_id = 0;
375+
llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
375376
}
376377

377378
//LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
@@ -501,7 +502,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
501502
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
502503
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
503504

504-
common_batch batch(std::min(n_batch, n_ctx*n_seq), 1);
505+
llama_batch_ext_ptr batch(llama_batch_ext_init(std::min(n_batch, n_ctx*n_seq), 1));
505506

506507
std::vector<float> logits;
507508
if (num_batches > 1) {
@@ -552,7 +553,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
552553

553554
int n_outputs = 0;
554555

555-
batch.clear();
556+
llama_batch_ext_clear(batch.get());
556557
for (int seq = 0; seq < n_seq_batch; seq++) {
557558
int seq_start = batch_start + seq*n_ctx;
558559

@@ -567,7 +568,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
567568
for (int k = 0; k < batch_size; ++k) {
568569
const llama_pos pos = j*n_batch + k;
569570
bool output = pos >= first;
570-
batch.add_text(tokens[seq_start + k], pos, seq, output);
571+
llama_batch_ext_add_text(batch.get(), tokens[seq_start + k], pos, &seq, 1, output);
571572

572573
n_outputs += output ? 1 : 0;
573574
}
@@ -649,26 +650,15 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
649650
return {tokens, ppl, logit_history, prob_history};
650651
}
651652

652-
static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector<float> & batch_logits, int n_batch, int n_vocab) {
653-
int prev_outputs = 0;
654-
for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) {
655-
const int n_tokens = std::min<int>(n_batch, batch.get_n_tokens() - i);
656-
657-
common_batch batch_view = batch.get_view(i, n_tokens);
658-
659-
const int ret = llama_decode_ext(ctx, batch_view.get());
660-
if (ret != 0) {
661-
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
662-
return false;
663-
}
664-
665-
int n_outputs = batch_view.n_outputs;
666-
667-
memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
668-
669-
prev_outputs += n_outputs;
653+
static bool decode_helper(llama_context * ctx, llama_batch_ext_ptr & batch, std::vector<float> & batch_logits, size_t n_outputs, int n_vocab) {
654+
const int ret = llama_decode_ext(ctx, batch.get());
655+
if (ret != 0) {
656+
LOG_ERR("failed to decode the batch, ret = %d\n", ret);
657+
return false;
670658
}
671659

660+
memcpy(batch_logits.data(), llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
661+
672662
return true;
673663
}
674664

@@ -836,14 +826,12 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
836826
double acc = 0.0f;
837827

838828
const int n_ctx = llama_n_ctx(ctx);
839-
const int n_batch = params.n_batch;
840-
841829
const int n_vocab = llama_vocab_n_tokens(vocab);
842830

843831
const int max_tasks_per_batch = 32;
844832
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
845833

846-
common_batch batch(n_ctx, 4);
834+
llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 4));
847835

848836
std::vector<float> tok_logits(n_vocab);
849837
// TODO: this could be made smaller; it's currently the worst-case size
@@ -859,7 +847,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
859847
size_t i1 = i0;
860848
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
861849

862-
batch.clear();
850+
llama_batch_ext_clear(batch.get());
863851

864852
// batch as much tasks as possible into the available context
865853
// each task has 4 unique sequence ids - one for each ending
@@ -875,7 +863,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
875863
}
876864

877865
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
878-
batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
866+
std::vector<llama_seq_id> seq_ids = { s0 + 0, s0 + 1, s0 + 2, s0 + 3 };
867+
llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false);
879868
}
880869
llama_batch_ext_set_output_last(batch.get());
881870
n_logits += 1;
@@ -885,7 +874,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
885874
// TODO: don't evaluate the last token of each sequence
886875
for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
887876
const bool needs_logits = i < seq_tokens_size - 1;
888-
batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
877+
llama_seq_id seq_id = s0 + s;
878+
llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[s][i], i, &seq_id, 1, needs_logits);
889879
n_logits += needs_logits;
890880
}
891881
}
@@ -907,7 +897,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
907897
llama_kv_self_clear(ctx);
908898

909899
// decode all tasks [i0, i1)
910-
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
900+
if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) {
911901
LOG_ERR("%s: llama_decode() failed\n", __func__);
912902
return;
913903
}
@@ -1118,14 +1108,12 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11181108
LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
11191109

11201110
const int n_ctx = llama_n_ctx(ctx);
1121-
const int n_batch = params.n_batch;
1122-
11231111
const int n_vocab = llama_vocab_n_tokens(vocab);
11241112

11251113
const int max_tasks_per_batch = 128;
11261114
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
11271115

1128-
common_batch batch(n_ctx, 2);
1116+
llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 2));
11291117

11301118
std::vector<float> tok_logits(n_vocab);
11311119
// TODO: this could be made smaller; it's currently the worst-case size
@@ -1144,7 +1132,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11441132
size_t i1 = i0;
11451133
size_t i_logits = 0;
11461134

1147-
batch.clear();
1135+
llama_batch_ext_clear(batch.get());
11481136

11491137
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
11501138
int n_logits = 0;
@@ -1154,15 +1142,17 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11541142
}
11551143

11561144
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
1157-
batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
1145+
std::vector<llama_seq_id> seq_ids{ s0 + 0, s0 + 1 };
1146+
llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false);
11581147
}
11591148
llama_batch_ext_set_output_last(batch.get());
11601149
n_logits += 1;
11611150

11621151
for (int s = 0; s < 2; ++s) {
11631152
// TODO: end before the last token, no need to predict past the end of the sequences
11641153
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
1165-
batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true);
1154+
llama_seq_id seq_id = s0 + s;
1155+
llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[s][i], i, &seq_id, 1, true);
11661156
n_logits += 1;
11671157
}
11681158
}
@@ -1184,7 +1174,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
11841174
llama_kv_self_clear(ctx);
11851175

11861176
// decode all tasks [i0, i1)
1187-
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1177+
if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) {
11881178
LOG_ERR("%s: llama_decode() failed\n", __func__);
11891179
return;
11901180
}
@@ -1472,14 +1462,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
14721462
LOG("\ntask\tacc_norm\n");
14731463

14741464
const int n_ctx = llama_n_ctx(ctx);
1475-
const int n_batch = params.n_batch;
1476-
14771465
const int n_vocab = llama_vocab_n_tokens(vocab);
14781466

14791467
const int max_tasks_per_batch = 32;
14801468
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
14811469

1482-
common_batch batch(n_ctx, max_seq);
1470+
llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, max_seq));
14831471

14841472
std::vector<float> tok_logits(n_vocab);
14851473
std::vector<float> batch_logits(size_t(n_ctx)*n_vocab);
@@ -1499,7 +1487,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
14991487
size_t i1 = i0;
15001488
size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
15011489

1502-
batch.clear();
1490+
llama_batch_ext_clear(batch.get());
15031491

15041492
// batch as much tasks as possible into the available context
15051493
// each task has 4 unique sequence ids - one for each ending
@@ -1518,11 +1506,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15181506
if (int(batch_indeces.size()) != num_answers) {
15191507
batch_indeces.resize(num_answers);
15201508
}
1521-
for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s;
1509+
for (int s = 0; s < num_answers; ++s) {
1510+
batch_indeces[s] = s0 + s;
1511+
}
15221512

15231513
for (size_t i = 0; i < cur_task.common_prefix; ++i) {
1524-
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
1525-
batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false);
1514+
llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[0][i], i, batch_indeces.data(), batch_indeces.size(), false);
15261515
}
15271516
llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix
15281517
n_logits += 1;
@@ -1532,7 +1521,8 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15321521
// TODO: don't evaluate the last token of each sequence
15331522
for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
15341523
const bool needs_logits = i < seq_tokens_size - 1;
1535-
batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
1524+
llama_seq_id seq_id = { s0 + s };
1525+
llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[s][i], i, &seq_id, 1, needs_logits);
15361526
n_logits += needs_logits;
15371527
}
15381528
}
@@ -1556,7 +1546,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
15561546
llama_kv_self_clear(ctx);
15571547

15581548
// decode all tasks [i0, i1)
1559-
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
1549+
if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) {
15601550
LOG_ERR("%s: llama_decode() failed\n", __func__);
15611551
return;
15621552
}
@@ -1743,7 +1733,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17431733
// clear the KV cache
17441734
llama_kv_self_clear(ctx);
17451735

1746-
common_batch batch(n_batch, 1);
1736+
llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1));
17471737

17481738
for (int j = 0; j < num_batches; ++j) {
17491739
const int batch_start = start + j * n_batch;
@@ -1757,9 +1747,10 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
17571747
tokens[batch_start] = llama_vocab_bos(vocab);
17581748
}
17591749

1760-
batch.clear();
1750+
llama_batch_ext_clear(batch.get());
17611751
for (int i = 0; i < batch_size; i++) {
1762-
batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true);
1752+
llama_seq_id seq_id = 0;
1753+
llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true);
17631754
}
17641755

17651756
if (llama_decode_ext(ctx, batch.get())) {

0 commit comments

Comments
 (0)