Skip to content

Commit 8b20858

Browse files
authored
perplexity : faster Winogrande via batching (#5024)
* perplexity : faster Winogrande via batching ggml-ci * perplexity : remove unused function * perplexity : only tokenize selected tasks for Winogrande
1 parent 57e2a7a commit 8b20858

File tree

1 file changed

+158
-125
lines changed

1 file changed

+158
-125
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 158 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -423,26 +423,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
423423
return {tokens, ppl, logit_history, prob_history};
424424
}
425425

426-
static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int> & tokens,
427-
int n_past, int n_batch, int n_vocab) {
428-
std::vector<float> result;
429-
result.reserve(tokens.size() * n_vocab);
430-
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
431-
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
432-
size_t n_tokens = tokens.size() - i_chunk * n_batch;
433-
n_tokens = std::min(n_tokens, size_t(n_batch));
434-
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
435-
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0))) {
436-
fprintf(stderr, "%s : failed to eval\n", __func__);
437-
return {};
426+
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
427+
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
428+
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
429+
430+
llama_batch batch_view = {
431+
n_tokens,
432+
batch.token + i,
433+
nullptr,
434+
batch.pos + i,
435+
batch.n_seq_id + i,
436+
batch.seq_id + i,
437+
batch.logits + i,
438+
0, 0, 0, // unused
439+
};
440+
441+
const int ret = llama_decode(ctx, batch_view);
442+
if (ret != 0) {
443+
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
444+
return false;
438445
}
439446

440-
const auto logits = llama_get_logits(ctx);
441-
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
442-
443-
n_past += n_tokens;
447+
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
444448
}
445-
return result;
449+
450+
return true;
446451
}
447452

448453
static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
@@ -576,7 +581,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
576581

577582
// determine the common prefix of the endings
578583
hs_cur.common_prefix = 0;
579-
hs_cur.required_tokens = 0;
580584
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
581585
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
582586
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
@@ -609,45 +613,18 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
609613
const int n_ctx = llama_n_ctx(ctx);
610614
const int n_batch = params.n_batch;
611615

612-
const int max_tasks_per_batch = params.n_parallel;
616+
const int max_tasks_per_batch = 32;
613617
const int max_seq = 4*max_tasks_per_batch;
614618

615619
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
616620

617621
std::vector<float> tok_logits(n_vocab);
618-
std::vector<float> batch_logits(n_ctx*n_vocab);
622+
std::vector<float> batch_logits(n_vocab*n_ctx);
619623

620624
std::vector<std::pair<size_t, llama_token>> eval_pairs;
621625
std::vector<float> eval_results;
622626
std::vector<std::thread> workers(std::thread::hardware_concurrency());
623627

624-
auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
625-
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
626-
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
627-
628-
llama_batch batch_view = {
629-
n_tokens,
630-
batch.token + i,
631-
nullptr,
632-
batch.pos + i,
633-
batch.n_seq_id + i,
634-
batch.seq_id + i,
635-
batch.logits + i,
636-
0, 0, 0, // unused
637-
};
638-
639-
const int ret = llama_decode(ctx, batch_view);
640-
if (ret != 0) {
641-
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
642-
return false;
643-
}
644-
645-
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
646-
}
647-
648-
return true;
649-
};
650-
651628
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
652629
int n_cur = 0;
653630

@@ -696,7 +673,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
696673
llama_kv_cache_clear(ctx);
697674

698675
// decode all tasks [i0, i1)
699-
if (!decode_helper(ctx, batch, n_batch)) {
676+
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
700677
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
701678
return;
702679
}
@@ -772,6 +749,13 @@ struct winogrande_entry {
772749
std::string second;
773750
std::array<std::string, 2> choices;
774751
int answer;
752+
753+
size_t i_batch;
754+
size_t common_prefix;
755+
size_t required_tokens;
756+
size_t n_base1; // number of tokens for context + choice 1
757+
size_t n_base2; // number of tokens for context + choice 2
758+
std::vector<llama_token> seq_tokens[2];
775759
};
776760

777761
static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
@@ -875,115 +859,164 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
875859
data = std::move(selected);
876860
}
877861

