@@ -423,26 +423,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
423
423
return {tokens, ppl, logit_history, prob_history};
424
424
}
425
425
426
- static std::vector<float > evaluate_tokens (llama_context * ctx, std::vector<int > & tokens,
427
- int n_past, int n_batch, int n_vocab) {
428
- std::vector<float > result;
429
- result.reserve (tokens.size () * n_vocab);
430
- size_t n_chunk = (tokens.size () + n_batch - 1 )/n_batch;
431
- for (size_t i_chunk = 0 ; i_chunk < n_chunk; ++i_chunk) {
432
- size_t n_tokens = tokens.size () - i_chunk * n_batch;
433
- n_tokens = std::min (n_tokens, size_t (n_batch));
434
- llama_kv_cache_seq_rm (ctx, 0 , n_past, -1 );
435
- if (llama_decode (ctx, llama_batch_get_one (tokens.data () + i_chunk * n_batch, n_tokens, n_past, 0 ))) {
436
- fprintf (stderr, " %s : failed to eval\n " , __func__);
437
- return {};
426
+ static bool decode_helper (llama_context * ctx, llama_batch & batch, std::vector<float > & batch_logits, int32_t n_batch, int32_t n_vocab) {
427
+ for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch) {
428
+ const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
429
+
430
+ llama_batch batch_view = {
431
+ n_tokens,
432
+ batch.token + i,
433
+ nullptr ,
434
+ batch.pos + i,
435
+ batch.n_seq_id + i,
436
+ batch.seq_id + i,
437
+ batch.logits + i,
438
+ 0 , 0 , 0 , // unused
439
+ };
440
+
441
+ const int ret = llama_decode (ctx, batch_view);
442
+ if (ret != 0 ) {
443
+ LOG_TEE (" failed to decode the batch, n_batch = %d, ret = %d\n " , n_batch, ret);
444
+ return false ;
438
445
}
439
446
440
- const auto logits = llama_get_logits (ctx);
441
- result.insert (result.end (), logits, logits + n_tokens * n_vocab);
442
-
443
- n_past += n_tokens;
447
+ memcpy (batch_logits.data () + i*n_vocab, llama_get_logits (ctx), n_tokens*n_vocab*sizeof (float ));
444
448
}
445
- return result;
449
+
450
+ return true ;
446
451
}
447
452
448
453
static void hellaswag_compute_logprobs (const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
@@ -576,7 +581,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
576
581
577
582
// determine the common prefix of the endings
578
583
hs_cur.common_prefix = 0 ;
579
- hs_cur.required_tokens = 0 ;
580
584
for (size_t k = 0 ; k < hs_cur.seq_tokens [0 ].size (); k++) {
581
585
if (hs_cur.seq_tokens [0 ][k] != hs_cur.seq_tokens [1 ][k] ||
582
586
hs_cur.seq_tokens [0 ][k] != hs_cur.seq_tokens [2 ][k] ||
@@ -609,45 +613,18 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
609
613
const int n_ctx = llama_n_ctx (ctx);
610
614
const int n_batch = params.n_batch ;
611
615
612
- const int max_tasks_per_batch = params. n_parallel ;
616
+ const int max_tasks_per_batch = 32 ;
613
617
const int max_seq = 4 *max_tasks_per_batch;
614
618
615
619
llama_batch batch = llama_batch_init (n_ctx, 0 , max_seq);
616
620
617
621
std::vector<float > tok_logits (n_vocab);
618
- std::vector<float > batch_logits (n_ctx* n_vocab);
622
+ std::vector<float > batch_logits (n_vocab*n_ctx );
619
623
620
624
std::vector<std::pair<size_t , llama_token>> eval_pairs;
621
625
std::vector<float > eval_results;
622
626
std::vector<std::thread> workers (std::thread::hardware_concurrency ());
623
627
624
- auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
625
- for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch) {
626
- const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
627
-
628
- llama_batch batch_view = {
629
- n_tokens,
630
- batch.token + i,
631
- nullptr ,
632
- batch.pos + i,
633
- batch.n_seq_id + i,
634
- batch.seq_id + i,
635
- batch.logits + i,
636
- 0 , 0 , 0 , // unused
637
- };
638
-
639
- const int ret = llama_decode (ctx, batch_view);
640
- if (ret != 0 ) {
641
- LOG_TEE (" failed to decode the batch, n_batch = %d, ret = %d\n " , n_batch, ret);
642
- return false ;
643
- }
644
-
645
- memcpy (batch_logits.data () + i*n_vocab, llama_get_logits (ctx), n_tokens*n_vocab*sizeof (float ));
646
- }
647
-
648
- return true ;
649
- };
650
-
651
628
for (size_t i0 = 0 ; i0 < hs_task_count; i0++) {
652
629
int n_cur = 0 ;
653
630
@@ -696,7 +673,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
696
673
llama_kv_cache_clear (ctx);
697
674
698
675
// decode all tasks [i0, i1)
699
- if (!decode_helper (ctx, batch, n_batch)) {
676
+ if (!decode_helper (ctx, batch, batch_logits, n_batch, n_vocab )) {
700
677
fprintf (stderr, " %s: llama_decode() failed\n " , __func__);
701
678
return ;
702
679
}
@@ -772,6 +749,13 @@ struct winogrande_entry {
772
749
std::string second;
773
750
std::array<std::string, 2 > choices;
774
751
int answer;
752
+
753
+ size_t i_batch;
754
+ size_t common_prefix;
755
+ size_t required_tokens;
756
+ size_t n_base1; // number of tokens for context + choice 1
757
+ size_t n_base2; // number of tokens for context + choice 2
758
+ std::vector<llama_token> seq_tokens[2 ];
775
759
};
776
760
777
761
static std::vector<winogrande_entry> load_winogrande_from_csv (const std::string& prompt) {
@@ -875,115 +859,164 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
875
859
data = std::move (selected);
876
860
}
877
861
862
+ fprintf (stderr, " %s : tokenizing selected tasks\n " , __func__);
863
+
878
864
// This is needed as usual for LLaMA models
879
865
const bool add_bos = llama_should_add_bos_token (llama_get_model (ctx));
880
866
867
+ for (auto & task : data) {
868
+ task.seq_tokens [0 ] = ::llama_tokenize (ctx, task.first + task.choices [0 ] + task.second , add_bos);
869
+ task.seq_tokens [1 ] = ::llama_tokenize (ctx, task.first + task.choices [1 ] + task.second , add_bos);
870
+
871
+ task.common_prefix = 0 ;
872
+ for (size_t k = 0 ; k < task.seq_tokens [0 ].size (); k++) {
873
+ if (task.seq_tokens [0 ][k] != task.seq_tokens [1 ][k]) {
874
+ break ;
875
+ }
876
+ task.common_prefix ++;
877
+ }
878
+
879
+ task.required_tokens = task.common_prefix +
880
+ task.seq_tokens [0 ].size () - task.common_prefix +
881
+ task.seq_tokens [1 ].size () - task.common_prefix ;
882
+
883
+ task.n_base1 = ::llama_tokenize (ctx, task.first + task.choices [0 ], add_bos).size ();
884
+ task.n_base2 = ::llama_tokenize (ctx, task.first + task.choices [1 ], add_bos).size ();
885
+ }
886
+
881
887
fprintf (stderr, " %s : calculating winogrande score over selected tasks.\n " , __func__);
882
888
883
889
const int n_vocab = llama_n_vocab (llama_get_model (ctx));
884
- const int n_ctx = llama_n_ctx (ctx);
890
+ const int n_ctx = llama_n_ctx (ctx);
891
+ const int n_batch = params.n_batch ;
892
+
893
+ const int max_tasks_per_batch = 128 ;
894
+ const int max_seq = 2 *max_tasks_per_batch;
895
+
896
+ llama_batch batch = llama_batch_init (n_ctx, 0 , max_seq);
885
897
886
898
std::vector<float > tok_logits (n_vocab);
899
+ std::vector<float > batch_logits (n_vocab*n_ctx);
887
900
888
901
int n_correct = 0 ;
889
902
int n_done = 0 ;
890
903
891
- for (size_t task_idx = 0 ; task_idx < data.size (); task_idx ++) {
892
- const auto & task = data[task_idx] ;
904
+ for (size_t i0 = 0 ; i0 < data.size (); i0 ++) {
905
+ int n_cur = 0 ;
893
906
894
- auto base_context = ::llama_tokenize (ctx, task.first , add_bos);
895
- auto base_ctx_1st = ::llama_tokenize (ctx, task.first + task.choices [0 ], add_bos);
896
- auto base_ctx_2nd = ::llama_tokenize (ctx, task.first + task.choices [1 ], add_bos);
907
+ size_t i1 = i0;
908
+ size_t i_batch = 0 ;
897
909
898
- auto sentence_1st = task.first + task.choices [0 ] + task.second ;
899
- auto sentence_2nd = task.first + task.choices [1 ] + task.second ;
900
- auto query_1st = ::llama_tokenize (ctx, sentence_1st, add_bos);
901
- auto query_2nd = ::llama_tokenize (ctx, sentence_2nd, add_bos);
910
+ llama_batch_clear (batch);
902
911
903
- if (query_1st.size () > (size_t )n_ctx || query_2nd.size () > (size_t )n_ctx) {
904
- fprintf (stderr, " %s : number of tokens in queries %zu, %zu > n_ctxl\n " , __func__, query_1st.size (), query_2nd.size ());
905
- return ;
906
- }
912
+ while (n_cur + (int ) data[i1].required_tokens <= n_ctx) {
913
+ const int s0 = 2 *(i1 - i0);
914
+ if (s0 + 2 > max_seq) {
915
+ break ;
916
+ }
917
+
918
+ for (size_t i = 0 ; i < data[i1].common_prefix ; ++i) {
919
+ llama_batch_add (batch, data[i1].seq_tokens [0 ][i], i, { s0 + 0 , s0 + 1 }, false );
920
+ }
921
+ batch.logits [batch.n_tokens - 1 ] = true ;
907
922
908
- auto query_1st_size = query_1st.size ();
909
- auto query_2nd_size = query_2nd.size ();
923
+ for (int s = 0 ; s < 2 ; ++s) {
924
+ for (size_t i = data[i1].common_prefix ; i < data[i1].seq_tokens [s].size (); ++i) {
925
+ llama_batch_add (batch, data[i1].seq_tokens [s][i], i, { s0 + s }, true );
926
+ }
927
+ }
910
928
911
- // Speedup small evaluations by evaluating atleast 32 tokens
912
- // For Winogrande this seems to slow it down rather than speed it up.
913
- // if (query_1st.size() < 32) query_1st.resize(32);
914
- // if (query_2nd.size() < 32) query_2nd.resize(32);
929
+ data[i1].i_batch = i_batch;
930
+ i_batch += data[i1].required_tokens ;
915
931
916
- llama_kv_cache_clear (ctx);
917
- auto logits_1st = evaluate_tokens (ctx, query_1st, 0 , params.n_batch , n_vocab);
932
+ n_cur += data[i1].required_tokens ;
933
+ if (++i1 == data.size ()) {
934
+ break ;
935
+ }
936
+ }
937
+
938
+ if (i0 == i1) {
939
+ fprintf (stderr, " %s : task %zu does not fit in the context window\n " , __func__, i0);
940
+ return ;
941
+ }
918
942
919
943
llama_kv_cache_clear (ctx);
920
- auto logits_2nd = evaluate_tokens (ctx, query_2nd, 0 , params.n_batch , n_vocab);
921
944
922
- if (logits_1st.empty () || logits_2nd.empty ()) {
923
- fprintf (stderr, " %s : failed to eval\n " , __func__);
945
+ // decode all tasks [i0, i1)
946
+ if (!decode_helper (ctx, batch, batch_logits, n_batch, n_vocab)) {
947
+ fprintf (stderr, " %s: llama_decode() failed\n " , __func__);
924
948
return ;
925
949
}
926
950
927
- bool skip_choice = query_1st_size - base_ctx_1st.size () > k_min_trailing_ctx &&
928
- query_2nd_size - base_ctx_2nd.size () > k_min_trailing_ctx;
929
-
930
- float score_1st = 0 ;
931
- bool is_nan_1st = false ;
932
- const auto & base_1 = skip_choice ? base_ctx_1st : base_context;
933
- const int last_1st = query_1st_size - base_1.size () > 1 ? 1 : 0 ;
934
- for (size_t j = base_1.size ()-1 ; j < query_1st_size-1 -last_1st; ++j) {
935
- std::memcpy (tok_logits.data (), logits_1st.data () + j*n_vocab, n_vocab*sizeof (float ));
936
- const float prob = softmax (tok_logits)[query_1st[j+1 ]];
937
- if (std::isnan (prob) || !prob) {
938
- fprintf (stderr, " %s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n " , __func__,
939
- prob, j, sentence_1st.c_str (), base_context.size ());
940
- is_nan_1st = true ;
941
- break ;
951
+ for (size_t i = i0; i < i1; ++i) {
952
+ auto & task = data[i];
953
+
954
+ const bool skip_choice =
955
+ task.seq_tokens [0 ].size () - task.common_prefix > k_min_trailing_ctx &&
956
+ task.seq_tokens [1 ].size () - task.common_prefix > k_min_trailing_ctx;
957
+
958
+ float score_1st = 0 ;
959
+ bool is_nan_1st = false ;
960
+ const auto & n_base1 = skip_choice ? task.n_base1 : task.common_prefix ;
961
+ const int last_1st = task.seq_tokens [0 ].size () - n_base1 > 1 ? 1 : 0 ;
962
+ size_t li = n_base1 - 1 ;
963
+ for (size_t j = n_base1-1 ; j < task.seq_tokens [0 ].size ()-1 -last_1st; ++j) {
964
+ std::memcpy (tok_logits.data (), batch_logits.data () + n_vocab*(task.i_batch + li++), n_vocab*sizeof (float ));
965
+ const float prob = softmax (tok_logits)[task.seq_tokens [0 ][j+1 ]];
966
+ if (std::isnan (prob) || !prob) {
967
+ fprintf (stderr, " %s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n " , __func__,
968
+ prob, j, (task.first + task.choices [0 ] + task.second ).c_str (), n_base1);
969
+ is_nan_1st = true ;
970
+ break ;
971
+ }
972
+ score_1st += std::log (prob);
942
973
}
943
- score_1st += std::log (prob);
944
- }
945
- score_1st /= (query_1st_size - base_1.size () - last_1st);
946
-
947
- float score_2nd = 0 ;
948
- bool is_nan_2nd = false ;
949
- const auto & base_2 = skip_choice ? base_ctx_2nd : base_context;
950
- const int last_2nd = query_2nd_size - base_2.size () > 1 ? 1 : 0 ;
951
- for (size_t j = base_2.size ()-1 ; j < query_2nd_size-1 -last_2nd; ++j) {
952
- std::memcpy (tok_logits.data (), logits_2nd.data () + j*n_vocab, n_vocab*sizeof (float ));
953
- const float prob = softmax (tok_logits)[query_2nd[j+1 ]];
954
- if (std::isnan (prob) || !prob) {
955
- fprintf (stderr, " %s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n " , __func__,
956
- prob, j, sentence_2nd.c_str (), base_context.size ());
957
- is_nan_2nd = true ;
958
- break ;
974
+ score_1st /= (task.seq_tokens [0 ].size () - n_base1 - last_1st);
975
+
976
+ float score_2nd = 0 ;
977
+ bool is_nan_2nd = false ;
978
+ const auto & n_base2 = skip_choice ? task.n_base2 : task.common_prefix ;
979
+ const int last_2nd = task.seq_tokens [1 ].size () - n_base2 > 1 ? 1 : 0 ;
980
+ li = task.seq_tokens [0 ].size () - task.common_prefix + n_base2 - 1 ;
981
+ for (size_t j = n_base2-1 ; j < task.seq_tokens [1 ].size ()-1 -last_2nd; ++j) {
982
+ std::memcpy (tok_logits.data (), batch_logits.data () + n_vocab*(task.i_batch + li++), n_vocab*sizeof (float ));
983
+ const float prob = softmax (tok_logits)[task.seq_tokens [1 ][j+1 ]];
984
+ if (std::isnan (prob) || !prob) {
985
+ fprintf (stderr, " %s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n " , __func__,
986
+ prob, j, (task.first + task.choices [1 ] + task.second ).c_str (), n_base2);
987
+ is_nan_2nd = true ;
988
+ break ;
989
+ }
990
+ score_2nd += std::log (prob);
959
991
}
960
- score_2nd += std::log (prob);
961
- }
962
- score_2nd /= (query_2nd_size - base_2.size () - last_2nd);
992
+ score_2nd /= (task.seq_tokens [1 ].size () - n_base2 - last_2nd);
963
993
964
- if (is_nan_1st || is_nan_2nd) {
965
- continue ;
966
- }
994
+ if (is_nan_1st || is_nan_2nd) {
995
+ continue ;
996
+ }
967
997
968
- if (std::isnan (score_1st) || std::isnan (score_2nd)) {
969
- printf (" ================== NaN score %g, %g) for:\n " , score_1st, score_2nd);
970
- printf (" Q1: <%s> - %zu tokens\n " , sentence_1st. c_str (), query_1st_size );
971
- printf (" Q2: <%s> - %zu tokens\n " , sentence_2nd. c_str (), query_2nd_size );
972
- printf (" B : <%s> - %zu tokens\n " , task.first .c_str (), base_context. size () );
973
- printf (" base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n " , base_1. size (), base_2. size () , skip_choice);
974
- continue ;
975
- }
998
+ if (std::isnan (score_1st) || std::isnan (score_2nd)) {
999
+ printf (" ================== NaN score %g, %g) for:\n " , score_1st, score_2nd);
1000
+ printf (" Q1: <%s> - %zu tokens\n " , (task. first + task. choices [ 0 ] + task. second ). c_str (), task. seq_tokens [ 0 ]. size () );
1001
+ printf (" Q2: <%s> - %zu tokens\n " , (task. first + task. choices [ 1 ] + task. second ). c_str (), task. seq_tokens [ 1 ]. size () );
1002
+ printf (" B : <%s> - %zu tokens\n " , task.first .c_str (), task. common_prefix );
1003
+ printf (" base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n " , n_base1, n_base2 , skip_choice);
1004
+ continue ;
1005
+ }
976
1006
977
- int result = score_1st > score_2nd ? 1 : 2 ;
1007
+ int result = score_1st > score_2nd ? 1 : 2 ;
1008
+
1009
+ if (result == task.answer ) {
1010
+ ++n_correct;
1011
+ }
1012
+ ++n_done;
978
1013
979
- if (result == task.answer ) {
980
- ++n_correct;
1014
+ // Print the accumulated accuracy mean x 100
1015
+ printf (" %zu\t %.4lf\t %10.6f %10.6f %d %d\n " , i+1 , 100.0 * n_correct/n_done, score_1st, score_2nd, result, task.answer );
1016
+ fflush (stdout);
981
1017
}
982
- ++n_done;
983
1018
984
- // Print the accumulated accuracy mean x 100
985
- printf (" %zu\t %.4lf\t %10.6f %10.6f %d %d\n " ,task_idx+1 , 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer );
986
- fflush (stdout);
1019
+ i0 = i1 - 1 ;
987
1020
}
988
1021
989
1022
printf (" \n " );
0 commit comments