@@ -122,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
122
122
printf (" \n " );
123
123
}
124
124
125
+ std::vector<float > hellaswag_evaluate_tokens (llama_context * ctx, const std::vector<int >& tokens, int n_past, int n_batch,
126
+ int n_vocab, int n_thread) {
127
+ std::vector<float > result;
128
+ result.reserve (tokens.size () * n_vocab);
129
+ size_t n_chunk = (tokens.size () + n_batch - 1 )/n_batch;
130
+ for (size_t i_chunk = 0 ; i_chunk < n_chunk; ++i_chunk) {
131
+ size_t n_tokens = tokens.size () - i_chunk * n_batch;
132
+ n_tokens = std::min (n_tokens, size_t (n_batch));
133
+ if (llama_eval (ctx, tokens.data () + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
134
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
135
+ return {};
136
+ }
137
+
138
+ const auto logits = llama_get_logits (ctx);
139
+ result.insert (result.end (), logits, logits + n_tokens * n_vocab);
140
+
141
+ n_past += n_tokens;
142
+ }
143
+ return result;
144
+ }
145
+
125
146
void hellaswag_score (llama_context * ctx, const gpt_params & params) {
126
147
// Calculates hellaswag score (acc_norm) from prompt
127
148
//
@@ -235,15 +256,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
235
256
query_embd.resize (32 );
236
257
}
237
258
238
- // Evaluate the query
239
- if (llama_eval (ctx, query_embd. data (), query_embd. size (), 0 , params. n_threads )) {
259
+ auto logits = hellaswag_evaluate_tokens (ctx, query_embd, 0 , params. n_batch , n_vocab, params. n_threads );
260
+ if (logits. empty ( )) {
240
261
fprintf (stderr, " %s : failed to eval\n " , __func__);
241
262
return ;
242
263
}
243
264
244
- auto query_logits = llama_get_logits (ctx);
245
-
246
- std::memcpy (tok_logits.data (), query_logits + (context_size-1 )*n_vocab, n_vocab*sizeof (float ));
265
+ std::memcpy (tok_logits.data (), logits.data () + (context_size-1 )*n_vocab, n_vocab*sizeof (float ));
247
266
const auto first_probs = softmax (tok_logits);
248
267
249
268
hs_data[task_idx].ending_logprob_count [0 ] = 1 ;
@@ -252,7 +271,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
252
271
// Calculate the logprobs over the ending
253
272
for (size_t j = context_size; j < query_size - 1 ; j++) {
254
273
255
- std::memcpy (tok_logits.data (), query_logits + j*n_vocab, n_vocab*sizeof (float ));
274
+ std::memcpy (tok_logits.data (), logits. data () + j*n_vocab, n_vocab*sizeof (float ));
256
275
257
276
const float prob = softmax (tok_logits)[query_embd[j + 1 ]];
258
277
@@ -271,7 +290,6 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
271
290
// Tokenize the query
272
291
query_embd = ::llama_tokenize (ctx, hs_data[task_idx].ending [ending_idx], false );
273
292
query_size = query_embd.size ();
274
- // printf("Second query: %d\n",(int)query_size);
275
293
276
294
// Stop if query wont fit the ctx window
277
295
if (context_size + query_size > (size_t )params.n_ctx ) {
@@ -286,19 +304,18 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
286
304
// }
287
305
288
306
// Evaluate the query
289
- if (llama_eval (ctx, query_embd.data (), query_embd.size (), context_size, params.n_threads )) {
307
+ logits = hellaswag_evaluate_tokens (ctx, query_embd, context_size, params.n_batch , n_vocab, params.n_threads );
308
+ if (logits.empty ()) {
290
309
fprintf (stderr, " %s : failed to eval\n " , __func__);
291
310
return ;
292
311
}
293
312
294
- query_logits = llama_get_logits (ctx);
295
-
296
313
hs_data[task_idx].ending_logprob_count [ending_idx] = 1 ;
297
314
hs_data[task_idx].ending_logprob [ending_idx] = std::log (first_probs[query_embd[0 ]]);
298
315
299
316
// Calculate the logprobs over the ending
300
317
for (size_t j = 0 ; j < query_size - 1 ; j++) {
301
- std::memcpy (tok_logits.data (), query_logits + j*n_vocab, n_vocab*sizeof (float ));
318
+ std::memcpy (tok_logits.data (), logits. data () + j*n_vocab, n_vocab*sizeof (float ));
302
319
303
320
const float prob = softmax (tok_logits)[query_embd[j + 1 ]];
304
321
0 commit comments