862+
fprintf(stderr, "%s : tokenizing selected tasks\n", __func__);
863+
878864
// This is needed as usual for LLaMA models
879865
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
880866

867+
for (auto & task : data) {
868+
task.seq_tokens[0] = ::llama_tokenize(ctx, task.first + task.choices[0] + task.second, add_bos);
869+
task.seq_tokens[1] = ::llama_tokenize(ctx, task.first + task.choices[1] + task.second, add_bos);
870+
871+
task.common_prefix = 0;
872+
for (size_t k = 0; k < task.seq_tokens[0].size(); k++) {
873+
if (task.seq_tokens[0][k] != task.seq_tokens[1][k]) {
874+
break;
875+
}
876+
task.common_prefix++;
877+
}
878+
879+
task.required_tokens = task.common_prefix +
880+
task.seq_tokens[0].size() - task.common_prefix +
881+
task.seq_tokens[1].size() - task.common_prefix;
882+
883+
task.n_base1 = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos).size();
884+
task.n_base2 = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos).size();
885+
}
886+
881887
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
882888

883889
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
884-
const int n_ctx = llama_n_ctx(ctx);
890+
const int n_ctx = llama_n_ctx(ctx);
891+
const int n_batch = params.n_batch;
892+
893+
const int max_tasks_per_batch = 128;
894+
const int max_seq = 2*max_tasks_per_batch;
895+
896+
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
885897

886898
std::vector<float> tok_logits(n_vocab);
899+
std::vector<float> batch_logits(n_vocab*n_ctx);
887900

888901
int n_correct = 0;
889902
int n_done = 0;
890903

891-
for (size_t task_idx = 0; task_idx < data.size(); task_idx++) {
892-
const auto& task = data[task_idx];
904+
for (size_t i0 = 0; i0 < data.size(); i0++) {
905+
int n_cur = 0;
893906

894-
auto base_context = ::llama_tokenize(ctx, task.first, add_bos);
895-
auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos);
896-
auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos);
907+
size_t i1 = i0;
908+
size_t i_batch = 0;
897909

898-
auto sentence_1st = task.first + task.choices[0] + task.second;
899-
auto sentence_2nd = task.first + task.choices[1] + task.second;
900-
auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos);
901-
auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos);
910+
llama_batch_clear(batch);
902911

903-
if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) {
904-
fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size());
905-
return;
906-
}
912+
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
913+
const int s0 = 2*(i1 - i0);
914+
if (s0 + 2 > max_seq) {
915+
break;
916+
}
917+
918+
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
919+
llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
920+
}
921+
batch.logits[batch.n_tokens - 1] = true;
907922

908-
auto query_1st_size = query_1st.size();
909-
auto query_2nd_size = query_2nd.size();
923+
for (int s = 0; s < 2; ++s) {
924+
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
925+
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
926+
}
927+
}
910928

911-
// Speedup small evaluations by evaluating atleast 32 tokens
912-
// For Winogrande this seems to slow it down rather than speed it up.
913-
//if (query_1st.size() < 32) query_1st.resize(32);
914-
//if (query_2nd.size() < 32) query_2nd.resize(32);
929+
data[i1].i_batch = i_batch;
930+
i_batch += data[i1].required_tokens;
915931

916-
llama_kv_cache_clear(ctx);
917-
auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab);
932+
n_cur += data[i1].required_tokens;
933+
if (++i1 == data.size()) {
934+
break;
935+
}
936+
}
937+
938+
if (i0 == i1) {
939+
fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
940+
return;
941+
}
918942

919943
llama_kv_cache_clear(ctx);
920-
auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab);
921944

922-
if (logits_1st.empty() || logits_2nd.empty()) {
923-
fprintf(stderr, "%s : failed to eval\n", __func__);
945+
// decode all tasks [i0, i1)
946+
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
947+
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
924948
return;
925949
}
926950

