Skip to content

Commit b5fe67f

Browse files
authored
Perplexity: Compute scores correlated to HellaSwag (#2312)
* Add parameter --perplexity-lines to perplexity.cpp
1 parent 24baa54 commit b5fe67f

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

examples/common.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
387387
params.antiprompt.push_back(argv[i]);
388388
} else if (arg == "--perplexity") {
389389
params.perplexity = true;
390+
} else if (arg == "--perplexity-lines") {
391+
params.perplexity_lines = true;
390392
} else if (arg == "--ignore-eos") {
391393
params.logit_bias[llama_token_eos()] = -INFINITY;
392394
} else if (arg == "--no-penalize-nl") {
@@ -512,7 +514,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
512514
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
513515
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
514516
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
515-
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
517+
fprintf(stderr, " --perplexity compute perplexity over each ctx window of the prompt\n");
518+
fprintf(stderr, " --perplexity-lines compute perplexity over each line of the prompt\n");
516519
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
517520
fprintf(stderr, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
518521
if (llama_mlock_supported()) {

examples/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ struct gpt_params {
8282
bool instruct = false; // instruction mode (used for Alpaca models)
8383
bool penalize_nl = true; // consider newlines as a repeatable token
8484
bool perplexity = false; // compute perplexity over the prompt
85+
bool perplexity_lines = false; // compute perplexity over each line of the prompt
8586
bool use_mmap = true; // use mmap for faster loads
8687
bool use_mlock = false; // use mlock to keep model in memory
8788
bool mem_test = false; // compute maximum memory usage

examples/perplexity/perplexity.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <cmath>
66
#include <ctime>
7+
#include <sstream>
78

89
#if defined(_MSC_VER)
910
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -120,6 +121,77 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
120121
printf("\n");
121122
}
122123

124+
void perplexity_lines(llama_context * ctx, const gpt_params & params) {
125+
// Calculates perplexity over each line of the prompt
126+
127+
std::vector<std::string> prompt_lines;
128+
std::istringstream strstream(params.prompt);
129+
std::string line;
130+
131+
while (std::getline(strstream,line,'\n')) {
132+
prompt_lines.push_back(line);
133+
}
134+
135+
const int n_vocab = llama_n_vocab(ctx);
136+
137+
int counttotal = 0;
138+
size_t n_lines = prompt_lines.size();
139+
140+
double nll = 0.0;
141+
142+
fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines);
143+
144+
printf("\nLine\tPPL line\tPPL cumulative\n");
145+
146+
for (size_t i = 0; i < n_lines; ++i) {
147+
148+
// Tokenize and insert BOS at start
149+
std::vector<int> batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true);
150+
151+
size_t batch_size = batch_embd.size();
152+
153+
// Stop if line is too long
154+
if( batch_size > (size_t)params.n_ctx ) {
155+
fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i);
156+
return;
157+
}
158+
159+
if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) {
160+
fprintf(stderr, "%s : failed to eval\n", __func__);
161+
return;
162+
}
163+
164+
const auto batch_logits = llama_get_logits(ctx);
165+
std::vector<float> logits;
166+
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
167+
168+
double nllline = 0.0;
169+
int countline = 0;
170+
171+
// Perplexity over second half of the line
172+
for (size_t j = batch_size/2; j < batch_size - 1; ++j) {
173+
// Calculate probability of next token, given the previous ones.
174+
const std::vector<float> tok_logits(
175+
logits.begin() + (j + 0) * n_vocab,
176+
logits.begin() + (j + 1) * n_vocab);
177+
178+
const float prob = softmax(tok_logits)[batch_embd[ j + 1]];
179+
180+
nllline += -std::log(prob);
181+
++countline;
182+
}
183+
184+
nll += nllline;
185+
counttotal += countline;
186+
187+
// perplexity is e^(average negative log-likelihood)
188+
printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) );
189+
fflush(stdout);
190+
}
191+
192+
printf("\n");
193+
}
194+
123195
int main(int argc, char ** argv) {
124196
gpt_params params;
125197

@@ -168,7 +240,11 @@ int main(int argc, char ** argv) {
168240
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
169241
}
170242

171-
perplexity(ctx, params);
243+
if (params.perplexity_lines) {
244+
perplexity_lines(ctx, params);
245+
} else {
246+
perplexity(ctx, params);
247+
}
172248

173249
llama_print_timings(ctx);
174250
llama_free(ctx);

0 commit comments

Comments
 (0)