@@ -28,7 +28,6 @@ std::vector<float> softmax(const std::vector<float>& logits) {
2828}
2929
3030void perplexity_v2 (llama_context * ctx, const gpt_params & params) {
31-
3231 // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
3332 // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
3433 // Output: `perplexity: 13.5106 [114/114]`
@@ -38,7 +37,13 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
3837 fprintf (stderr, " %s: stride is %d but must be greater than zero!\n " ,__func__,params.ppl_stride );
3938 return ;
4039 }
41- auto tokens = ::llama_tokenize (ctx, params.prompt , true );
40+
41+ const bool is_spm = llama_vocab_type (ctx) == LLAMA_VOCAB_TYPE_SPM;
42+ const bool add_bos = is_spm;
43+
44+ fprintf (stderr, " %s: tokenizing the input ..\n " , __func__);
45+
46+ auto tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
4247
4348 const int calc_chunk = params.n_ctx ;
4449
@@ -86,7 +91,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
8691 const auto token_org = tokens[batch_start];
8792
8893 // add BOS token for the first batch of each chunk
89- if (j == 0 ) {
94+ if (add_bos && j == 0 ) {
9095 tokens[batch_start] = llama_token_bos (ctx);
9196 }
9297
@@ -136,7 +141,6 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) {
136141}
137142
138143void perplexity (llama_context * ctx, const gpt_params & params) {
139-
140144 if (params.ppl_stride > 0 ) {
141145 perplexity_v2 (ctx, params);
142146 return ;
@@ -146,7 +150,13 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
146150 // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
147151 // Output: `perplexity: 13.5106 [114/114]`
148152 // BOS tokens will be added for each chunk before eval
149- auto tokens = ::llama_tokenize (ctx, params.prompt , true );
153+
154+ const bool is_spm = llama_vocab_type (ctx) == LLAMA_VOCAB_TYPE_SPM;
155+ const bool add_bos = is_spm;
156+
157+ fprintf (stderr, " %s: tokenizing the input ..\n " , __func__);
158+
159+ auto tokens = ::llama_tokenize (ctx, params.prompt , add_bos);
150160
151161 const int n_chunk_max = tokens.size () / params.n_ctx ;
152162
@@ -177,7 +187,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
177187 const auto token_org = tokens[batch_start];
178188
179189 // add BOS token for the first batch of each chunk
180- if (j == 0 ) {
190+ if (add_bos && j == 0 ) {
181191 tokens[batch_start] = llama_token_bos (ctx);
182192 }
183193
@@ -295,8 +305,10 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
295305 size_t hs_task_count = prompt_lines.size ()/6 ;
296306 fprintf (stderr, " %s : loaded %zu tasks from prompt.\n " , __func__, hs_task_count);
297307
308+ const bool is_spm = llama_vocab_type (ctx) == LLAMA_VOCAB_TYPE_SPM;
309+
298310 // This is needed as usual for LLaMA models
299- bool prepend_bos = true ;
311+ const bool add_bos = is_spm ;
300312
301313 // Number of tasks to use when computing the score
302314 if ( params.hellaswag_tasks < hs_task_count ) {
@@ -352,14 +364,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
352364 std::vector<float > tok_logits (n_vocab);
353365
354366 for (size_t task_idx = 0 ; task_idx < hs_task_count; task_idx++) {
355-
356367 // Tokenize the context to count tokens
357- std::vector<int > context_embd = ::llama_tokenize (ctx, hs_data[task_idx].context , prepend_bos );
368+ std::vector<int > context_embd = ::llama_tokenize (ctx, hs_data[task_idx].context , add_bos );
358369 size_t context_size = context_embd.size ();
359370
360371 // Do the 1st ending
361372 // In this case we include the context when evaluating
362- auto query_embd = ::llama_tokenize (ctx, hs_data[task_idx].context + hs_data[task_idx].ending [0 ], prepend_bos );
373+ auto query_embd = ::llama_tokenize (ctx, hs_data[task_idx].context + hs_data[task_idx].ending [0 ], add_bos );
363374 auto query_size = query_embd.size ();
364375 // printf("First query: %d\n",(int)query_size);
365376
0 commit comments