Skip to content

Commit 5e9ff54

Browse files
ikawrakowKawrakow
andauthored
More efficient Hellaswag implementation (#2677)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 1f0bccb commit 5e9ff54

File tree

1 file changed

+70
-22
lines changed

1 file changed

+70
-22
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cmath>
66
#include <ctime>
77
#include <sstream>
8+
#include <cstring>
89

910
#if defined(_MSC_VER)
1011
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -209,50 +210,97 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
209210
double acc = 0.0f;
210211
const int n_vocab = llama_n_vocab(ctx);
211212

213+
std::vector<float> tok_logits(n_vocab);
214+
212215
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
213216

214217
// Tokenize the context to count tokens
215218
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
216219
size_t context_size = context_embd.size();
217220

218-
for (size_t ending_idx=0;ending_idx<4;ending_idx++) {
221+
// Do the 1st ending
222+
// In this case we include the context when evaluating
223+
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
224+
auto query_size = query_embd.size();
225+
//printf("First query: %d\n",(int)query_size);
226+
227+
// Stop if query wont fit the ctx window
228+
if (query_size > (size_t)params.n_ctx) {
229+
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
230+
return;
231+
}
232+
233+
// Speedup small evaluations by evaluating atleast 32 tokens
234+
if (query_size < 32) {
235+
query_embd.resize(32);
236+
}
237+
238+
// Evaluate the query
239+
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
240+
fprintf(stderr, "%s : failed to eval\n", __func__);
241+
return;
242+
}
243+
244+
auto query_logits = llama_get_logits(ctx);
245+
246+
std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float));
247+
const auto first_probs = softmax(tok_logits);
248+
249+
hs_data[task_idx].ending_logprob_count[0] = 1;
250+
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
251+
252+
// Calculate the logprobs over the ending
253+
for (size_t j = context_size; j < query_size - 1; j++) {
254+
255+
std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
256+
257+
const float prob = softmax(tok_logits)[query_embd[j + 1]];
258+
259+
hs_data[task_idx].ending_logprob[0] += std::log(prob);
260+
hs_data[task_idx].ending_logprob_count[0]++;
261+
}
262+
263+
// Calculate the mean token logprob for acc_norm
264+
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
265+
266+
// Do the remaining endings
267+
// For these, we use the bare ending with n_past = context_size
268+
//
269+
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
219270

220271
// Tokenize the query
221-
std::vector<int> query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos);
222-
size_t query_size = query_embd.size();
272+
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
273+
query_size = query_embd.size();
274+
//printf("Second query: %d\n",(int)query_size);
223275

224276
// Stop if query wont fit the ctx window
225-
if (query_size > (size_t)params.n_ctx) {
277+
if (context_size + query_size > (size_t)params.n_ctx) {
226278
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
227279
return;
228280
}
229281

230282
// Speedup small evaluations by evaluating atleast 32 tokens
231-
if (query_size < 32) {
232-
query_embd.resize(32);
233-
}
283+
// No, resizing to 32 is actually slightly slower (at least on CUDA)
284+
//if (query_size < 32) {
285+
// query_embd.resize(32);
286+
//}
234287

235288
// Evaluate the query
236-
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
289+
if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) {
237290
fprintf(stderr, "%s : failed to eval\n", __func__);
238291
return;
239292
}
240293

241-
const auto query_logits = llama_get_logits(ctx);
242-
std::vector<float> logits;
243-
logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab);
294+
query_logits = llama_get_logits(ctx);
244295

245-
hs_data[task_idx].ending_logprob_count[ending_idx] = 0;
246-
hs_data[task_idx].ending_logprob[ending_idx] = 0.0f;
296+
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
297+
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
247298

248299
// Calculate the logprobs over the ending
249-
for (size_t j = context_size-1; j < query_size - 1; j++) {
250-
// Calculate probability of next token, given the previous ones.
251-
const std::vector<float> tok_logits(
252-
logits.begin() + (j + 0) * n_vocab,
253-
logits.begin() + (j + 1) * n_vocab);
300+
for (size_t j = 0; j < query_size - 1; j++) {
301+
std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));
254302

255-
const float prob = softmax(tok_logits)[query_embd[ j + 1]];
303+
const float prob = softmax(tok_logits)[query_embd[j + 1]];
256304

257305
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
258306
hs_data[task_idx].ending_logprob_count[ending_idx]++;
@@ -267,9 +315,9 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
267315
}
268316

269317
// Find the ending with maximum logprob
270-
size_t ending_logprob_max_idx = -1;
271-
double ending_logprob_max_val = -INFINITY;
272-
for (size_t j=0; j < 4; j++) {
318+
size_t ending_logprob_max_idx = 0;
319+
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
320+
for (size_t j = 1; j < 4; j++) {
273321
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
274322
ending_logprob_max_idx = j;
275323
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];

0 commit comments

Comments
 (0)