@@ -2082,7 +2082,7 @@ struct llama_context {
2082
2082
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2083
2083
struct ggml_tensor * inp_cls; // I32 [n_batch]
2084
2084
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2085
- struct ggml_tensor * inp_s_mask; // F32 [kv_size]
2085
+ struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
2086
2086
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
2087
2087
2088
2088
#ifdef GGML_USE_MPI
@@ -5518,6 +5518,9 @@ struct llm_build_context {
5518
5518
lctx.inp_K_shift = nullptr;
5519
5519
lctx.inp_mean = nullptr;
5520
5520
lctx.inp_cls = nullptr;
5521
+ lctx.inp_s_copy = nullptr;
5522
+ lctx.inp_s_mask = nullptr;
5523
+ lctx.inp_s_seq = nullptr;
5521
5524
}
5522
5525
5523
5526
void free() {
@@ -5559,14 +5562,14 @@ struct llm_build_context {
5559
5562
5560
5563
GGML_ASSERT(kv_self.recurrent);
5561
5564
5562
- lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size );
5565
+ struct ggml_tensor * state_copy = build_inp_s_copy( );
5563
5566
5564
5567
for (int il = 0; il < n_layer; ++il) {
5565
5568
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
5566
5569
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
5567
5570
5568
- conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy );
5569
- ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy );
5571
+ conv_states = ggml_get_rows(ctx0, conv_states, state_copy );
5572
+ ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy );
5570
5573
5571
5574
// TODO: name the intermediate tensors with cb()
5572
5575
@@ -5665,6 +5668,27 @@ struct llm_build_context {
5665
5668
return lctx.inp_cls;
5666
5669
}
5667
5670
5671
+ struct ggml_tensor * build_inp_s_copy() {
5672
+ lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
5673
+ cb(lctx.inp_s_copy, "inp_s_copy", -1);
5674
+ ggml_set_input(lctx.inp_s_copy);
5675
+ return lctx.inp_s_copy;
5676
+ }
5677
+
5678
+ struct ggml_tensor * build_inp_s_mask() {
5679
+ lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
5680
+ cb(lctx.inp_s_mask, "inp_s_mask", -1);
5681
+ ggml_set_input(lctx.inp_s_mask);
5682
+ return lctx.inp_s_mask;
5683
+ }
5684
+
5685
+ struct ggml_tensor * build_inp_s_seq() {
5686
+ lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
5687
+ cb(lctx.inp_s_seq, "inp_s_seq", -1);
5688
+ ggml_set_input(lctx.inp_s_seq);
5689
+ return lctx.inp_s_seq;
5690
+ }
5691
+
5668
5692
struct ggml_cgraph * build_llama() {
5669
5693
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5670
5694
@@ -8148,12 +8172,8 @@ struct llm_build_context {
8148
8172
// {n_embd, n_tokens}
8149
8173
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
8150
8174
8151
- struct ggml_tensor * state_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
8152
- struct ggml_tensor * state_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
8153
- lctx.inp_s_mask = state_mask;
8154
- lctx.inp_s_seq = state_seq;
8155
- ggml_set_input(state_mask);
8156
- ggml_set_input(state_seq);
8175
+ struct ggml_tensor * state_mask = build_inp_s_mask();
8176
+ struct ggml_tensor * state_seq = build_inp_s_seq();
8157
8177
8158
8178
for (int il = 0; il < n_layer; ++il) {
8159
8179
// (ab)using the KV cache to store the states
@@ -8508,7 +8528,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8508
8528
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
8509
8529
}
8510
8530
8511
- if (batch.pos) {
8531
+ if (batch.pos && lctx.inp_pos ) {
8512
8532
const int64_t n_tokens = batch.n_tokens;
8513
8533
8514
8534
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
@@ -8519,61 +8539,63 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8519
8539
"non-causal attention with generative models is not supported"
8520
8540
);
8521
8541
8522
- // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8523
- if (cparams.causal_attn) {
8524
- const int64_t n_kv = kv_self.n;
8525
- const int64_t n_tokens = batch.n_tokens;
8542
+ if (lctx.inp_KQ_mask) {
8543
+ // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8544
+ if (cparams.causal_attn) {
8545
+ const int64_t n_kv = kv_self.n;
8546
+ const int64_t n_tokens = batch.n_tokens;
8526
8547
8527
- assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8548
+ assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8528
8549
8529
- float * data = (float *) lctx.inp_KQ_mask->data;
8550
+ float * data = (float *) lctx.inp_KQ_mask->data;
8530
8551
8531
- // For causal attention, use only the previous KV cells
8532
- // of the correct sequence for each token of the batch.
8533
- // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
8534
- for (int h = 0; h < 1; ++h) {
8535
- for (int j = 0; j < n_tokens; ++j) {
8536
- const llama_pos pos = batch.pos[j];
8537
- const llama_seq_id seq_id = batch.seq_id[j][0];
8552
+ // For causal attention, use only the previous KV cells
8553
+ // of the correct sequence for each token of the batch.
8554
+ // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
8555
+ for (int h = 0; h < 1; ++h) {
8556
+ for (int j = 0; j < n_tokens; ++j) {
8557
+ const llama_pos pos = batch.pos[j];
8558
+ const llama_seq_id seq_id = batch.seq_id[j][0];
8538
8559
8539
- for (int i = 0; i < n_kv; ++i) {
8540
- float f;
8541
- if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8542
- f = -INFINITY;
8543
- } else {
8544
- f = 0.0f;
8560
+ for (int i = 0; i < n_kv; ++i) {
8561
+ float f;
8562
+ if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8563
+ f = -INFINITY;
8564
+ } else {
8565
+ f = 0.0f;
8566
+ }
8567
+ data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
8545
8568
}
8546
- data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
8547
8569
}
8548
8570
}
8549
- }
8550
- } else {
8551
- // when using kv cache, the mask needs to match the kv cache size
8552
- const int64_t n_tokens = batch.n_tokens;
8553
- const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
8571
+ } else {
8572
+ // when using kv cache, the mask needs to match the kv cache size
8573
+ const int64_t n_tokens = batch.n_tokens;
8574
+ const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
8554
8575
8555
- assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8576
+ assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8556
8577
8557
- float * data = (float *) lctx.inp_KQ_mask->data;
8578
+ float * data = (float *) lctx.inp_KQ_mask->data;
8558
8579
8559
- for (int h = 0; h < 1; ++h) {
8560
- for (int j = 0; j < n_tokens; ++j) {
8561
- const llama_seq_id seq_id = batch.seq_id[j][0];
8580
+ for (int h = 0; h < 1; ++h) {
8581
+ for (int j = 0; j < n_tokens; ++j) {
8582
+ const llama_seq_id seq_id = batch.seq_id[j][0];
8562
8583
8563
- for (int i = 0; i < n_tokens; ++i) {
8564
- float f = -INFINITY;
8565
- for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8566
- if (batch.seq_id[i][s] == seq_id) {
8567
- f = 0.0f;
8568
- break;
8584
+ for (int i = 0; i < n_tokens; ++i) {
8585
+ float f = -INFINITY;
8586
+ for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8587
+ if (batch.seq_id[i][s] == seq_id) {
8588
+ f = 0.0f;
8589
+ break;
8590
+ }
8569
8591
}
8570
- }
8571
8592
8572
- data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8573
- }
8593
+ data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8594
+ }
8574
8595
8575
- for (int i = n_tokens; i < n_stride; ++i) {
8576
- data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
8596
+ for (int i = n_tokens; i < n_stride; ++i) {
8597
+ data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
8598
+ }
8577
8599
}
8578
8600
}
8579
8601
}
@@ -9272,11 +9294,15 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
9272
9294
}
9273
9295
9274
9296
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
9275
- llama_set_s_copy(lctx);
9276
-
9277
9297
{
9298
+ ggml_backend_sched_reset(lctx.sched);
9299
+
9278
9300
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
9279
9301
9302
+ ggml_backend_sched_alloc_graph(lctx.sched, gf);
9303
+
9304
+ llama_set_s_copy(lctx);
9305
+
9280
9306
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
9281
9307
9282
9308
need_reserve = true;
0 commit comments