@@ -560,7 +560,8 @@ bool llama_eval(
560
560
const int n_past,
561
561
const std::vector<llama_vocab::id> & embd_inp,
562
562
std::vector<float > & embd_w,
563
- size_t & mem_per_token) {
563
+ size_t & mem_per_token,
564
+ bool return_all_logits = false ) {
564
565
const int N = embd_inp.size ();
565
566
566
567
const auto & hparams = model.hparams ;
@@ -578,7 +579,7 @@ bool llama_eval(
578
579
static void * buf = malloc (buf_size);
579
580
580
581
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
581
- const size_t buf_size_new = 1.1 *(mem_per_token*N); // add 10 % to account for ggml object overhead
582
+ const size_t buf_size_new = 1.3 *(mem_per_token*N); // add 30 % to account for ggml object overhead
582
583
// fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
583
584
584
585
// reallocate
@@ -764,9 +765,14 @@ bool llama_eval(
764
765
// embd_w.resize(n_vocab*N);
765
766
// memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
766
767
767
- // return result for just the last token
768
- embd_w.resize (n_vocab);
769
- memcpy (embd_w.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
768
+ if (return_all_logits) {
769
+ embd_w.resize (n_vocab * N);
770
+ memcpy (embd_w.data (), (float *) ggml_get_data (inpL), sizeof (float )*n_vocab*N);
771
+ } else {
772
+ // return result for just the last token
773
+ embd_w.resize (n_vocab);
774
+ memcpy (embd_w.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
775
+ }
770
776
771
777
if (mem_per_token == 0 ) {
772
778
mem_per_token = ggml_used_mem (ctx0)/N;
@@ -778,6 +784,76 @@ bool llama_eval(
778
784
return true ;
779
785
}
780
786
787
+ std::vector<double > softmax (const std::vector<float >& logits) {
788
+ std::vector<double > probs (logits.size ());
789
+ float max_logit = logits[0 ];
790
+ for (float v : logits) max_logit = std::max (max_logit, v);
791
+ double sum_exp = 0.0 ;
792
+ for (size_t i = 0 ; i < logits.size (); i++) {
793
+ // Subtract the maximum logit value from the current logit value for numerical stability
794
+ float logit = logits[i] - max_logit;
795
+ double exp_logit = std::exp (logit);
796
+ sum_exp += exp_logit;
797
+ probs[i] = exp_logit;
798
+ }
799
+ for (size_t i = 0 ; i < probs.size (); i++) probs[i] /= sum_exp;
800
+ return probs;
801
+ }
802
+
803
+ void perplexity (const llama_vocab &vocab, const llama_model &model, const gpt_params ¶ms, size_t mem_per_token) {
804
+ // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
805
+ // Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
806
+ // Output: `perplexity: 13.5106 [114/114]`
807
+ std::vector<llama_vocab::id> tokens = ::llama_tokenize (vocab, params.prompt , true );
808
+
809
+ int count = 0 ;
810
+ double nll = 0.0 ;
811
+ int seq_count = tokens.size () / params.n_ctx ;
812
+ printf (" Calculating perplexity over %d chunks\n " , seq_count);
813
+ for (int i = 0 ; i < seq_count; ++i) {
814
+ int start = i * params.n_ctx ;
815
+ int end = start + params.n_ctx - 1 ;
816
+ std::vector<llama_vocab::id> embd (tokens.begin () + start, tokens.begin () + end);
817
+ std::vector<float > logits;
818
+ auto start_t = std::chrono::high_resolution_clock::now ();
819
+ if (!llama_eval (model, params.n_threads , 0 , embd, logits, mem_per_token, true )) {
820
+ fprintf (stderr, " Failed to predict\n " );
821
+ return ;
822
+ }
823
+ auto end_t = std::chrono::high_resolution_clock::now ();
824
+ if (i == 0 ) {
825
+ double seconds = std::chrono::duration<double >(end_t - start_t ).count ();
826
+ printf (" %.2f seconds per pass - ETA %.2f hours\n " , seconds, (seconds * seq_count) / (60.0 *60.0 ));
827
+ }
828
+ // We get the logits for all the tokens in the context window (params.n_ctx)
829
+ // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
830
+ // calculate the perplexity over the last half the window (so the model always has
831
+ // some context to predict the token).
832
+ //
833
+ // We rely on the fact that attention in the forward pass only looks at previous
834
+ // tokens here, so the logits returned for each token are an accurate representation
835
+ // of what the model would have predicted at that point.
836
+ //
837
+ // Example, we have a context window of 512, we will compute perplexity for each of the
838
+ // last 256 tokens. Then, we split the input up into context window size chunks to
839
+ // process the entire prompt.
840
+ for (int j = params.n_ctx / 2 ; j < params.n_ctx - 1 ; ++j) {
841
+ // Calculate probability of next token, given the previous ones.
842
+ int n_vocab = model.hparams .n_vocab ;
843
+ std::vector<float > tok_logits (
844
+ logits.begin () + j * n_vocab,
845
+ logits.begin () + (j + 1 ) * n_vocab);
846
+ double prob = softmax (tok_logits)[tokens[start + j + 1 ]];
847
+ nll += -std::log (prob);
848
+ ++count;
849
+ }
850
+ // perplexity is e^(average negative log-likelihood)
851
+ printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
852
+ fflush (stdout);
853
+ }
854
+ printf (" \n " );
855
+ }
856
+
781
857
static bool is_interacting = false ;
782
858
783
859
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
@@ -868,13 +944,22 @@ int main(int argc, char ** argv) {
868
944
params.n_threads , std::thread::hardware_concurrency (), llama_print_system_info ());
869
945
}
870
946
947
+ std::vector<float > logits;
948
+
949
+ // determine the required inference memory per token:
950
+ size_t mem_per_token = 0 ;
951
+ llama_eval (model, params.n_threads , 0 , { 0 , 1 , 2 , 3 }, logits, mem_per_token);
952
+
953
+ if (params.perplexity ) {
954
+ perplexity (vocab, model, params, mem_per_token);
955
+ exit (0 );
956
+ }
957
+
871
958
int n_past = 0 ;
872
959
873
960
int64_t t_sample_us = 0 ;
874
961
int64_t t_predict_us = 0 ;
875
962
876
- std::vector<float > logits;
877
-
878
963
// Add a space in front of the first character to match OG llama tokenizer behavior
879
964
params.prompt .insert (0 , 1 , ' ' );
880
965
// tokenize the prompt
@@ -928,10 +1013,6 @@ int main(int argc, char ** argv) {
928
1013
929
1014
std::vector<llama_vocab::id> embd;
930
1015
931
- // determine the required inference memory per token:
932
- size_t mem_per_token = 0 ;
933
- llama_eval (model, params.n_threads , 0 , { 0 , 1 , 2 , 3 }, logits, mem_per_token);
934
-
935
1016
int last_n_size = params.repeat_last_n ;
936
1017
std::vector<llama_vocab::id> last_n_tokens (last_n_size);
937
1018
std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
0 commit comments