Skip to content

Commit 599ccda

Browse files
committed
Iteratively skip the layer with the least impact
1 parent abb1d12 commit 599ccda

File tree

1 file changed

+55
-18
lines changed

1 file changed

+55
-18
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "common.h"
33
#include "llama.h"
44

5+
#include <algorithm>
56
#include <cmath>
67
#include <cstdio>
78
#include <cstring>
@@ -321,12 +322,17 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
321322
const int n_batch = params.n_batch;
322323

323324
llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);
325+
326+
const int32_t n_layers = 32;
327+
const int test_count = 15;
324328
std::vector<int32_t> layers;
325-
const int32_t n_layers = 26;
326329
layers.resize(n_layers + 1);
327330
std::iota(layers.begin(), layers.end(), 0);
328331
batch.run_layers = layers.data();
329-
int32_t skip_layer = 0;
332+
int32_t skip_layer = -1;
333+
std::vector<int32_t> skips;
334+
int32_t curr_best_layer = -1;
335+
double curr_best_ppl = -1, ref_ppl = -1;
330336

331337
int count = 0;
332338
double nll = 0.0;
@@ -337,22 +343,44 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
337343
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
338344

339345
for (int i = 0; i < n_chunk; ++i) {
340-
if (i > 0 && i % 20 == 0) {
341-
if (skip_layer >= n_layers) break;
346+
if (i > 0 && i % test_count == 0) {
347+
for (int32_t new_sl = skip_layer + 1; new_sl <= n_layers; new_sl++) {
348+
if (std::find(skips.begin(), skips.end(), new_sl) != skips.end()) continue;
349+
skip_layer = new_sl;
350+
break;
351+
}
352+
if (skip_layer >= n_layers) {
353+
if (curr_best_layer == -1) break;
354+
printf("\n\nADD SKIP %3d - ppl vs ref %.4f", curr_best_layer, curr_best_ppl - ref_ppl);
355+
if (curr_best_ppl >= ref_ppl * 5) break;
356+
skips.push_back(curr_best_layer);
357+
curr_best_layer = -1;
358+
curr_best_ppl = -1;
359+
skip_layer = -1;
360+
for (int32_t new_sl = skip_layer + 1; new_sl <= n_layers; new_sl++) {
361+
if (std::find(skips.begin(), skips.end(), new_sl) != skips.end()) continue;
362+
skip_layer = new_sl;
363+
break;
364+
}
365+
if (skip_layer == -1 || skip_layer == n_layers) break;
366+
}
342367
i = 0;
343368
count = 0;
344369
nll = 0;
345370
nll2 = 0;
346371
logit_history.clear();
347372
prob_history.clear();
348373

349-
for (int32_t i = 0, ic = 0; i < n_layers; i++) {
350-
if (i == skip_layer) continue;
374+
int32_t ic = 0;
375+
for (int32_t i = 0; i < n_layers; i++) {
376+
if (i == skip_layer || std::find(skips.begin(), skips.end(), i) != skips.end()) continue;
351377
layers[ic++] = i;
352378
}
353-
layers[n_layers - 1] = -1; // we skipped 1
354-
printf("\nSKIPPING: %d\n", skip_layer);
355-
skip_layer++;
379+
if (ic == 0) break;
380+
layers[ic] = -1;
381+
printf("\nSKIP %3d + [", skip_layer);
382+
for (const auto l : skips) printf("%d,", l);
383+
printf("] - len: %3zu, best:(%3d: %.3f)\n", skips.size() + 1, curr_best_layer, curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0);
356384
}
357385
const int start = i * n_ctx;
358386
const int end = start + n_ctx;
@@ -396,7 +424,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
396424

397425
const auto t_end = std::chrono::high_resolution_clock::now();
398426

399-
if (i == 0 && skip_layer == 0) {
427+
if (i == 0 && skip_layer < 0 && skips.empty()) {
400428
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
401429
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
402430
int total_seconds = (int)(t_total * n_chunk);
@@ -425,15 +453,24 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
425453
count += n_ctx - first - 1;
426454

427455
// perplexity is e^(average negative log-likelihood)
428-
if (params.ppl_output_type == 0) {
429-
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
430-
} else {
431-
double av = nll/count;
432-
double av2 = nll2/count - av*av;
433-
if (av2 > 0) av2 = sqrt(av2/(count-1));
434-
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
456+
// if (params.ppl_output_type == 0) {
457+
// printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
458+
// } else {
459+
// double av = nll/count;
460+
// double av2 = nll2/count - av*av;
461+
// if (av2 > 0) av2 = sqrt(av2/(count-1));
462+
// printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
463+
// }
464+
// fflush(stdout);
465+
if (skip_layer >= 0 && i + 1 == test_count) {
466+
double ppl = std::exp(nll / count);
467+
if (curr_best_layer == -1 || ppl < curr_best_ppl) {
468+
curr_best_layer = skip_layer;
469+
curr_best_ppl = ppl;
470+
}
471+
} else if (skip_layer < 0) {
472+
ref_ppl = std::exp(nll / count);
435473
}
436-
fflush(stdout);
437474
}
438475
printf("\n");
439476

0 commit comments

Comments
 (0)