Skip to content

Commit 3e06fca

Browse files
committed
llama : fix Mamba inference for pipeline parallelism
1 parent 1ac668e commit 3e06fca

File tree

1 file changed

+80
-54
lines changed

1 file changed

+80
-54
lines changed

llama.cpp

Lines changed: 80 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2082,7 +2082,7 @@ struct llama_context {
20822082
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
20832083
struct ggml_tensor * inp_cls; // I32 [n_batch]
20842084
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2085-
struct ggml_tensor * inp_s_mask; // F32 [kv_size]
2085+
struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
20862086
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
20872087

20882088
#ifdef GGML_USE_MPI
@@ -5518,6 +5518,9 @@ struct llm_build_context {
55185518
lctx.inp_K_shift = nullptr;
55195519
lctx.inp_mean = nullptr;
55205520
lctx.inp_cls = nullptr;
5521+
lctx.inp_s_copy = nullptr;
5522+
lctx.inp_s_mask = nullptr;
5523+
lctx.inp_s_seq = nullptr;
55215524
}
55225525

55235526
void free() {
@@ -5559,14 +5562,14 @@ struct llm_build_context {
55595562

55605563
GGML_ASSERT(kv_self.recurrent);
55615564

5562-
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
5565+
struct ggml_tensor * state_copy = build_inp_s_copy();
55635566

55645567
for (int il = 0; il < n_layer; ++il) {
55655568
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
55665569
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
55675570

5568-
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
5569-
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
5571+
conv_states = ggml_get_rows(ctx0, conv_states, state_copy);
5572+
ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy);
55705573

55715574
// TODO: name the intermediate tensors with cb()
55725575

@@ -5665,6 +5668,27 @@ struct llm_build_context {
56655668
return lctx.inp_cls;
56665669
}
56675670

5671+
struct ggml_tensor * build_inp_s_copy() {
5672+
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size);
5673+
cb(lctx.inp_s_copy, "inp_s_copy", -1);
5674+
ggml_set_input(lctx.inp_s_copy);
5675+
return lctx.inp_s_copy;
5676+
}
5677+
5678+
struct ggml_tensor * build_inp_s_mask() {
5679+
lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
5680+
cb(lctx.inp_s_mask, "inp_s_mask", -1);
5681+
ggml_set_input(lctx.inp_s_mask);
5682+
return lctx.inp_s_mask;
5683+
}
5684+
5685+
struct ggml_tensor * build_inp_s_seq() {
5686+
lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
5687+
cb(lctx.inp_s_seq, "inp_s_seq", -1);
5688+
ggml_set_input(lctx.inp_s_seq);
5689+
return lctx.inp_s_seq;
5690+
}
5691+
56685692
struct ggml_cgraph * build_llama() {
56695693
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
56705694

@@ -8148,12 +8172,8 @@ struct llm_build_context {
81488172
// {n_embd, n_tokens}
81498173
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
81508174

8151-
struct ggml_tensor * state_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv);
8152-
struct ggml_tensor * state_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
8153-
lctx.inp_s_mask = state_mask;
8154-
lctx.inp_s_seq = state_seq;
8155-
ggml_set_input(state_mask);
8156-
ggml_set_input(state_seq);
8175+
struct ggml_tensor * state_mask = build_inp_s_mask();
8176+
struct ggml_tensor * state_seq = build_inp_s_seq();
81578177

81588178
for (int il = 0; il < n_layer; ++il) {
81598179
// (ab)using the KV cache to store the states
@@ -8508,7 +8528,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
85088528
ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
85098529
}
85108530

8511-
if (batch.pos) {
8531+
if (batch.pos && lctx.inp_pos) {
85128532
const int64_t n_tokens = batch.n_tokens;
85138533

85148534
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
@@ -8519,61 +8539,63 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
85198539
"non-causal attention with generative models is not supported"
85208540
);
85218541

8522-
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8523-
if (cparams.causal_attn) {
8524-
const int64_t n_kv = kv_self.n;
8525-
const int64_t n_tokens = batch.n_tokens;
8542+
if (lctx.inp_KQ_mask) {
8543+
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8544+
if (cparams.causal_attn) {
8545+
const int64_t n_kv = kv_self.n;
8546+
const int64_t n_tokens = batch.n_tokens;
85268547

8527-
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8548+
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
85288549

8529-
float * data = (float *) lctx.inp_KQ_mask->data;
8550+
float * data = (float *) lctx.inp_KQ_mask->data;
85308551

8531-
// For causal attention, use only the previous KV cells
8532-
// of the correct sequence for each token of the batch.
8533-
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
8534-
for (int h = 0; h < 1; ++h) {
8535-
for (int j = 0; j < n_tokens; ++j) {
8536-
const llama_pos pos = batch.pos[j];
8537-
const llama_seq_id seq_id = batch.seq_id[j][0];
8552+
// For causal attention, use only the previous KV cells
8553+
// of the correct sequence for each token of the batch.
8554+
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
8555+
for (int h = 0; h < 1; ++h) {
8556+
for (int j = 0; j < n_tokens; ++j) {
8557+
const llama_pos pos = batch.pos[j];
8558+
const llama_seq_id seq_id = batch.seq_id[j][0];
85388559

8539-
for (int i = 0; i < n_kv; ++i) {
8540-
float f;
8541-
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8542-
f = -INFINITY;
8543-
} else {
8544-
f = 0.0f;
8560+
for (int i = 0; i < n_kv; ++i) {
8561+
float f;
8562+
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
8563+
f = -INFINITY;
8564+
} else {
8565+
f = 0.0f;
8566+
}
8567+
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
85458568
}
8546-
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
85478569
}
85488570
}
8549-
}
8550-
} else {
8551-
// when using kv cache, the mask needs to match the kv cache size
8552-
const int64_t n_tokens = batch.n_tokens;
8553-
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
8571+
} else {
8572+
// when using kv cache, the mask needs to match the kv cache size
8573+
const int64_t n_tokens = batch.n_tokens;
8574+
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
85548575

8555-
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
8576+
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
85568577

8557-
float * data = (float *) lctx.inp_KQ_mask->data;
8578+
float * data = (float *) lctx.inp_KQ_mask->data;
85588579

8559-
for (int h = 0; h < 1; ++h) {
8560-
for (int j = 0; j < n_tokens; ++j) {
8561-
const llama_seq_id seq_id = batch.seq_id[j][0];
8580+
for (int h = 0; h < 1; ++h) {
8581+
for (int j = 0; j < n_tokens; ++j) {
8582+
const llama_seq_id seq_id = batch.seq_id[j][0];
85628583

8563-
for (int i = 0; i < n_tokens; ++i) {
8564-
float f = -INFINITY;
8565-
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8566-
if (batch.seq_id[i][s] == seq_id) {
8567-
f = 0.0f;
8568-
break;
8584+
for (int i = 0; i < n_tokens; ++i) {
8585+
float f = -INFINITY;
8586+
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
8587+
if (batch.seq_id[i][s] == seq_id) {
8588+
f = 0.0f;
8589+
break;
8590+
}
85698591
}
8570-
}
85718592

8572-
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8573-
}
8593+
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8594+
}
85748595

8575-
for (int i = n_tokens; i < n_stride; ++i) {
8576-
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
8596+
for (int i = n_tokens; i < n_stride; ++i) {
8597+
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
8598+
}
85778599
}
85788600
}
85798601
}
@@ -9272,11 +9294,15 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
92729294
}
92739295

92749296
if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) {
9275-
llama_set_s_copy(lctx);
9276-
92779297
{
9298+
ggml_backend_sched_reset(lctx.sched);
9299+
92789300
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
92799301

9302+
ggml_backend_sched_alloc_graph(lctx.sched, gf);
9303+
9304+
llama_set_s_copy(lctx);
9305+
92809306
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
92819307

92829308
need_reserve = true;

0 commit comments

Comments
 (0)