Skip to content

Commit 8a569cf

Browse files
committed
perplexity anti-mode improvements
1 parent d6b44fb commit 8a569cf

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -352,13 +352,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
352352
std::vector<int32_t> extremes;
353353
extremes.resize(n_layers);
354354
std::fill(extremes.begin(), extremes.end(), 0);
355-
if (anti_mode) {
356-
// No pointing in starting with first/last layer disabled.
357-
skip_types[0] = 15;
358-
skip_types[n_layers - 1] = 15;
359-
skips.push_back(0); skips.push_back(0 + n_layers);
360-
skips.push_back(n_layers - 1); skips.push_back(n_layers - 1 + n_layers);
361-
}
355+
// if (anti_mode) {
356+
// // No point in starting with first/last layer disabled.
357+
// skip_types[0] = 15;
358+
// skip_types[n_layers - 1] = 15;
359+
// skips.push_back(0); skips.push_back(0 + n_layers);
360+
// skips.push_back(n_layers - 1); skips.push_back(n_layers - 1 + n_layers);
361+
// }
362362
int32_t curr_best_layer = -1, curr_best_type = 0;
363363
double curr_best_ppl = -1, ref_ppl = -1;
364364
const int32_t mask = anti_mode ? 3 : 0;
@@ -389,7 +389,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
389389
}
390390
if (skip_layer >= n_layers) {
391391
if (curr_best_layer == -1) break;
392-
if (prune_target > 0 && pass_results.size() >= prune_target * 2) {
392+
if (anti_mode || (prune_target > 0 && pass_results.size() >= prune_target * 2)) {
393393
std::sort(pass_results.begin(), pass_results.end(),
394394
[](const std::tuple<int32_t, int32_t, double> & a, const std::tuple<int32_t, int32_t, double> & b) {
395395
if (anti_mode) return std::get<2>(b) > std::get<2>(a);
@@ -399,24 +399,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
399399
const size_t num_prune = std::min(pass_results.size(), prune_target);
400400
for (size_t temp = 0, pruned = 0; temp < pass_results.size(); temp++) {
401401
int32_t lidx = std::get<0>(pass_results[temp]);
402-
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
403-
extremes[lidx] |= std::get<1>(pass_results[temp]);
404-
printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx,
405-
std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
406402
if (anti_mode) {
407403
skip_types[lidx] |= std::get<1>(pass_results[temp]);
408404
skips.push_back(std::get<1>(pass_results[temp]) == 1 ? lidx : lidx + n_layers);
409405
}
406+
if (lidx == curr_best_layer && std::get<1>(pass_results[temp]) == curr_best_type) continue;
407+
extremes[lidx] |= std::get<1>(pass_results[temp]);
408+
printf("\nPrune[%zu]: %d (%d) - %.2f\n", pruned + 1, lidx,
409+
std::get<1>(pass_results[temp]), std::get<2>(pass_results[temp]));
410410
if (++pruned >= num_prune) break;
411411
}
412412
}
413413
pass_results.clear();
414414
printf("\n\nADD %c%3d - ppl vs ref %.4f",
415415
int(label[curr_best_type]), curr_best_layer,
416416
curr_best_ppl - ref_ppl);
417-
if (!anti_mode && curr_best_ppl > ref_ppl * 1.75) break;
418-
skip_types[curr_best_layer] += curr_best_type;
419-
skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
417+
if (!anti_mode) {
418+
if (curr_best_ppl > ref_ppl * 1.75) break;
419+
skip_types[curr_best_layer] += curr_best_type;
420+
skips.push_back(curr_best_type == 1 ? curr_best_layer : curr_best_layer + n_layers);
421+
}
420422
curr_best_layer = -1;
421423
curr_best_ppl = -1;
422424
curr_best_type = 0;

0 commit comments

Comments
 (0)