@@ -313,7 +313,7 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
313
313
}
314
314
}
315
315
316
- void llm_graph_input_attn_base ::set_input (const llama_ubatch * ubatch) {
316
+ void llm_graph_input_attn_no_cache ::set_input (const llama_ubatch * ubatch) {
317
317
if (kq_mask) {
318
318
if (cparams.causal_attn ) {
319
319
const int64_t n_kv = ubatch->n_tokens ;
@@ -400,7 +400,7 @@ void llm_graph_input_attn_base::set_input(const llama_ubatch * ubatch) {
400
400
}
401
401
}
402
402
403
- void llm_graph_input_attn_kv_self ::set_input (const llama_ubatch * ubatch) {
403
+ void llm_graph_input_attn_kv_unified ::set_input (const llama_ubatch * ubatch) {
404
404
if (self_kq_mask || self_kq_mask_swa) {
405
405
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
406
406
if (cparams.causal_attn ) {
@@ -523,9 +523,7 @@ void llm_graph_input_attn_kv_self::set_input(const llama_ubatch * ubatch) {
523
523
}
524
524
}
525
525
526
- void llm_graph_input_attn_dec::set_input (const llama_ubatch * ubatch) {
527
- inp_kv_self->set_input (ubatch);
528
-
526
+ void llm_graph_input_attn_cross::set_input (const llama_ubatch * ubatch) {
529
527
if (cross_kq_mask) {
530
528
const int64_t n_enc = cross_kq_mask->ne [0 ];
531
529
const int64_t n_tokens = ubatch->n_tokens ;
@@ -1066,7 +1064,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
1066
1064
auto & cur = inp->s_copy ;
1067
1065
1068
1066
cur = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_kv);
1069
- // cb(cur, "inp_s_copy", -1);
1070
1067
ggml_set_input (cur);
1071
1068
1072
1069
res->add_input (std::move (inp));
@@ -1084,7 +1081,6 @@ ggml_tensor * llm_graph_context::build_inp_s_mask() const {
1084
1081
auto & cur = inp->s_mask ;
1085
1082
1086
1083
cur = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, 1 , n_kv);
1087
- // cb(cur, "inp_s_mask", -1);
1088
1084
ggml_set_input (cur);
1089
1085
1090
1086
res->add_input (std::move (inp));
@@ -1151,15 +1147,11 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
1151
1147
cb (pos_bucket_1d, " pos_bucket_1d" , -1 );
1152
1148
1153
1149
ggml_tensor * pos_bias = ggml_get_rows (ctx0, attn_rel_b, pos_bucket_1d);
1154
- cb (pos_bias, " pos_bias" , -1 );
1155
1150
1156
1151
pos_bias = ggml_reshape_3d (ctx0, pos_bias, pos_bias->ne [0 ], pos_bucket->ne [0 ], pos_bucket->ne [1 ]);
1157
- cb (pos_bias, " pos_bias" , -1 );
1152
+ pos_bias = ggml_permute (ctx0, pos_bias, 2 , 0 , 1 , 3 );
1153
+ pos_bias = ggml_cont (ctx0, pos_bias);
1158
1154
1159
- pos_bias = ggml_permute (ctx0, pos_bias, 2 , 0 , 1 , 3 );
1160
- cb (pos_bias, " pos_bias" , -1 );
1161
-
1162
- pos_bias = ggml_cont (ctx0, pos_bias);
1163
1155
cb (pos_bias, " pos_bias" , -1 );
1164
1156
1165
1157
return pos_bias;
@@ -1257,26 +1249,21 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1257
1249
return cur;
1258
1250
}
1259
1251
1260
- llm_graph_input_attn_base * llm_graph_context::build_attn_inp_base (
1261
- bool causal,
1262
- bool swa) const {
1263
- auto inp = std::make_unique<llm_graph_input_attn_base>(hparams, cparams);
1252
+ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache () const {
1253
+ auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1264
1254
1265
1255
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1266
- GGML_UNUSED (causal);
1267
- GGML_UNUSED (swa);
1268
-
1269
1256
inp->kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1270
1257
// cb(inp_kq_mask, "KQ_mask", -1);
1271
1258
ggml_set_input (inp->kq_mask );
1272
1259
1273
1260
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
1274
1261
1275
- return (llm_graph_input_attn_base *) res->add_input (std::move (inp));
1262
+ return (llm_graph_input_attn_no_cache *) res->add_input (std::move (inp));
1276
1263
}
1277
1264
1278
1265
ggml_tensor * llm_graph_context::build_attn (
1279
- llm_graph_input_attn_base * inp,
1266
+ llm_graph_input_attn_no_cache * inp,
1280
1267
ggml_cgraph * gf,
1281
1268
ggml_tensor * wo,
1282
1269
ggml_tensor * wo_b,
@@ -1324,12 +1311,12 @@ ggml_tensor * llm_graph_context::build_attn(
1324
1311
return cur;
1325
1312
}
1326
1313
1327
- llm_graph_input_attn_kv_self * llm_graph_context::build_attn_inp_kv_self (
1314
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified (
1328
1315
bool causal,
1329
1316
bool swa) const {
1330
1317
const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1331
1318
1332
- auto inp = std::make_unique<llm_graph_input_attn_kv_self >(hparams, cparams, kv_self);
1319
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified >(hparams, cparams, kv_self);
1333
1320
1334
1321
const auto n_kv = kv_self->n ;
1335
1322
@@ -1353,11 +1340,11 @@ llm_graph_input_attn_kv_self * llm_graph_context::build_attn_inp_kv_self(
1353
1340
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
1354
1341
}
1355
1342
1356
- return (llm_graph_input_attn_kv_self *) res->add_input (std::move (inp));
1343
+ return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
1357
1344
}
1358
1345
1359
1346
ggml_tensor * llm_graph_context::build_attn (
1360
- llm_graph_input_attn_kv_self * inp,
1347
+ llm_graph_input_attn_kv_unified * inp,
1361
1348
ggml_cgraph * gf,
1362
1349
ggml_tensor * wo,
1363
1350
ggml_tensor * wo_b,
@@ -1490,12 +1477,8 @@ ggml_tensor * llm_graph_context::build_attn(
1490
1477
return cur;
1491
1478
}
1492
1479
1493
- llm_graph_input_attn_dec * llm_graph_context::build_attn_inp_dec (
1494
- bool causal,
1495
- bool swa) const {
1496
- auto * inp_kv_self = build_attn_inp_kv_self (causal, swa);
1497
-
1498
- auto inp = std::make_unique<llm_graph_input_attn_dec>(inp_kv_self, cross);
1480
+ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross () const {
1481
+ auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
1499
1482
1500
1483
const int32_t n_enc = !cross->v_embd .empty () ? cross->n_enc : hparams.n_ctx_train ;
1501
1484
@@ -1504,11 +1487,11 @@ llm_graph_input_attn_dec * llm_graph_context::build_attn_inp_dec(
1504
1487
1505
1488
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->cross_kq_mask , GGML_TYPE_F16) : inp->cross_kq_mask ;
1506
1489
1507
- return (llm_graph_input_attn_dec *) res->add_input (std::move (inp));
1490
+ return (llm_graph_input_attn_cross *) res->add_input (std::move (inp));
1508
1491
}
1509
1492
1510
1493
ggml_tensor * llm_graph_context::build_attn (
1511
- llm_graph_input_attn_dec * inp,
1494
+ llm_graph_input_attn_cross * inp,
1512
1495
ggml_cgraph * gf,
1513
1496
ggml_tensor * wo,
1514
1497
ggml_tensor * wo_b,
0 commit comments