927-
bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx &&
928-
query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx;
929-
930-
float score_1st = 0;
931-
bool is_nan_1st = false;
932-
const auto& base_1 = skip_choice ? base_ctx_1st : base_context;
933-
const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0;
934-
for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) {
935-
std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float));
936-
const float prob = softmax(tok_logits)[query_1st[j+1]];
937-
if (std::isnan(prob) || !prob) {
938-
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
939-
prob, j, sentence_1st.c_str(), base_context.size());
940-
is_nan_1st = true;
941-
break;
951+
for (size_t i = i0; i < i1; ++i) {
952+
auto & task = data[i];
953+
954+
const bool skip_choice =
955+
task.seq_tokens[0].size() - task.common_prefix > k_min_trailing_ctx &&
956+
task.seq_tokens[1].size() - task.common_prefix > k_min_trailing_ctx;
957+
958+
float score_1st = 0;
959+
bool is_nan_1st = false;
960+
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
961+
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
962+
size_t li = n_base1 - 1;
963+
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
964+
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
965+
const float prob = softmax(tok_logits)[task.seq_tokens[0][j+1]];
966+
if (std::isnan(prob) || !prob) {
967+
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
968+
prob, j, (task.first + task.choices[0] + task.second).c_str(), n_base1);
969+
is_nan_1st = true;
970+
break;
971+
}
972+
score_1st += std::log(prob);
942973
}
943-
score_1st += std::log(prob);
944-
}
945-
score_1st /= (query_1st_size - base_1.size() - last_1st);
946-
947-
float score_2nd = 0;
948-
bool is_nan_2nd = false;
949-
const auto& base_2 = skip_choice ? base_ctx_2nd : base_context;
950-
const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0;
951-
for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) {
952-
std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float));
953-
const float prob = softmax(tok_logits)[query_2nd[j+1]];
954-
if (std::isnan(prob) || !prob) {
955-
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
956-
prob, j, sentence_2nd.c_str(), base_context.size());
957-
is_nan_2nd = true;
958-
break;
974+
score_1st /= (task.seq_tokens[0].size() - n_base1 - last_1st);
975+
976+
float score_2nd = 0;
977+
bool is_nan_2nd = false;
978+
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
979+
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
980+
li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
981+
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
982+
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(task.i_batch + li++), n_vocab*sizeof(float));
983+
const float prob = softmax(tok_logits)[task.seq_tokens[1][j+1]];
984+
if (std::isnan(prob) || !prob) {
985+
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
986+
prob, j, (task.first + task.choices[1] + task.second).c_str(), n_base2);
987+
is_nan_2nd = true;
988+
break;
989+
}
990+
score_2nd += std::log(prob);
959991
}
960-
score_2nd += std::log(prob);
961-
}
962-
score_2nd /= (query_2nd_size - base_2.size() - last_2nd);
992+
score_2nd /= (task.seq_tokens[1].size() - n_base2 - last_2nd);
963993

964-
if (is_nan_1st || is_nan_2nd) {
965-
continue;
966-
}
994+
if (is_nan_1st || is_nan_2nd) {
995+
continue;
996+
}
967997

968-
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
969-
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
970-
printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size);
971-
printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size);
972-
printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size());
973-
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice);
974-
continue;
975-
}
998+
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
999+
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
1000+
printf("Q1: <%s> - %zu tokens\n", (task.first + task.choices[0] + task.second).c_str(), task.seq_tokens[0].size());
1001+
printf("Q2: <%s> - %zu tokens\n", (task.first + task.choices[1] + task.second).c_str(), task.seq_tokens[1].size());
1002+
printf("B : <%s> - %zu tokens\n", task.first.c_str(), task.common_prefix);
1003+
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", n_base1, n_base2, skip_choice);
1004+
continue;
1005+
}
9761006

977-
int result = score_1st > score_2nd ? 1 : 2;
1007+
int result = score_1st > score_2nd ? 1 : 2;
1008+
1009+
if (result == task.answer) {
1010+
++n_correct;
1011+
}
1012+
++n_done;
9781013

979-
if (result == task.answer) {
980-
++n_correct;
1014+
// Print the accumulated accuracy mean x 100
1015+
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n", i+1, 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer);
1016+
fflush(stdout);
9811017
}
982-
++n_done;
9831018

984-
// Print the accumulated accuracy mean x 100
985-
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer);
986-
fflush(stdout);
1019+
i0 = i1 - 1;
9871020
}
9881021

9891022
printf("\n");

0 commit comments

Comments
 (0)