5
5
#include < cmath>
6
6
#include < ctime>
7
7
#include < sstream>
8
+ #include < cstring>
8
9
9
10
#if defined(_MSC_VER)
10
11
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -209,50 +210,97 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
209
210
double acc = 0 .0f ;
210
211
const int n_vocab = llama_n_vocab (ctx);
211
212
213
+ std::vector<float > tok_logits (n_vocab);
214
+
212
215
for (size_t task_idx = 0 ; task_idx < hs_task_count; task_idx++) {
213
216
214
217
// Tokenize the context to count tokens
215
218
std::vector<int > context_embd = ::llama_tokenize (ctx, hs_data[task_idx].context , prepend_bos);
216
219
size_t context_size = context_embd.size ();
217
220
218
- for (size_t ending_idx=0 ;ending_idx<4 ;ending_idx++) {
221
+ // Do the 1st ending
222
+ // In this case we include the context when evaluating
223
+ auto query_embd = ::llama_tokenize (ctx, hs_data[task_idx].context + hs_data[task_idx].ending [0 ], prepend_bos);
224
+ auto query_size = query_embd.size ();
225
+ // printf("First query: %d\n",(int)query_size);
226
+
227
+ // Stop if query wont fit the ctx window
228
+ if (query_size > (size_t )params.n_ctx ) {
229
+ fprintf (stderr, " %s : number of tokens in query %zu > n_ctxl\n " , __func__, query_size);
230
+ return ;
231
+ }
232
+
233
+ // Speedup small evaluations by evaluating atleast 32 tokens
234
+ if (query_size < 32 ) {
235
+ query_embd.resize (32 );
236
+ }
237
+
238
+ // Evaluate the query
239
+ if (llama_eval (ctx, query_embd.data (), query_embd.size (), 0 , params.n_threads )) {
240
+ fprintf (stderr, " %s : failed to eval\n " , __func__);
241
+ return ;
242
+ }
243
+
244
+ auto query_logits = llama_get_logits (ctx);
245
+
246
+ std::memcpy (tok_logits.data (), query_logits + (context_size-1 )*n_vocab, n_vocab*sizeof (float ));
247
+ const auto first_probs = softmax (tok_logits);
248
+
249
+ hs_data[task_idx].ending_logprob_count [0 ] = 1 ;
250
+ hs_data[task_idx].ending_logprob [0 ] = std::log (first_probs[query_embd[context_size]]);
251
+
252
+ // Calculate the logprobs over the ending
253
+ for (size_t j = context_size; j < query_size - 1 ; j++) {
254
+
255
+ std::memcpy (tok_logits.data (), query_logits + j*n_vocab, n_vocab*sizeof (float ));
256
+
257
+ const float prob = softmax (tok_logits)[query_embd[j + 1 ]];
258
+
259
+ hs_data[task_idx].ending_logprob [0 ] += std::log (prob);
260
+ hs_data[task_idx].ending_logprob_count [0 ]++;
261
+ }
262
+
263
+ // Calculate the mean token logprob for acc_norm
264
+ hs_data[task_idx].ending_logprob [0 ] /= hs_data[task_idx].ending_logprob_count [0 ];
265
+
266
+ // Do the remaining endings
267
+ // For these, we use the bare ending with n_past = context_size
268
+ //
269
+ for (size_t ending_idx = 1 ; ending_idx < 4 ; ending_idx++) {
219
270
220
271
// Tokenize the query
221
- std::vector<int > query_embd = ::llama_tokenize (ctx, hs_data[task_idx].context + hs_data[task_idx].ending [ending_idx], prepend_bos);
222
- size_t query_size = query_embd.size ();
272
+ query_embd = ::llama_tokenize (ctx, hs_data[task_idx].ending [ending_idx], false );
273
+ query_size = query_embd.size ();
274
+ // printf("Second query: %d\n",(int)query_size);
223
275
224
276
// Stop if query wont fit the ctx window
225
- if (query_size > (size_t )params.n_ctx ) {
277
+ if (context_size + query_size > (size_t )params.n_ctx ) {
226
278
fprintf (stderr, " %s : number of tokens in query %zu > n_ctxl\n " , __func__, query_size);
227
279
return ;
228
280
}
229
281
230
282
// Speedup small evaluations by evaluating atleast 32 tokens
231
- if (query_size < 32 ) {
232
- query_embd.resize (32 );
233
- }
283
+ // No, resizing to 32 is actually slightly slower (at least on CUDA)
284
+ // if (query_size < 32) {
285
+ // query_embd.resize(32);
286
+ // }
234
287
235
288
// Evaluate the query
236
- if (llama_eval (ctx, query_embd.data (), query_embd.size (), 0 , params.n_threads )) {
289
+ if (llama_eval (ctx, query_embd.data (), query_embd.size (), context_size , params.n_threads )) {
237
290
fprintf (stderr, " %s : failed to eval\n " , __func__);
238
291
return ;
239
292
}
240
293
241
- const auto query_logits = llama_get_logits (ctx);
242
- std::vector<float > logits;
243
- logits.insert (logits.end (), query_logits, query_logits + query_size * n_vocab);
294
+ query_logits = llama_get_logits (ctx);
244
295
245
- hs_data[task_idx].ending_logprob_count [ending_idx] = 0 ;
246
- hs_data[task_idx].ending_logprob [ending_idx] = 0 . 0f ;
296
+ hs_data[task_idx].ending_logprob_count [ending_idx] = 1 ;
297
+ hs_data[task_idx].ending_logprob [ending_idx] = std::log (first_probs[query_embd[ 0 ]]) ;
247
298
248
299
// Calculate the logprobs over the ending
249
- for (size_t j = context_size-1 ; j < query_size - 1 ; j++) {
250
- // Calculate probability of next token, given the previous ones.
251
- const std::vector<float > tok_logits (
252
- logits.begin () + (j + 0 ) * n_vocab,
253
- logits.begin () + (j + 1 ) * n_vocab);
300
+ for (size_t j = 0 ; j < query_size - 1 ; j++) {
301
+ std::memcpy (tok_logits.data (), query_logits + j*n_vocab, n_vocab*sizeof (float ));
254
302
255
- const float prob = softmax (tok_logits)[query_embd[ j + 1 ]];
303
+ const float prob = softmax (tok_logits)[query_embd[j + 1 ]];
256
304
257
305
hs_data[task_idx].ending_logprob [ending_idx] += std::log (prob);
258
306
hs_data[task_idx].ending_logprob_count [ending_idx]++;
@@ -267,9 +315,9 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
267
315
}
268
316
269
317
// Find the ending with maximum logprob
270
- size_t ending_logprob_max_idx = - 1 ;
271
- double ending_logprob_max_val = -INFINITY ;
272
- for (size_t j= 0 ; j < 4 ; j++) {
318
+ size_t ending_logprob_max_idx = 0 ;
319
+ double ending_logprob_max_val = hs_data[task_idx]. ending_logprob [ 0 ] ;
320
+ for (size_t j = 1 ; j < 4 ; j++) {
273
321
if (hs_data[task_idx].ending_logprob [j] > ending_logprob_max_val) {
274
322
ending_logprob_max_idx = j;
275
323
ending_logprob_max_val = hs_data[task_idx].ending_logprob [j];
0 commit comments