Skip to content

Commit eb589d5

Browse files
committed
llama : avoid copies for simple batch splits
1 parent 61200ef commit eb589d5

File tree

1 file changed

+52
-29
lines changed

1 file changed

+52
-29
lines changed

llama.cpp

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3143,36 +3143,51 @@ struct llama_sbatch {
31433143
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
31443144
// NOTE: loops are separated for cache-friendliness
31453145
if (batch->token) {
3146-
for (size_t i = 0; i < length; ++i) {
3147-
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
3146+
if (ubatch.equal_seqs) {
3147+
for (size_t i = 0; i < length; ++i) {
3148+
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
3149+
}
3150+
} else {
3151+
// simple split
3152+
ubatch.token = batch->token + seq.offset;
31483153
}
31493154
} else {
31503155
ubatch.token = nullptr;
31513156
}
31523157
if (batch->embd) {
3153-
for (size_t i = 0; i < length; ++i) {
3154-
memcpy(
3155-
ubatch.embd + n_embd * (ubatch.n_tokens + i),
3156-
batch->embd + n_embd * ids[seq.offset + i],
3157-
n_embd * sizeof(float)
3158-
);
3158+
if (ubatch.equal_seqs) {
3159+
for (size_t i = 0; i < length; ++i) {
3160+
memcpy(
3161+
ubatch.embd + n_embd * (ubatch.n_tokens + i),
3162+
batch->embd + n_embd * ids[seq.offset + i],
3163+
n_embd * sizeof(float)
3164+
);
3165+
}
3166+
} else {
3167+
// simple split
3168+
ubatch.embd = batch->embd + seq.offset;
31593169
}
31603170
} else {
31613171
ubatch.embd = nullptr;
31623172
}
31633173
// from here on, the else branches are deprecated;
31643174
// they are helpers for smoother batch API transition
31653175
if (batch->pos) {
3166-
for (size_t i = 0; i < length; ++i) {
3167-
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
3176+
if (ubatch.equal_seqs) {
3177+
for (size_t i = 0; i < length; ++i) {
3178+
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
3179+
}
3180+
} else {
3181+
// simple split
3182+
ubatch.pos = batch->pos + seq.offset;
31683183
}
31693184
} else {
31703185
for (size_t i = 0; i < length; ++i) {
31713186
llama_pos bi = ids[seq.offset + i];
31723187
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
31733188
}
31743189
}
3175-
if (seq.n_seq_id > 0) {
3190+
if (ubatch.equal_seqs) {
31763191
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
31773192
if (seq.seq_id) {
31783193
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
@@ -3181,9 +3196,10 @@ struct llama_sbatch {
31813196
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
31823197
}
31833198
} else {
3199+
// simple split
31843200
if (batch->n_seq_id) {
31853201
for (size_t i = 0; i < length; ++i) {
3186-
ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]];
3202+
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
31873203
}
31883204
} else {
31893205
for (size_t i = 0; i < length; ++i) {
@@ -3192,7 +3208,7 @@ struct llama_sbatch {
31923208
}
31933209
if (batch->seq_id) {
31943210
for (size_t i = 0; i < length; ++i) {
3195-
ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]];
3211+
ubatch.seq_id = batch->seq_id + seq.offset;
31963212
}
31973213
} else {
31983214
for (size_t i = 0; i < length; ++i) {
@@ -3201,11 +3217,19 @@ struct llama_sbatch {
32013217
}
32023218
}
32033219
if (batch->logits) {
3204-
for (size_t i = 0; i < length; ++i) {
3205-
size_t id = ids[seq.offset + i];
3206-
int8_t is_output = batch->logits[id];
3207-
ubatch.output[ubatch.n_tokens + i] = is_output;
3208-
if (is_output) { out_ids.push_back(id); }
3220+
if (ubatch.equal_seqs) {
3221+
for (size_t i = 0; i < length; ++i) {
3222+
size_t id = ids[seq.offset + i];
3223+
int8_t is_output = batch->logits[id];
3224+
ubatch.output[ubatch.n_tokens + i] = is_output;
3225+
if (is_output) { out_ids.push_back(id); }
3226+
}
3227+
} else {
3228+
// simple split
3229+
ubatch.output = batch->logits + seq.offset;
3230+
for (size_t i = 0; i < length; ++i) {
3231+
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
3232+
}
32093233
}
32103234
} else if (logits_all) {
32113235
for (size_t i = 0; i < length; ++i) {
@@ -3222,26 +3246,25 @@ struct llama_sbatch {
32223246
}
32233247
}
32243248
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
3225-
ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1;
3249+
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
32263250
}
32273251
ubatch.n_tokens += length;
3228-
ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits
3252+
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
32293253
seq.offset += length;
32303254
seq.length -= length;
32313255
n_tokens -= length;
32323256
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
32333257
}
32343258

3235-
// legacy split, unknown number of sequences of unequal lengths
3236-
llama_ubatch split_slice(size_t n_ubatch) {
3259+
// simple split, unknown number of sequences of unequal lengths
3260+
llama_ubatch split_simple(size_t n_ubatch) {
32373261
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
32383262
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
32393263
ubatch.equal_seqs = false;
32403264
if (!seq.empty()) {
32413265
llama_sbatch_seq & s = seq[0];
32423266
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
32433267
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
3244-
// TODO: reduce copies
32453268
add_seq_to_ubatch(ubatch, s, length);
32463269
}
32473270
return ubatch;
@@ -3254,7 +3277,7 @@ struct llama_sbatch {
32543277
if (!seq.empty()) {
32553278
size_t length = 0;
32563279
size_t n_tokens_in_ubatch = 0;
3257-
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits
3280+
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
32583281
// smallest first, because it's easier to split this way;
32593282
// starting from the end to pop in constant time.
32603283
for (size_t i = seq.size(); i-- > 0;) {
@@ -3282,13 +3305,13 @@ struct llama_sbatch {
32823305
if (!seq.empty()) {
32833306
llama_sbatch_seq & s = seq[seq.size() - 1];
32843307
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
3285-
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits
3308+
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
32863309
add_seq_to_ubatch(ubatch, s, length);
32873310
}
32883311
return ubatch;
32893312
}
32903313

3291-
void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) {
3314+
void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
32923315
GGML_ASSERT(batch.n_tokens >= 0);
32933316
this->batch = &batch;
32943317
this->n_embd = n_embd;
@@ -3302,7 +3325,7 @@ struct llama_sbatch {
33023325
for (size_t i = 0; i < n_tokens; ++i) {
33033326
ids[i] = i;
33043327
}
3305-
if (legacy_split) {
3328+
if (simple_split) {
33063329
seq.resize(1);
33073330
llama_sbatch_seq & s = seq[0];
33083331
s.n_seq_id = 0;
@@ -13737,7 +13760,7 @@ static int llama_decode_internal(
1373713760
}
1373813761

1373913762
lctx.sbatch.from_batch(batch_all, n_embd,
13740-
/* legacy_split */ rs_self.size == 0,
13763+
/* simple_split */ rs_self.size == 0,
1374113764
/* logits_all */ n_outputs == n_tokens_all);
1374213765

1374313766
// reserve output buffer
@@ -13749,7 +13772,7 @@ static int llama_decode_internal(
1374913772
while (lctx.sbatch.n_tokens > 0) {
1375013773
// TODO: deprecate slice splits in favor of equal splits
1375113774
// For now, only use equal splits for recurrent or hybrid model architectures
13752-
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice(n_ubatch);
13775+
llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
1375313776
const uint32_t n_tokens = u_batch.n_tokens;
1375413777

1375513778
// count the outputs in this u_batch

0 commit comments

Comments
 (0)