Skip to content

Commit d3c08ea

Browse files
committed
Allow independent skipping of attention and MLP
1 parent 599ccda commit d3c08ea

File tree

2 files changed

+72
-51
lines changed

2 files changed

+72
-51
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
323323

324324
llama_batch batch = llama_batch_get_one(NULL, 0, 0, 0);
325325

326-
const int32_t n_layers = 32;
326+
const int32_t n_layers = 26;
327327
const int test_count = 15;
328+
// 1 = attn, 2 = mlp, 3 = both
329+
int32_t test_skip_type = 1;
328330
std::vector<int32_t> layers;
329331
layers.resize(n_layers + 1);
330-
std::iota(layers.begin(), layers.end(), 0);
332+
std::fill(layers.begin(), layers.end(), 0);
331333
batch.run_layers = layers.data();
332334
int32_t skip_layer = -1;
333335
std::vector<int32_t> skips;
@@ -342,9 +344,12 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
342344

343345
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
344346

347+
auto test_t_start = std::chrono::high_resolution_clock::now();
345348
for (int i = 0; i < n_chunk; ++i) {
346349
if (i > 0 && i % test_count == 0) {
347-
for (int32_t new_sl = skip_layer + 1; new_sl <= n_layers; new_sl++) {
350+
auto test_t_end = std::chrono::high_resolution_clock::now();
351+
float test_t_total = std::chrono::duration<float>(test_t_end - test_t_start).count();
352+
for (int32_t new_sl = std::max(0, skip_layer + 1); new_sl <= n_layers ; new_sl++) {
348353
if (std::find(skips.begin(), skips.end(), new_sl) != skips.end()) continue;
349354
skip_layer = new_sl;
350355
break;
@@ -371,16 +376,22 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
371376
logit_history.clear();
372377
prob_history.clear();
373378

374-
int32_t ic = 0;
375379
for (int32_t i = 0; i < n_layers; i++) {
376-
if (i == skip_layer || std::find(skips.begin(), skips.end(), i) != skips.end()) continue;
377-
layers[ic++] = i;
380+
if (i == skip_layer || std::find(skips.begin(), skips.end(), i) != skips.end()) {
381+
layers[i] = test_skip_type;
382+
} else {
383+
layers[i] = 0;
384+
}
378385
}
379-
if (ic == 0) break;
380-
layers[ic] = -1;
386+
layers[n_layers] = -1;
381387
printf("\nSKIP %3d + [", skip_layer);
382388
for (const auto l : skips) printf("%d,", l);
383-
printf("] - len: %3zu, best:(%3d: %.3f)\n", skips.size() + 1, curr_best_layer, curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0);
389+
printf("] - len: %3zu, best:(%3d: %.3f), took %.2f sec\n",
390+
skips.size() + 1,
391+
curr_best_layer,
392+
curr_best_ppl != -1 ? curr_best_ppl - ref_ppl : 0,
393+
test_t_total);
394+
test_t_start = std::chrono::high_resolution_clock::now();
384395
}
385396
const int start = i * n_ctx;
386397
const int end = start + n_ctx;
@@ -453,15 +464,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
453464
count += n_ctx - first - 1;
454465

455466
// perplexity is e^(average negative log-likelihood)
456-
// if (params.ppl_output_type == 0) {
457-
// printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
458-
// } else {
459-
// double av = nll/count;
460-
// double av2 = nll2/count - av*av;
461-
// if (av2 > 0) av2 = sqrt(av2/(count-1));
462-
// printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
463-
// }
464-
// fflush(stdout);
467+
if (params.ppl_output_type == 0) {
468+
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
469+
} else {
470+
double av = nll/count;
471+
double av2 = nll2/count - av*av;
472+
if (av2 > 0) av2 = sqrt(av2/(count-1));
473+
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
474+
}
475+
fflush(stdout);
465476
if (skip_layer >= 0 && i + 1 == test_count) {
466477
double ppl = std::exp(nll / count);
467478
if (curr_best_layer == -1 || ppl < curr_best_ppl) {

llama.cpp

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2975,8 +2975,6 @@ static struct ggml_cgraph * llm_build_llama(
29752975

29762976
const auto & kv_self = lctx.kv_self;
29772977

2978-
int32_t * run_layer = batch.run_layers;
2979-
29802978
GGML_ASSERT(!!kv_self.ctx);
29812979

29822980
const int64_t n_embd = hparams.n_embd;
@@ -3132,12 +3130,27 @@ static struct ggml_cgraph * llm_build_llama(
31323130
}
31333131
}
31343132

3135-
for (int il_ = 0; il_ < n_layer; ++il_) {
3136-
int il = il_;
3133+
int32_t * run_layer = batch.run_layers;
3134+
bool run_attn = false, run_mlp = false;
3135+
cur = inpL;
3136+
3137+
for (int il = 0; il < n_layer; ++il) {
3138+
run_attn = run_mlp = true;
31373139
if (run_layer != NULL) {
3138-
il = *run_layer++;
3139-
if (il < 0) break;
3140+
if (*run_layer >= 0) {
3141+
run_attn = (*run_layer & 1) == 0;
3142+
run_mlp = (*run_layer & 2) == 0;
3143+
run_layer++;
3144+
} else {
3145+
run_layer = NULL;
3146+
}
3147+
} else if (ggml_allocr_is_measure(lctx.alloc) && il == n_layer - 1) {
3148+
// No idea why this is needed, but otherwise we run out of space
3149+
// when skipping attn or mlp (but not both) on the last layer
3150+
run_mlp = false;
31403151
}
3152+
if (!run_attn && !run_mlp) continue;
3153+
31413154
ggml_format_name(inpL, "layer_inp_%d", il);
31423155

31433156
offload_func_t offload_func = llama_nop;
@@ -3148,10 +3161,11 @@ static struct ggml_cgraph * llm_build_llama(
31483161
}
31493162
#endif // GGML_USE_CUBLAS
31503163

3151-
struct ggml_tensor * inpSA = inpL;
3164+
struct ggml_tensor * inpFF = nullptr;
31523165

3153-
// norm
3154-
{
3166+
// self-attention
3167+
if (run_attn) {
3168+
// norm
31553169
cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
31563170
offload_func(cur);
31573171
ggml_set_name(cur, "rms_norm_0");
@@ -3160,10 +3174,7 @@ static struct ggml_cgraph * llm_build_llama(
31603174
cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
31613175
offload_func(cur);
31623176
ggml_set_name(cur, "attention_norm_0");
3163-
}
31643177

3165-
// self-attention
3166-
{
31673178
// compute Q and K and RoPE them
31683179
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
31693180
offload_func_kq(tmpk);
@@ -3280,25 +3291,25 @@ static struct ggml_cgraph * llm_build_llama(
32803291
cur);
32813292
offload_func(cur);
32823293
ggml_set_name(cur, "result_wo");
3283-
}
32843294

3285-
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
3286-
offload_func(inpFF);
3287-
ggml_set_name(inpFF, "inpFF");
3295+
inpFF = ggml_add(ctx0, cur, inpL);
3296+
offload_func(inpFF);
3297+
ggml_set_name(inpFF, "inpFF");
3298+
} else {
3299+
inpFF = inpL;
3300+
}
32883301

32893302
// feed-forward network
3290-
{
3303+
if (run_mlp) {
32913304
// norm
3292-
{
3293-
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
3294-
offload_func(cur);
3295-
ggml_set_name(cur, "rms_norm_1");
3305+
cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
3306+
offload_func(cur);
3307+
ggml_set_name(cur, "rms_norm_1");
32963308

3297-
// cur = cur*ffn_norm(broadcasted)
3298-
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
3299-
offload_func(cur);
3300-
ggml_set_name(cur, "ffn_norm");
3301-
}
3309+
// cur = cur*ffn_norm(broadcasted)
3310+
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
3311+
offload_func(cur);
3312+
ggml_set_name(cur, "ffn_norm");
33023313

33033314
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
33043315
model.layers[il].w3,
@@ -3326,18 +3337,18 @@ static struct ggml_cgraph * llm_build_llama(
33263337
cur);
33273338
offload_func(cur);
33283339
ggml_set_name(cur, "result_w2");
3329-
}
33303340

3331-
cur = ggml_add(ctx0, cur, inpFF);
3332-
offload_func(cur);
3333-
ggml_set_name(cur, "inpFF_+_result_w2");
3341+
cur = ggml_add(ctx0, cur, inpFF);
3342+
offload_func(cur);
3343+
ggml_set_name(cur, "inpFF_+_result_w2");
3344+
} else {
3345+
cur = inpFF;
3346+
}
33343347

33353348
// input for next layer
33363349
inpL = cur;
33373350
}
33383351

3339-
cur = inpL;
3340-
33413352
// norm
33423353
{
33433354
cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
@@ -9351,7 +9362,6 @@ void llama_batch_free(struct llama_batch batch) {
93519362
if (batch.pos) free(batch.pos);
93529363
if (batch.seq_id) free(batch.seq_id);
93539364
if (batch.logits) free(batch.logits);
9354-
if (batch.run_layers) free(batch.run_layers);
93559365
}
93569366

93579367
int llama_decode(

0 commit comments

Comments
 (0)