Skip to content

Commit 486ae64

Browse files
authored
Compute perplexity over prompt (abetlen#270)
* Compute perplexity over prompt * More accurate perplexity calculation - over all logits in the context window (so 512x more tokens!) * Output all perplexitiies * Add timing/ETA
1 parent 3ab3e65 commit 486ae64

File tree

3 files changed

+98
-13
lines changed

3 files changed

+98
-13
lines changed

main.cpp

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ bool llama_eval(
560560
const int n_past,
561561
const std::vector<llama_vocab::id> & embd_inp,
562562
std::vector<float> & embd_w,
563-
size_t & mem_per_token) {
563+
size_t & mem_per_token,
564+
bool return_all_logits = false) {
564565
const int N = embd_inp.size();
565566

566567
const auto & hparams = model.hparams;
@@ -578,7 +579,7 @@ bool llama_eval(
578579
static void * buf = malloc(buf_size);
579580

580581
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
581-
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
582+
const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead
582583
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
583584

584585
// reallocate
@@ -764,9 +765,14 @@ bool llama_eval(
764765
//embd_w.resize(n_vocab*N);
765766
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
766767

767-
// return result for just the last token
768-
embd_w.resize(n_vocab);
769-
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
768+
if (return_all_logits) {
769+
embd_w.resize(n_vocab * N);
770+
memcpy(embd_w.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
771+
} else {
772+
// return result for just the last token
773+
embd_w.resize(n_vocab);
774+
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
775+
}
770776

771777
if (mem_per_token == 0) {
772778
mem_per_token = ggml_used_mem(ctx0)/N;
@@ -778,6 +784,76 @@ bool llama_eval(
778784
return true;
779785
}
780786

787+
std::vector<double> softmax(const std::vector<float>& logits) {
788+
std::vector<double> probs(logits.size());
789+
float max_logit = logits[0];
790+
for (float v : logits) max_logit = std::max(max_logit, v);
791+
double sum_exp = 0.0;
792+
for (size_t i = 0; i < logits.size(); i++) {
793+
// Subtract the maximum logit value from the current logit value for numerical stability
794+
float logit = logits[i] - max_logit;
795+
double exp_logit = std::exp(logit);
796+
sum_exp += exp_logit;
797+
probs[i] = exp_logit;
798+
}
799+
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
800+
return probs;
801+
}
802+
803+
void perplexity(const llama_vocab &vocab, const llama_model &model, const gpt_params &params, size_t mem_per_token) {
804+
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
805+
// Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
806+
// Output: `perplexity: 13.5106 [114/114]`
807+
std::vector<llama_vocab::id> tokens = ::llama_tokenize(vocab, params.prompt, true);
808+
809+
int count = 0;
810+
double nll = 0.0;
811+
int seq_count = tokens.size() / params.n_ctx;
812+
printf("Calculating perplexity over %d chunks\n", seq_count);
813+
for (int i = 0; i < seq_count; ++i) {
814+
int start = i * params.n_ctx;
815+
int end = start + params.n_ctx - 1;
816+
std::vector<llama_vocab::id> embd(tokens.begin() + start, tokens.begin() + end);
817+
std::vector<float> logits;
818+
auto start_t = std::chrono::high_resolution_clock::now();
819+
if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token, true)) {
820+
fprintf(stderr, "Failed to predict\n");
821+
return;
822+
}
823+
auto end_t = std::chrono::high_resolution_clock::now();
824+
if (i == 0) {
825+
double seconds = std::chrono::duration<double>(end_t - start_t).count();
826+
printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
827+
}
828+
// We get the logits for all the tokens in the context window (params.n_ctx)
829+
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
830+
// calculate the perplexity over the last half the window (so the model always has
831+
// some context to predict the token).
832+
//
833+
// We rely on the fact that attention in the forward pass only looks at previous
834+
// tokens here, so the logits returned for each token are an accurate representation
835+
// of what the model would have predicted at that point.
836+
//
837+
// Example, we have a context window of 512, we will compute perplexity for each of the
838+
// last 256 tokens. Then, we split the input up into context window size chunks to
839+
// process the entire prompt.
840+
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
841+
// Calculate probability of next token, given the previous ones.
842+
int n_vocab = model.hparams.n_vocab;
843+
std::vector<float> tok_logits(
844+
logits.begin() + j * n_vocab,
845+
logits.begin() + (j + 1) * n_vocab);
846+
double prob = softmax(tok_logits)[tokens[start + j + 1]];
847+
nll += -std::log(prob);
848+
++count;
849+
}
850+
// perplexity is e^(average negative log-likelihood)
851+
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
852+
fflush(stdout);
853+
}
854+
printf("\n");
855+
}
856+
781857
static bool is_interacting = false;
782858

783859
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
@@ -868,13 +944,22 @@ int main(int argc, char ** argv) {
868944
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
869945
}
870946

947+
std::vector<float> logits;
948+
949+
// determine the required inference memory per token:
950+
size_t mem_per_token = 0;
951+
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
952+
953+
if (params.perplexity) {
954+
perplexity(vocab, model, params, mem_per_token);
955+
exit(0);
956+
}
957+
871958
int n_past = 0;
872959

873960
int64_t t_sample_us = 0;
874961
int64_t t_predict_us = 0;
875962

876-
std::vector<float> logits;
877-
878963
// Add a space in front of the first character to match OG llama tokenizer behavior
879964
params.prompt.insert(0, 1, ' ');
880965
// tokenize the prompt
@@ -928,10 +1013,6 @@ int main(int argc, char ** argv) {
9281013

9291014
std::vector<llama_vocab::id> embd;
9301015

931-
// determine the required inference memory per token:
932-
size_t mem_per_token = 0;
933-
llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token);
934-
9351016
int last_n_size = params.repeat_last_n;
9361017
std::vector<llama_vocab::id> last_n_tokens(last_n_size);
9371018
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);

utils.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
7272
params.use_color = true;
7373
} else if (arg == "-r" || arg == "--reverse-prompt") {
7474
params.antiprompt.push_back(argv[++i]);
75+
} else if (arg == "--perplexity") {
76+
params.perplexity = true;
7577
} else if (arg == "--ignore-eos") {
7678
params.ignore_eos = true;
7779
} else if (arg == "--n_parts") {
@@ -120,6 +122,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
120122
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
121123
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
122124
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
125+
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
123126
fprintf(stderr, " -m FNAME, --model FNAME\n");
124127
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
125128
fprintf(stderr, "\n");
@@ -596,7 +599,7 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
596599

597600
char * pdst = (char *) dst;
598601

599-
for (int j = 0; j < n; j += k) {
602+
for (int j = 0; j < n; j += k) {
600603
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
601604
uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
602605
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));
@@ -619,7 +622,7 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t
619622

620623
*(float *) pd = d;
621624
*(float *) pm = min;
622-
pd += bs;
625+
pd += bs;
623626
pm += bs;
624627

625628
for (int l = 0; l < qk; l += 2) {

utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ struct gpt_params {
4040
bool interactive_start = false; // reverse prompt immediately
4141
bool instruct = false; // instruction mode (used for Alpaca models)
4242
bool ignore_eos = false; // do not stop generating after eos
43+
bool perplexity = false; // compute perplexity over the prompt
4344
};
4445

4546
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

0 commit comments

Comments
 (0)