@@ -4315,6 +4315,7 @@ static struct ggml_tensor * llm_build_kqv(
4315
4315
const llama_model & model,
4316
4316
const llama_hparams & hparams,
4317
4317
const llama_kv_cache & kv,
4318
+ struct ggml_cgraph * graph,
4318
4319
struct ggml_tensor * wo,
4319
4320
struct ggml_tensor * wo_b,
4320
4321
struct ggml_tensor * q_cur,
@@ -4393,6 +4394,8 @@ static struct ggml_tensor * llm_build_kqv(
4393
4394
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
4394
4395
cb(cur, "kqv_merged_cont", il);
4395
4396
4397
+ ggml_build_forward_expand(graph, cur);
4398
+
4396
4399
cur = ggml_mul_mat(ctx, wo, cur);
4397
4400
if (wo_b) {
4398
4401
cb(cur, "kqv_wo", il);
@@ -4405,6 +4408,44 @@ static struct ggml_tensor * llm_build_kqv(
4405
4408
return cur;
4406
4409
}
4407
4410
4411
+ static struct ggml_tensor * llm_build_kv(
4412
+ struct ggml_context * ctx,
4413
+ const llama_model & model,
4414
+ const llama_hparams & hparams,
4415
+ const llama_kv_cache & kv,
4416
+ struct ggml_cgraph * graph,
4417
+ struct ggml_tensor * wo,
4418
+ struct ggml_tensor * wo_b,
4419
+ struct ggml_tensor * k_cur,
4420
+ struct ggml_tensor * v_cur,
4421
+ struct ggml_tensor * q_cur,
4422
+ struct ggml_tensor * kq_mask,
4423
+ int64_t n_ctx,
4424
+ int32_t n_tokens,
4425
+ int32_t kv_head,
4426
+ int32_t n_kv,
4427
+ float max_alibi_bias,
4428
+ float kq_scale,
4429
+ const llm_build_cb & cb,
4430
+ int il) {
4431
+
4432
+ // these nodes are added to the graph together so that they are not reordered
4433
+ // by doing so, the number of splits in the graph is reduced
4434
+ ggml_build_forward_expand(graph, k_cur);
4435
+ ggml_build_forward_expand(graph, v_cur);
4436
+ ggml_build_forward_expand(graph, q_cur);
4437
+
4438
+ llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
4439
+
4440
+ struct ggml_tensor * cur;
4441
+ cur = llm_build_kqv(ctx, model, hparams, kv, graph,
4442
+ wo, wo_b,
4443
+ q_cur, kq_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, kq_scale, cb, il);
4444
+ cb(cur, "kqv_out", il);
4445
+
4446
+ return cur;
4447
+ }
4448
+
4408
4449
struct llm_build_context {
4409
4450
const llama_model & model;
4410
4451
const llama_hparams & hparams;
@@ -4562,12 +4603,6 @@ struct llm_build_context {
4562
4603
cb(Vcur, "Vcur", il);
4563
4604
}
4564
4605
4565
- // these nodes are added to the graph together so that they are not reordered
4566
- // by doing so, the number of splits in the graph is reduced
4567
- ggml_build_forward_expand(gf, Qcur);
4568
- ggml_build_forward_expand(gf, Kcur);
4569
- ggml_build_forward_expand(gf, Vcur);
4570
-
4571
4606
Qcur = ggml_rope_custom(
4572
4607
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4573
4608
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
@@ -4582,11 +4617,9 @@ struct llm_build_context {
4582
4617
);
4583
4618
cb(Kcur, "Kcur", il);
4584
4619
4585
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
4586
-
4587
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
4620
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
4588
4621
model.layers[il].wo, model.layers[il].bo,
4589
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
4622
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
4590
4623
cb(cur, "kqv_out", il);
4591
4624
}
4592
4625
@@ -4763,14 +4796,13 @@ struct llm_build_context {
4763
4796
cb(Qcur, "Qcur", il);
4764
4797
cb(Kcur, "Kcur", il);
4765
4798
4766
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
4767
4799
4768
4800
// apply ALiBi for 13B model
4769
4801
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
4770
4802
4771
- cur = llm_build_kqv (ctx0, model, hparams, kv_self,
4803
+ cur = llm_build_kv (ctx0, model, hparams, kv_self, gf ,
4772
4804
model.layers[il].wo, NULL,
4773
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
4805
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
4774
4806
cb(cur, "kqv_out", il);
4775
4807
}
4776
4808
@@ -4892,11 +4924,9 @@ struct llm_build_context {
4892
4924
);
4893
4925
cb(Kcur, "Kcur", il);
4894
4926
4895
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
4896
-
4897
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
4927
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
4898
4928
model.layers[il].wo, NULL,
4899
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
4929
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
4900
4930
cb(cur, "kqv_out", il);
4901
4931
}
4902
4932
@@ -4993,11 +5023,9 @@ struct llm_build_context {
4993
5023
4994
5024
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
4995
5025
4996
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
4997
-
4998
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
5026
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
4999
5027
model.layers[il].wo, model.layers[il].bo,
5000
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5028
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5001
5029
cb(cur, "kqv_out", il);
5002
5030
}
5003
5031
@@ -5200,12 +5228,9 @@ struct llm_build_context {
5200
5228
);
5201
5229
cb(Vcur, "Vcur", il);
5202
5230
5203
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5204
-
5205
- // TODO: not tested, could be broken
5206
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
5231
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5207
5232
model.layers[il].wo, model.layers[il].bo,
5208
- Q, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5233
+ Kcur, Vcur, Q, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5209
5234
cb(cur, "kqv_out", il);
5210
5235
}
5211
5236
@@ -5292,11 +5317,9 @@ struct llm_build_context {
5292
5317
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5293
5318
cb(Qcur, "Qcur", il);
5294
5319
5295
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5296
-
5297
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
5320
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5298
5321
model.layers[il].wo, NULL,
5299
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5322
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5300
5323
cb(cur, "kqv_out", il);
5301
5324
}
5302
5325
@@ -5390,11 +5413,9 @@ struct llm_build_context {
5390
5413
5391
5414
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5392
5415
5393
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5394
-
5395
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
5416
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5396
5417
model.layers[il].wo, model.layers[il].bo,
5397
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5418
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5398
5419
cb(cur, "kqv_out", il);
5399
5420
}
5400
5421
@@ -5485,11 +5506,9 @@ struct llm_build_context {
5485
5506
5486
5507
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
5487
5508
5488
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5489
-
5490
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
5509
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5491
5510
model.layers[il].wo, NULL,
5492
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5511
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5493
5512
cb(cur, "kqv_out", il);
5494
5513
}
5495
5514
@@ -5597,11 +5616,9 @@ struct llm_build_context {
5597
5616
);
5598
5617
cb(Kcur, "Kcur", il);
5599
5618
5600
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5601
-
5602
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
5619
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5603
5620
model.layers[il].wo, NULL,
5604
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5621
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5605
5622
cb(cur, "kqv_out", il);
5606
5623
}
5607
5624
@@ -5714,11 +5731,9 @@ struct llm_build_context {
5714
5731
);
5715
5732
cb(Kcur, "Kcur", il);
5716
5733
5717
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5718
-
5719
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
5734
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5720
5735
model.layers[il].wo, NULL,
5721
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5736
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5722
5737
cb(cur, "kqv_out", il);
5723
5738
}
5724
5739
@@ -5837,11 +5852,9 @@ struct llm_build_context {
5837
5852
);
5838
5853
cb(Kcur, "Kcur", il);
5839
5854
5840
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5841
-
5842
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
5855
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5843
5856
model.layers[il].wo, model.layers[il].bo,
5844
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5857
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
5845
5858
cb(cur, "kqv_out", il);
5846
5859
}
5847
5860
@@ -5966,11 +5979,9 @@ struct llm_build_context {
5966
5979
);
5967
5980
cb(Kcur, "Kcur", il);
5968
5981
5969
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
5970
-
5971
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
5982
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
5972
5983
model.layers[il].wo, model.layers[il].bo,
5973
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f, cb, il);
5984
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f, cb, il);
5974
5985
cb(cur, "kqv_out", il);
5975
5986
}
5976
5987
@@ -6071,11 +6082,9 @@ struct llm_build_context {
6071
6082
ext_factor, attn_factor, beta_fast, beta_slow);
6072
6083
cb(Kcur, "Kcur", il);
6073
6084
6074
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
6075
-
6076
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
6085
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6077
6086
model.layers[il].wo, NULL,
6078
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6087
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6079
6088
cb(cur, "kqv_out", il);
6080
6089
}
6081
6090
struct ggml_tensor * sa_out = cur;
@@ -6172,11 +6181,9 @@ struct llm_build_context {
6172
6181
6173
6182
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
6174
6183
6175
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
6176
-
6177
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
6184
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6178
6185
model.layers[il].wo, model.layers[il].bo,
6179
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6186
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6180
6187
cb(cur, "kqv_out", il);
6181
6188
}
6182
6189
@@ -6283,11 +6290,9 @@ struct llm_build_context {
6283
6290
);
6284
6291
cb(Kcur, "Kcur", il);
6285
6292
6286
- llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
6287
-
6288
- cur = llm_build_kqv(ctx0, model, hparams, kv_self,
6293
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
6289
6294
model.layers[il].wo, model.layers[il].bo,
6290
- Qcur, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6295
+ Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head , n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
6291
6296
cb(cur, "kqv_out", il);
6292
6297
}
6293
6298
@@ -6355,6 +6360,14 @@ static struct ggml_cgraph * llama_build_graph(
6355
6360
ggml_set_name(cur, name);
6356
6361
}
6357
6362
6363
+
6364
+ if (!lctx.cparams.offload_kqv) {
6365
+ if (strcmp(name, "kqv_merged_cont") == 0) {
6366
+ // all nodes between the KV store and the attention output are run on the CPU
6367
+ ggml_backend_sched_set_node_backend(lctx.sched, cur, lctx.backend_cpu);
6368
+ }
6369
+ }
6370
+
6358
6371
//
6359
6372
// allocate input tensors and set input data
6360
6373
//
0 commit comments