3
3
#include " build-info.h"
4
4
5
5
#include < cmath>
6
+ #include < cstdio>
7
+ #include < cstring>
6
8
#include < ctime>
7
9
#include < sstream>
8
- #include < cstring>
9
10
#include < thread>
10
11
#include < mutex>
12
+ #include < tuple>
13
+ #include < utility>
14
+ #include < vector>
11
15
12
16
#if defined(_MSC_VER)
13
17
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -29,20 +33,20 @@ std::vector<float> softmax(const std::vector<float>& logits) {
29
33
return probs;
30
34
}
31
35
32
- float log_softmax (int n_vocab, const float * logits, int tok) {
36
+ std::tuple< double , float , float > log_softmax (int n_vocab, const float * logits, int tok) {
33
37
float max_logit = logits[0 ];
34
38
for (int i = 1 ; i < n_vocab; ++i) max_logit = std::max (max_logit, logits[i]);
35
39
double sum_exp = 0.0 ;
36
40
for (int i = 0 ; i < n_vocab; ++i) sum_exp += expf (logits[i] - max_logit);
37
- return logits[tok] - max_logit - log (sum_exp);
41
+ return std::make_tuple (-( logits[tok] - max_logit - log (sum_exp)), logits[tok], expf (logits[tok] - max_logit) / sum_exp);
38
42
}
39
43
40
- void process_logits (int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread>& workers,
41
- double & nll, double & nll2) {
44
+ void process_logits (int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
45
+ double & nll, double & nll2, float * logit_history, float * prob_history ) {
42
46
43
47
std::mutex mutex;
44
48
int counter = 0 ;
45
- auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () {
49
+ auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
46
50
double local_nll = 0 , local_nll2 = 0 ;
47
51
while (true ) {
48
52
std::unique_lock<std::mutex> lock (mutex);
@@ -52,34 +56,44 @@ void process_logits(int n_vocab, const float * logits, const int * tokens, int n
52
56
break ;
53
57
}
54
58
lock.unlock ();
55
- double v = -log_softmax (n_vocab, logits + i*n_vocab, tokens[i+1 ]);
56
- local_nll += v;
57
- local_nll2 += v*v;
59
+ const std::tuple<double , float , float > v = log_softmax (n_vocab, logits + i*n_vocab, tokens[i+1 ]);
60
+ const double v0 = std::get<0 >(v);
61
+ local_nll += v0;
62
+ local_nll2 += v0*v0;
63
+
64
+ logit_history[i] = std::get<1 >(v);
65
+ prob_history[i] = std::get<2 >(v);
58
66
}
59
67
};
60
- for (auto & w : workers) w = std::thread (compute);
68
+ for (auto & w : workers) w = std::thread (compute);
61
69
compute ();
62
- for (auto & w : workers) w.join ();
70
+ for (auto & w : workers) w.join ();
63
71
64
72
}
65
73
66
- void perplexity_v2 (llama_context * ctx, const gpt_params & params) {
74
+ std::tuple<std::vector<llama_token>, std::vector<float >, std::vector<float >, float >
75
+ perplexity_v2 (llama_context * ctx, const gpt_params & params) {
67
76
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
68
77
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
69
78
// Output: `perplexity: 13.5106 [114/114]`
70
79
// BOS tokens will be added for each chunk before eval
71
80
72
- if (params.ppl_stride <= 0 ) {
73
- fprintf (stderr, " %s: stride is %d but must be greater than zero!\n " ,__func__,params.ppl_stride );
74
- return ;
75
- }
76
-
77
81
const bool is_spm = llama_vocab_type (ctx) == LLAMA_VOCAB_TYPE_SPM;
78
82
const bool add_bos = is_spm;
79
83
80
84
fprintf (stderr, " %s: tokenizing the input ..\n " , __func__);
81
85
82
- auto tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
86
+ std::vector<llama_token> tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
87
+ std::vector<float > logit_history;
88
+ std::vector<float > prob_history;
89
+
90
+ logit_history.resize (tokens.size ());
91
+ prob_history.resize (tokens.size ());
92
+
93
+ if (params.ppl_stride <= 0 ) {
94
+ fprintf (stderr, " %s: stride is %d but must be greater than zero!\n " ,__func__,params.ppl_stride );
95
+ return std::make_tuple (tokens, logit_history, prob_history, -1 );
96
+ }
83
97
84
98
const int calc_chunk = params.n_ctx ;
85
99
@@ -88,7 +102,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
88
102
if (int (tokens.size ()) <= calc_chunk) {
89
103
fprintf (stderr, " %s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n " ,__func__,
90
104
tokens.size (), params.n_ctx , params.ppl_stride );
91
- return ;
105
+ return std::make_tuple (tokens, logit_history, prob_history, - 1 ) ;
92
106
}
93
107
94
108
const int n_chunk_max = (tokens.size () - calc_chunk + params.ppl_stride - 1 ) / params.ppl_stride ;
@@ -120,7 +134,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
120
134
// fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
121
135
if (llama_eval (ctx, tokens.data () + batch_start, batch_size, j * n_batch, params.n_threads )) {
122
136
// fprintf(stderr, "%s : failed to eval\n", __func__);
123
- return ;
137
+ return std::make_tuple (tokens, logit_history, prob_history, - 1 ) ;
124
138
}
125
139
126
140
// save original token and restore it after eval
@@ -161,6 +175,8 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
161
175
logits.begin () + (j + 1 ) * n_vocab);
162
176
163
177
const float prob = softmax (tok_logits)[tokens[start + j + 1 ]];
178
+ logit_history[start + j + 1 ] = tok_logits[tokens[start + j + 1 ]];
179
+ prob_history[start + j + 1 ] = prob;
164
180
165
181
nll += -std::log (prob);
166
182
++count;
@@ -174,12 +190,15 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
174
190
fflush (stdout);
175
191
}
176
192
printf (" \n " );
193
+
194
+ return std::make_tuple (tokens, logit_history, prob_history, std::exp (nll / count));
177
195
}
178
196
179
- void perplexity (llama_context * ctx, const gpt_params & params) {
197
+ std::tuple<std::vector<llama_token>, std::vector<float >, std::vector<float >, float >
198
+ perplexity (llama_context * ctx, const gpt_params & params) {
199
+
180
200
if (params.ppl_stride > 0 ) {
181
- perplexity_v2 (ctx, params);
182
- return ;
201
+ return perplexity_v2 (ctx, params);
183
202
}
184
203
185
204
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
@@ -193,11 +212,17 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
193
212
auto tim1 = std::chrono::high_resolution_clock::now ();
194
213
fprintf (stderr, " %s: tokenizing the input ..\n " , __func__);
195
214
196
- auto tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
215
+ std::vector<llama_token> tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
197
216
198
217
auto tim2 = std::chrono::high_resolution_clock::now ();
199
218
fprintf (stderr, " %s: tokenization took %g ms\n " ,__func__,1e-3 *std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count ());
200
219
220
+ std::vector<float > logit_history;
221
+ logit_history.resize (tokens.size ());
222
+
223
+ std::vector<float > prob_history;
224
+ prob_history.resize (tokens.size ());
225
+
201
226
const int n_chunk_max = tokens.size () / params.n_ctx ;
202
227
203
228
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min (params.n_chunks , n_chunk_max);
@@ -236,7 +261,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
236
261
237
262
if (llama_eval (ctx, tokens.data () + batch_start, batch_size, j * n_batch, params.n_threads )) {
238
263
fprintf (stderr, " %s : failed to eval\n " , __func__);
239
- return ;
264
+ return std::make_tuple (tokens, logit_history, prob_history, - 1 ) ;
240
265
}
241
266
242
267
// restore the original token in case it was set to BOS
@@ -272,7 +297,8 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
272
297
// last 256 tokens. Then, we split the input up into context window size chunks to
273
298
// process the entire prompt.
274
299
const int first = std::min (512 , params.n_ctx /2 );
275
- process_logits (n_vocab, logits.data () + first*n_vocab, tokens.data () + start + first, params.n_ctx - 1 - first, workers, nll, nll2);
300
+ process_logits (n_vocab, logits.data () + first*n_vocab, tokens.data () + start + first, params.n_ctx - 1 - first,
301
+ workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
276
302
count += params.n_ctx - first - 1 ;
277
303
278
304
// perplexity is e^(average negative log-likelihood)
@@ -287,16 +313,19 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
287
313
fflush (stdout);
288
314
}
289
315
printf (" \n " );
316
+
290
317
nll2 /= count;
291
318
nll /= count;
319
+ const double ppl = exp (nll);
292
320
nll2 -= nll * nll;
293
321
if (nll2 > 0 ) {
294
322
nll2 = sqrt (nll2/(count-1 ));
295
- double ppl = exp (nll);
296
323
printf (" Final estimate: PPL = %.4lf +/- %.5lf\n " , ppl, nll2*ppl);
297
324
} else {
298
325
printf (" Unexpected negative standard deviation of log(prob)\n " );
299
326
}
327
+
328
+ return std::make_tuple (tokens, logit_history, prob_history, ppl);
300
329
}
301
330
302
331
std::vector<float > hellaswag_evaluate_tokens (llama_context * ctx, const std::vector<int >& tokens, int n_past, int n_batch,
@@ -604,13 +633,56 @@ int main(int argc, char ** argv) {
604
633
params.n_threads , std::thread::hardware_concurrency (), llama_print_system_info ());
605
634
}
606
635
636
+ std::vector<llama_token> tokens;
637
+ std::vector<float > logits;
638
+ std::vector<float > probs;
639
+ double perplexity_value = -1 ;
607
640
if (params.hellaswag ) {
608
641
hellaswag_score (ctx, params);
609
642
} else {
610
- perplexity (ctx, params);
643
+ auto ret = perplexity (ctx, params);
644
+ tokens = std::get<0 >(ret);
645
+ logits = std::get<1 >(ret);
646
+ probs = std::get<2 >(ret);
647
+ perplexity_value = std::get<3 >(ret);
611
648
}
612
649
613
650
llama_print_timings (ctx);
651
+
652
+ if (params.hellaswag && !params.logdir .empty ()) {
653
+ fprintf (stderr, " %s: warning: logging results is not implemented for HellaSwag. No files will be written.\n " , __func__);
654
+ }
655
+
656
+ if (!params.hellaswag && !params.logdir .empty ()) {
657
+ const std::string timestamp = get_sortable_timestamp ();
658
+
659
+ const bool success = create_directory_with_parents (params.logdir );
660
+ if (success) {
661
+
662
+ FILE * logfile = fopen ((params.logdir + timestamp + " .yml" ).c_str (), " w" );
663
+ fprintf (logfile, " binary: perplexity\n " );
664
+ char model_type[128 ];
665
+ llama_model_desc (model, model_type, sizeof (model_type));
666
+ dump_non_result_info_yaml (logfile, params, ctx, timestamp, tokens, model_type);
667
+
668
+ fprintf (logfile, " \n " );
669
+ fprintf (logfile, " ######################\n " );
670
+ fprintf (logfile, " # Perplexity Results #\n " );
671
+ fprintf (logfile, " ######################\n " );
672
+ fprintf (logfile, " \n " );
673
+
674
+ dump_vector_float_yaml (logfile, " logits" , logits);
675
+ fprintf (logfile, " ppl_value: %f\n " , perplexity_value);
676
+ dump_vector_float_yaml (logfile, " probs" , probs);
677
+
678
+ llama_dump_timing_info_yaml (logfile, ctx);
679
+ fclose (logfile);
680
+ } else {
681
+ fprintf (stderr, " %s: warning: failed to create logdir %s, cannot write logfile\n " ,
682
+ __func__, params.logdir .c_str ());
683
+ }
684
+ }
685
+
614
686
llama_free (ctx);
615
687
llama_free_model (model);
616
688
0 commit comments