@@ -2975,8 +2975,6 @@ static struct ggml_cgraph * llm_build_llama(
2975
2975
2976
2976
const auto & kv_self = lctx.kv_self ;
2977
2977
2978
- int32_t * run_layer = batch.run_layers ;
2979
-
2980
2978
GGML_ASSERT (!!kv_self.ctx );
2981
2979
2982
2980
const int64_t n_embd = hparams.n_embd ;
@@ -3132,12 +3130,27 @@ static struct ggml_cgraph * llm_build_llama(
3132
3130
}
3133
3131
}
3134
3132
3135
- for (int il_ = 0 ; il_ < n_layer; ++il_) {
3136
- int il = il_;
3133
+ int32_t * run_layer = batch.run_layers ;
3134
+ bool run_attn = false , run_mlp = false ;
3135
+ cur = inpL;
3136
+
3137
+ for (int il = 0 ; il < n_layer; ++il) {
3138
+ run_attn = run_mlp = true ;
3137
3139
if (run_layer != NULL ) {
3138
- il = *run_layer++;
3139
- if (il < 0 ) break ;
3140
+ if (*run_layer >= 0 ) {
3141
+ run_attn = (*run_layer & 1 ) == 0 ;
3142
+ run_mlp = (*run_layer & 2 ) == 0 ;
3143
+ run_layer++;
3144
+ } else {
3145
+ run_layer = NULL ;
3146
+ }
3147
+ } else if (ggml_allocr_is_measure (lctx.alloc ) && il == n_layer - 1 ) {
3148
+ // No idea why this is needed, but otherwise we run out of space
3149
+ // when skipping attn or mlp (but not both) on the last layer
3150
+ run_mlp = false ;
3140
3151
}
3152
+ if (!run_attn && !run_mlp) continue ;
3153
+
3141
3154
ggml_format_name (inpL, " layer_inp_%d" , il);
3142
3155
3143
3156
offload_func_t offload_func = llama_nop;
@@ -3148,10 +3161,11 @@ static struct ggml_cgraph * llm_build_llama(
3148
3161
}
3149
3162
#endif // GGML_USE_CUBLAS
3150
3163
3151
- struct ggml_tensor * inpSA = inpL ;
3164
+ struct ggml_tensor * inpFF = nullptr ;
3152
3165
3153
- // norm
3154
- {
3166
+ // self-attention
3167
+ if (run_attn) {
3168
+ // norm
3155
3169
cur = ggml_rms_norm (ctx0, inpL, norm_rms_eps);
3156
3170
offload_func (cur);
3157
3171
ggml_set_name (cur, " rms_norm_0" );
@@ -3160,10 +3174,7 @@ static struct ggml_cgraph * llm_build_llama(
3160
3174
cur = ggml_mul (ctx0, cur, model.layers [il].attn_norm );
3161
3175
offload_func (cur);
3162
3176
ggml_set_name (cur, " attention_norm_0" );
3163
- }
3164
3177
3165
- // self-attention
3166
- {
3167
3178
// compute Q and K and RoPE them
3168
3179
struct ggml_tensor * tmpk = ggml_mul_mat (ctx0, model.layers [il].wk , cur);
3169
3180
offload_func_kq (tmpk);
@@ -3280,25 +3291,25 @@ static struct ggml_cgraph * llm_build_llama(
3280
3291
cur);
3281
3292
offload_func (cur);
3282
3293
ggml_set_name (cur, " result_wo" );
3283
- }
3284
3294
3285
- struct ggml_tensor * inpFF = ggml_add (ctx0, cur, inpSA);
3286
- offload_func (inpFF);
3287
- ggml_set_name (inpFF, " inpFF" );
3295
+ inpFF = ggml_add (ctx0, cur, inpL);
3296
+ offload_func (inpFF);
3297
+ ggml_set_name (inpFF, " inpFF" );
3298
+ } else {
3299
+ inpFF = inpL;
3300
+ }
3288
3301
3289
3302
// feed-forward network
3290
- {
3303
+ if (run_mlp) {
3291
3304
// norm
3292
- {
3293
- cur = ggml_rms_norm (ctx0, inpFF, norm_rms_eps);
3294
- offload_func (cur);
3295
- ggml_set_name (cur, " rms_norm_1" );
3305
+ cur = ggml_rms_norm (ctx0, inpFF, norm_rms_eps);
3306
+ offload_func (cur);
3307
+ ggml_set_name (cur, " rms_norm_1" );
3296
3308
3297
- // cur = cur*ffn_norm(broadcasted)
3298
- cur = ggml_mul (ctx0, cur, model.layers [il].ffn_norm );
3299
- offload_func (cur);
3300
- ggml_set_name (cur, " ffn_norm" );
3301
- }
3309
+ // cur = cur*ffn_norm(broadcasted)
3310
+ cur = ggml_mul (ctx0, cur, model.layers [il].ffn_norm );
3311
+ offload_func (cur);
3312
+ ggml_set_name (cur, " ffn_norm" );
3302
3313
3303
3314
struct ggml_tensor * tmp = ggml_mul_mat (ctx0,
3304
3315
model.layers [il].w3 ,
@@ -3326,18 +3337,18 @@ static struct ggml_cgraph * llm_build_llama(
3326
3337
cur);
3327
3338
offload_func (cur);
3328
3339
ggml_set_name (cur, " result_w2" );
3329
- }
3330
3340
3331
- cur = ggml_add (ctx0, cur, inpFF);
3332
- offload_func (cur);
3333
- ggml_set_name (cur, " inpFF_+_result_w2" );
3341
+ cur = ggml_add (ctx0, cur, inpFF);
3342
+ offload_func (cur);
3343
+ ggml_set_name (cur, " inpFF_+_result_w2" );
3344
+ } else {
3345
+ cur = inpFF;
3346
+ }
3334
3347
3335
3348
// input for next layer
3336
3349
inpL = cur;
3337
3350
}
3338
3351
3339
- cur = inpL;
3340
-
3341
3352
// norm
3342
3353
{
3343
3354
cur = ggml_rms_norm (ctx0, cur, norm_rms_eps);
@@ -9351,7 +9362,6 @@ void llama_batch_free(struct llama_batch batch) {
9351
9362
if (batch.pos ) free (batch.pos );
9352
9363
if (batch.seq_id ) free (batch.seq_id );
9353
9364
if (batch.logits ) free (batch.logits );
9354
- if (batch.run_layers ) free (batch.run_layers );
9355
9365
}
9356
9366
9357
9367
int llama_decode (
0 commit comments