Skip to content

Commit bcee6dd

Browse files
committed
graph : clean-up
ggml-ci
1 parent c326074 commit bcee6dd

File tree

3 files changed

+86
-124
lines changed

3 files changed

+86
-124
lines changed

src/llama-graph.cpp

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
313313
}
314314
}
315315

316-
void llm_graph_input_attn_base::set_input(const llama_ubatch * ubatch) {
316+
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
317317
if (kq_mask) {
318318
if (cparams.causal_attn) {
319319
const int64_t n_kv = ubatch->n_tokens;
@@ -400,7 +400,7 @@ void llm_graph_input_attn_base::set_input(const llama_ubatch * ubatch) {
400400
}
401401
}
402402

403-
void llm_graph_input_attn_kv_self::set_input(const llama_ubatch * ubatch) {
403+
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
404404
if (self_kq_mask || self_kq_mask_swa) {
405405
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
406406
if (cparams.causal_attn) {
@@ -523,9 +523,7 @@ void llm_graph_input_attn_kv_self::set_input(const llama_ubatch * ubatch) {
523523
}
524524
}
525525

526-
void llm_graph_input_attn_dec::set_input(const llama_ubatch * ubatch) {
527-
inp_kv_self->set_input(ubatch);
528-
526+
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
529527
if (cross_kq_mask) {
530528
const int64_t n_enc = cross_kq_mask->ne[0];
531529
const int64_t n_tokens = ubatch->n_tokens;
@@ -1066,7 +1064,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
10661064
auto & cur = inp->s_copy;
10671065

10681066
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
1069-
//cb(cur, "inp_s_copy", -1);
10701067
ggml_set_input(cur);
10711068

10721069
res->add_input(std::move(inp));
@@ -1084,7 +1081,6 @@ ggml_tensor * llm_graph_context::build_inp_s_mask() const {
10841081
auto & cur = inp->s_mask;
10851082

10861083
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
1087-
//cb(cur, "inp_s_mask", -1);
10881084
ggml_set_input(cur);
10891085

10901086
res->add_input(std::move(inp));
@@ -1151,15 +1147,11 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
11511147
cb(pos_bucket_1d, "pos_bucket_1d", -1);
11521148

11531149
ggml_tensor * pos_bias = ggml_get_rows(ctx0, attn_rel_b, pos_bucket_1d);
1154-
cb(pos_bias, "pos_bias", -1);
11551150

11561151
pos_bias = ggml_reshape_3d(ctx0, pos_bias, pos_bias->ne[0], pos_bucket->ne[0], pos_bucket->ne[1]);
1157-
cb(pos_bias, "pos_bias", -1);
1152+
pos_bias = ggml_permute (ctx0, pos_bias, 2, 0, 1, 3);
1153+
pos_bias = ggml_cont (ctx0, pos_bias);
11581154

1159-
pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3);
1160-
cb(pos_bias, "pos_bias", -1);
1161-
1162-
pos_bias = ggml_cont(ctx0, pos_bias);
11631155
cb(pos_bias, "pos_bias", -1);
11641156

11651157
return pos_bias;
@@ -1257,26 +1249,21 @@ ggml_tensor * llm_graph_context::build_attn_mha(
12571249
return cur;
12581250
}
12591251

1260-
llm_graph_input_attn_base * llm_graph_context::build_attn_inp_base(
1261-
bool causal,
1262-
bool swa) const {
1263-
auto inp = std::make_unique<llm_graph_input_attn_base>(hparams, cparams);
1252+
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const {
1253+
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
12641254

12651255
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1266-
GGML_UNUSED(causal);
1267-
GGML_UNUSED(swa);
1268-
12691256
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
12701257
//cb(inp_kq_mask, "KQ_mask", -1);
12711258
ggml_set_input(inp->kq_mask);
12721259

12731260
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
12741261

1275-
return (llm_graph_input_attn_base *) res->add_input(std::move(inp));
1262+
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
12761263
}
12771264

12781265
ggml_tensor * llm_graph_context::build_attn(
1279-
llm_graph_input_attn_base * inp,
1266+
llm_graph_input_attn_no_cache * inp,
12801267
ggml_cgraph * gf,
12811268
ggml_tensor * wo,
12821269
ggml_tensor * wo_b,
@@ -1324,12 +1311,12 @@ ggml_tensor * llm_graph_context::build_attn(
13241311
return cur;
13251312
}
13261313

1327-
llm_graph_input_attn_kv_self * llm_graph_context::build_attn_inp_kv_self(
1314+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(
13281315
bool causal,
13291316
bool swa) const {
13301317
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
13311318

1332-
auto inp = std::make_unique<llm_graph_input_attn_kv_self>(hparams, cparams, kv_self);
1319+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
13331320

13341321
const auto n_kv = kv_self->n;
13351322

@@ -1353,11 +1340,11 @@ llm_graph_input_attn_kv_self * llm_graph_context::build_attn_inp_kv_self(
13531340
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
13541341
}
13551342

1356-
return (llm_graph_input_attn_kv_self *) res->add_input(std::move(inp));
1343+
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
13571344
}
13581345

13591346
ggml_tensor * llm_graph_context::build_attn(
1360-
llm_graph_input_attn_kv_self * inp,
1347+
llm_graph_input_attn_kv_unified * inp,
13611348
ggml_cgraph * gf,
13621349
ggml_tensor * wo,
13631350
ggml_tensor * wo_b,
@@ -1490,12 +1477,8 @@ ggml_tensor * llm_graph_context::build_attn(
14901477
return cur;
14911478
}
14921479

1493-
llm_graph_input_attn_dec * llm_graph_context::build_attn_inp_dec(
1494-
bool causal,
1495-
bool swa) const {
1496-
auto * inp_kv_self = build_attn_inp_kv_self(causal, swa);
1497-
1498-
auto inp = std::make_unique<llm_graph_input_attn_dec>(inp_kv_self, cross);
1480+
llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1481+
auto inp = std::make_unique<llm_graph_input_attn_cross>(cross);
14991482

15001483
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
15011484

@@ -1504,11 +1487,11 @@ llm_graph_input_attn_dec * llm_graph_context::build_attn_inp_dec(
15041487

15051488
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
15061489

1507-
return (llm_graph_input_attn_dec *) res->add_input(std::move(inp));
1490+
return (llm_graph_input_attn_cross *) res->add_input(std::move(inp));
15081491
}
15091492

15101493
ggml_tensor * llm_graph_context::build_attn(
1511-
llm_graph_input_attn_dec * inp,
1494+
llm_graph_input_attn_cross * inp,
15121495
ggml_cgraph * gf,
15131496
ggml_tensor * wo,
15141497
ggml_tensor * wo_b,

src/llama-graph.h

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ enum llm_graph_type {
2828
LLM_GRAPH_TYPE_DECODER,
2929
};
3030

31-
//
32-
// llm_build
33-
//
34-
3531
enum llm_ffn_op_type {
3632
LLM_FFN_SILU,
3733
LLM_FFN_GELU,
@@ -105,20 +101,18 @@ class llm_graph_input_pos : public llm_graph_input_i {
105101
const int64_t n_pos_per_token = 1;
106102
};
107103

108-
// I32 [n_batch, n_batch]
109104
class llm_graph_input_pos_bucket : public llm_graph_input_i {
110105
public:
111106
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}
112107
virtual ~llm_graph_input_pos_bucket() = default;
113108

114109
void set_input(const llama_ubatch * ubatch) override;
115110

116-
ggml_tensor * pos_bucket = nullptr;
111+
ggml_tensor * pos_bucket = nullptr; // I32 [n_batch, n_batch]
117112

118113
const llama_hparams & hparams;
119114
};
120115

121-
// I32 [n_kv, n_batch]
122116
class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
123117
public:
124118
llm_graph_input_pos_bucket_kv(
@@ -128,7 +122,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
128122

129123
void set_input(const llama_ubatch * ubatch) override;
130124

131-
ggml_tensor * pos_bucket = nullptr;
125+
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
132126

133127
const llama_hparams & hparams;
134128
const llama_kv_cache_unified * kv_self;
@@ -176,33 +170,30 @@ class llm_graph_input_cls : public llm_graph_input_i {
176170
const llama_cparams & cparams;
177171
};
178172

179-
// I32 [kv_size]
180173
class llm_graph_input_s_copy : public llm_graph_input_i {
181174
public:
182175
llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
183176
virtual ~llm_graph_input_s_copy() = default;
184177

185178
void set_input(const llama_ubatch * ubatch) override;
186179

187-
ggml_tensor * s_copy;
180+
ggml_tensor * s_copy; // I32 [kv_size]
188181

189182
const llama_kv_cache_unified * kv_self;
190183
};
191184

192-
// F32 [1, n_kv]
193185
class llm_graph_input_s_mask : public llm_graph_input_i {
194186
public:
195187
llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
196188
virtual ~llm_graph_input_s_mask() = default;
197189

198190
void set_input(const llama_ubatch * ubatch) override;
199191

200-
ggml_tensor * s_mask;
192+
ggml_tensor * s_mask; // F32 [1, n_kv]
201193

202194
const llama_kv_cache_unified * kv_self;
203195
};
204196

205-
// F32 [n_embd, n_outputs_enc]
206197
class llm_graph_input_cross_embd : public llm_graph_input_i {
207198
public:
208199
llm_graph_input_cross_embd(
@@ -211,18 +202,18 @@ class llm_graph_input_cross_embd : public llm_graph_input_i {
211202

212203
void set_input(const llama_ubatch * ubatch) override;
213204

214-
ggml_tensor * cross_embd;
205+
ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc]
215206

216207
const llama_cross * cross;
217208
};
218209

219-
class llm_graph_input_attn_base : public llm_graph_input_i {
210+
class llm_graph_input_attn_no_cache : public llm_graph_input_i {
220211
public:
221-
llm_graph_input_attn_base(const llama_hparams & hparams, const llama_cparams & cparams) :
212+
llm_graph_input_attn_no_cache(const llama_hparams & hparams, const llama_cparams & cparams) :
222213
hparams(hparams),
223214
cparams(cparams) {
224215
}
225-
~llm_graph_input_attn_base() = default;
216+
~llm_graph_input_attn_no_cache() = default;
226217

227218
void set_input(const llama_ubatch * ubatch) override;
228219

@@ -235,17 +226,17 @@ class llm_graph_input_attn_base : public llm_graph_input_i {
235226
const llama_cparams & cparams;
236227
};
237228

238-
class llm_graph_input_attn_kv_self : public llm_graph_input_i {
229+
class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
239230
public:
240-
llm_graph_input_attn_kv_self(
231+
llm_graph_input_attn_kv_unified(
241232
const llama_hparams & hparams,
242233
const llama_cparams & cparams,
243234
const llama_kv_cache_unified * kv_self) :
244235
hparams(hparams),
245236
cparams(cparams),
246237
kv_self(kv_self) {
247238
}
248-
~llm_graph_input_attn_kv_self() = default;
239+
~llm_graph_input_attn_kv_unified() = default;
249240

250241
void set_input(const llama_ubatch * ubatch) override;
251242

@@ -263,24 +254,18 @@ class llm_graph_input_attn_kv_self : public llm_graph_input_i {
263254
const llama_kv_cache_unified * kv_self;
264255
};
265256

266-
class llm_graph_input_attn_dec : public llm_graph_input_i {
257+
class llm_graph_input_attn_cross : public llm_graph_input_i {
267258
public:
268-
llm_graph_input_attn_dec(
269-
llm_graph_input_attn_kv_self * inp_kv_self,
270-
const llama_cross * cross) : inp_kv_self(inp_kv_self), cross(cross) {}
271-
~llm_graph_input_attn_dec() = default;
259+
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
260+
~llm_graph_input_attn_cross() = default;
272261

273262
void set_input(const llama_ubatch * ubatch) override;
274263

275-
ggml_tensor * get_kq_mask() const { return inp_kv_self->get_kq_mask(); }
276-
ggml_tensor * get_kq_mask_swa() const { return inp_kv_self->get_kq_mask_swa(); }
277264
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
278265

279266
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
280267
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
281268

282-
llm_graph_input_attn_kv_self * inp_kv_self = nullptr;
283-
284269
const llama_cross * cross = nullptr;
285270
};
286271

@@ -511,13 +496,10 @@ struct llm_graph_context {
511496
bool v_trans,
512497
float kq_scale) const;
513498

514-
// no memory
515-
llm_graph_input_attn_base * build_attn_inp_base(
516-
bool causal,
517-
bool swa) const;
499+
llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
518500

519501
ggml_tensor * build_attn(
520-
llm_graph_input_attn_base * inp,
502+
llm_graph_input_attn_no_cache * inp,
521503
ggml_cgraph * gf,
522504
ggml_tensor * wo,
523505
ggml_tensor * wo_b,
@@ -528,13 +510,12 @@ struct llm_graph_context {
528510
float kq_scale,
529511
int il) const;
530512

531-
// kv cache (unified)
532-
llm_graph_input_attn_kv_self * build_attn_inp_kv_self(
513+
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(
533514
bool causal,
534515
bool swa) const;
535516

536517
ggml_tensor * build_attn(
537-
llm_graph_input_attn_kv_self * inp,
518+
llm_graph_input_attn_kv_unified * inp,
538519
ggml_cgraph * gf,
539520
ggml_tensor * wo,
540521
ggml_tensor * wo_b,
@@ -545,13 +526,10 @@ struct llm_graph_context {
545526
float kq_scale,
546527
int il) const;
547528

548-
// enc-dec cross attention
549-
llm_graph_input_attn_dec * build_attn_inp_dec(
550-
bool causal,
551-
bool swa) const;
529+
llm_graph_input_attn_cross * build_attn_inp_cross() const;
552530

553531
ggml_tensor * build_attn(
554-
llm_graph_input_attn_dec * inp,
532+
llm_graph_input_attn_cross * inp,
555533
ggml_cgraph * gf,
556534
ggml_tensor * wo,
557535
ggml_tensor * wo_b,

0 commit comments

Comments
 (0)