|
4 | 4 |
|
5 | 5 | #include <cmath>
|
6 | 6 | #include <ctime>
|
| 7 | +#include <sstream> |
7 | 8 |
|
8 | 9 | #if defined(_MSC_VER)
|
9 | 10 | #pragma warning(disable: 4244 4267) // possible loss of data
|
@@ -120,6 +121,77 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
120 | 121 | printf("\n");
|
121 | 122 | }
|
122 | 123 |
|
| 124 | +void perplexity_lines(llama_context * ctx, const gpt_params & params) { |
| 125 | + // Calculates perplexity over each line of the prompt |
| 126 | + |
| 127 | + std::vector<std::string> prompt_lines; |
| 128 | + std::istringstream strstream(params.prompt); |
| 129 | + std::string line; |
| 130 | + |
| 131 | + while (std::getline(strstream,line,'\n')) { |
| 132 | + prompt_lines.push_back(line); |
| 133 | + } |
| 134 | + |
| 135 | + const int n_vocab = llama_n_vocab(ctx); |
| 136 | + |
| 137 | + int counttotal = 0; |
| 138 | + size_t n_lines = prompt_lines.size(); |
| 139 | + |
| 140 | + double nll = 0.0; |
| 141 | + |
| 142 | + fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines); |
| 143 | + |
| 144 | + printf("\nLine\tPPL line\tPPL cumulative\n"); |
| 145 | + |
| 146 | + for (size_t i = 0; i < n_lines; ++i) { |
| 147 | + |
| 148 | + // Tokenize and insert BOS at start |
| 149 | + std::vector<int> batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true); |
| 150 | + |
| 151 | + size_t batch_size = batch_embd.size(); |
| 152 | + |
| 153 | + // Stop if line is too long |
| 154 | + if( batch_size > (size_t)params.n_ctx ) { |
| 155 | + fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i); |
| 156 | + return; |
| 157 | + } |
| 158 | + |
| 159 | + if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) { |
| 160 | + fprintf(stderr, "%s : failed to eval\n", __func__); |
| 161 | + return; |
| 162 | + } |
| 163 | + |
| 164 | + const auto batch_logits = llama_get_logits(ctx); |
| 165 | + std::vector<float> logits; |
| 166 | + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); |
| 167 | + |
| 168 | + double nllline = 0.0; |
| 169 | + int countline = 0; |
| 170 | + |
| 171 | + // Perplexity over second half of the line |
| 172 | + for (size_t j = batch_size/2; j < batch_size - 1; ++j) { |
| 173 | + // Calculate probability of next token, given the previous ones. |
| 174 | + const std::vector<float> tok_logits( |
| 175 | + logits.begin() + (j + 0) * n_vocab, |
| 176 | + logits.begin() + (j + 1) * n_vocab); |
| 177 | + |
| 178 | + const float prob = softmax(tok_logits)[batch_embd[ j + 1]]; |
| 179 | + |
| 180 | + nllline += -std::log(prob); |
| 181 | + ++countline; |
| 182 | + } |
| 183 | + |
| 184 | + nll += nllline; |
| 185 | + counttotal += countline; |
| 186 | + |
| 187 | + // perplexity is e^(average negative log-likelihood) |
| 188 | + printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) ); |
| 189 | + fflush(stdout); |
| 190 | + } |
| 191 | + |
| 192 | + printf("\n"); |
| 193 | +} |
| 194 | + |
123 | 195 | int main(int argc, char ** argv) {
|
124 | 196 | gpt_params params;
|
125 | 197 |
|
@@ -168,7 +240,11 @@ int main(int argc, char ** argv) {
|
168 | 240 | params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
169 | 241 | }
|
170 | 242 |
|
171 |
| - perplexity(ctx, params); |
| 243 | + if (params.perplexity_lines) { |
| 244 | + perplexity_lines(ctx, params); |
| 245 | + } else { |
| 246 | + perplexity(ctx, params); |
| 247 | + } |
172 | 248 |
|
173 | 249 | llama_print_timings(ctx);
|
174 | 250 | llama_free(ctx);
|
|
0 commit comments