Skip to content

Commit cb1c072

Browse files
ikawrakowKawrakow
andauthored
HellaSwag: split token evaluation into batches if needed (#2681)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 9e232f0 commit cb1c072

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,27 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
122122
printf("\n");
123123
}
124124

125+
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
126+
int n_vocab, int n_thread) {
127+
std::vector<float> result;
128+
result.reserve(tokens.size() * n_vocab);
129+
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
130+
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
131+
size_t n_tokens = tokens.size() - i_chunk * n_batch;
132+
n_tokens = std::min(n_tokens, size_t(n_batch));
133+
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
134+
fprintf(stderr, "%s : failed to eval\n", __func__);
135+
return {};
136+
}
137+
138+
const auto logits = llama_get_logits(ctx);
139+
result.insert(result.end(), logits, logits + n_tokens * n_vocab);
140+
141+
n_past += n_tokens;
142+
}
143+
return result;
144+
}
145+
125146
void hellaswag_score(llama_context * ctx, const gpt_params & params) {
126147
// Calculates hellaswag score (acc_norm) from prompt
127148
//
@@ -235,15 +256,13 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
235256
query_embd.resize(32);
236257
}
237258

238-
// Evaluate the query
239-
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
259+
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
260+
if (logits.empty()) {
240261
fprintf(stderr, "%s : failed to eval\n", __func__);
241262
return;
242263
}
243264

244-
auto query_logits = llama_get_logits(ctx);
245-
246-
std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float));
265+
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
247266
const auto first_probs = softmax(tok_logits);
248267

249268
hs_data[task_idx].ending_logprob_count[0] = 1;
@@ -252,7 +271,7 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
252271
// Calculate the logprobs over the ending
253272
for (size_t j = context_size; j < query_size - 1; j++) {
254273

255-
std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
274+
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
256275

257276
const float prob = softmax(tok_logits)[query_embd[j + 1]];
258277

@@ -271,7 +290,6 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
271290
// Tokenize the query
272291
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
273292
query_size = query_embd.size();
274-
//printf("Second query: %d\n",(int)query_size);
275293

276294
// Stop if query wont fit the ctx window
277295
if (context_size + query_size > (size_t)params.n_ctx) {
@@ -286,19 +304,18 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
286304
//}
287305

288306
// Evaluate the query
289-
if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) {
307+
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab, params.n_threads);
308+
if (logits.empty()) {
290309
fprintf(stderr, "%s : failed to eval\n", __func__);
291310
return;
292311
}
293312

294-
query_logits = llama_get_logits(ctx);
295-
296313
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
297314
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
298315

299316
// Calculate the logprobs over the ending
300317
for (size_t j = 0; j < query_size - 1; j++) {
301-
std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
318+
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
302319

303320
const float prob = softmax(tok_logits)[query_embd[j + 1]];
304321

0 commit comments

Comments
 (0)