Skip to content

More efficient HellaSwag implementation #2677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 70 additions & 22 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cmath>
#include <ctime>
#include <sstream>
#include <cstring>

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
Expand Down Expand Up @@ -209,50 +210,97 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
double acc = 0.0f;
const int n_vocab = llama_n_vocab(ctx);

std::vector<float> tok_logits(n_vocab);

for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {

// Tokenize the context to count tokens
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, prepend_bos);
size_t context_size = context_embd.size();

for (size_t ending_idx=0;ending_idx<4;ending_idx++) {
// Do the 1st ending
// In this case we include the context when evaluating
auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], prepend_bos);
auto query_size = query_embd.size();
//printf("First query: %d\n",(int)query_size);

// Stop if query wont fit the ctx window
if (query_size > (size_t)params.n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}

// Speedup small evaluations by evaluating atleast 32 tokens
if (query_size < 32) {
query_embd.resize(32);
}

// Evaluate the query
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}

auto query_logits = llama_get_logits(ctx);

std::memcpy(tok_logits.data(), query_logits + (context_size-1)*n_vocab, n_vocab*sizeof(float));
const auto first_probs = softmax(tok_logits);

hs_data[task_idx].ending_logprob_count[0] = 1;
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);

// Calculate the logprobs over the ending
for (size_t j = context_size; j < query_size - 1; j++) {

std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));

const float prob = softmax(tok_logits)[query_embd[j + 1]];

hs_data[task_idx].ending_logprob[0] += std::log(prob);
hs_data[task_idx].ending_logprob_count[0]++;
}

// Calculate the mean token logprob for acc_norm
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];

// Do the remaining endings
// For these, we use the bare ending with n_past = context_size
//
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {

// Tokenize the query
std::vector<int> query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[ending_idx], prepend_bos);
size_t query_size = query_embd.size();
query_embd = ::llama_tokenize(ctx, hs_data[task_idx].ending[ending_idx], false);
query_size = query_embd.size();
//printf("Second query: %d\n",(int)query_size);

// Stop if query wont fit the ctx window
if (query_size > (size_t)params.n_ctx) {
if (context_size + query_size > (size_t)params.n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}

// Speedup small evaluations by evaluating atleast 32 tokens
if (query_size < 32) {
query_embd.resize(32);
}
// No, resizing to 32 is actually slightly slower (at least on CUDA)
//if (query_size < 32) {
// query_embd.resize(32);
//}

// Evaluate the query
if (llama_eval(ctx, query_embd.data(), query_embd.size(), 0, params.n_threads)) {
if (llama_eval(ctx, query_embd.data(), query_embd.size(), context_size, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}

const auto query_logits = llama_get_logits(ctx);
std::vector<float> logits;
logits.insert(logits.end(), query_logits, query_logits + query_size * n_vocab);
query_logits = llama_get_logits(ctx);

hs_data[task_idx].ending_logprob_count[ending_idx] = 0;
hs_data[task_idx].ending_logprob[ending_idx] = 0.0f;
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);

// Calculate the logprobs over the ending
for (size_t j = context_size-1; j < query_size - 1; j++) {
// Calculate probability of next token, given the previous ones.
const std::vector<float> tok_logits(
logits.begin() + (j + 0) * n_vocab,
logits.begin() + (j + 1) * n_vocab);
for (size_t j = 0; j < query_size - 1; j++) {
std::memcpy(tok_logits.data(), query_logits + j*n_vocab, n_vocab*sizeof(float));

const float prob = softmax(tok_logits)[query_embd[ j + 1]];
const float prob = softmax(tok_logits)[query_embd[j + 1]];

hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
hs_data[task_idx].ending_logprob_count[ending_idx]++;
Expand All @@ -267,9 +315,9 @@ void hellaswag_score(llama_context * ctx, const gpt_params & params) {
}

// Find the ending with maximum logprob
size_t ending_logprob_max_idx = -1;
double ending_logprob_max_val = -INFINITY;
for (size_t j=0; j < 4; j++) {
size_t ending_logprob_max_idx = 0;
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
for (size_t j = 1; j < 4; j++) {
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
ending_logprob_max_idx = j;
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];
Expand Down