@@ -3143,36 +3143,51 @@ struct llama_sbatch {
3143
3143
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
3144
3144
// NOTE: loops are separated for cache-friendliness
3145
3145
if (batch->token) {
3146
- for (size_t i = 0; i < length; ++i) {
3147
- ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
3146
+ if (ubatch.equal_seqs) {
3147
+ for (size_t i = 0; i < length; ++i) {
3148
+ ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
3149
+ }
3150
+ } else {
3151
+ // simple split
3152
+ ubatch.token = batch->token + seq.offset;
3148
3153
}
3149
3154
} else {
3150
3155
ubatch.token = nullptr;
3151
3156
}
3152
3157
if (batch->embd) {
3153
- for (size_t i = 0; i < length; ++i) {
3154
- memcpy(
3155
- ubatch.embd + n_embd * (ubatch.n_tokens + i),
3156
- batch->embd + n_embd * ids[seq.offset + i],
3157
- n_embd * sizeof(float)
3158
- );
3158
+ if (ubatch.equal_seqs) {
3159
+ for (size_t i = 0; i < length; ++i) {
3160
+ memcpy(
3161
+ ubatch.embd + n_embd * (ubatch.n_tokens + i),
3162
+ batch->embd + n_embd * ids[seq.offset + i],
3163
+ n_embd * sizeof(float)
3164
+ );
3165
+ }
3166
+ } else {
3167
+ // simple split
3168
+ ubatch.embd = batch->embd + seq.offset;
3159
3169
}
3160
3170
} else {
3161
3171
ubatch.embd = nullptr;
3162
3172
}
3163
3173
// from here on, the else branches are deprecated;
3164
3174
// they are helpers for smoother batch API transition
3165
3175
if (batch->pos) {
3166
- for (size_t i = 0; i < length; ++i) {
3167
- ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
3176
+ if (ubatch.equal_seqs) {
3177
+ for (size_t i = 0; i < length; ++i) {
3178
+ ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
3179
+ }
3180
+ } else {
3181
+ // simple split
3182
+ ubatch.pos = batch->pos + seq.offset;
3168
3183
}
3169
3184
} else {
3170
3185
for (size_t i = 0; i < length; ++i) {
3171
3186
llama_pos bi = ids[seq.offset + i];
3172
3187
ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1);
3173
3188
}
3174
3189
}
3175
- if (seq.n_seq_id > 0 ) {
3190
+ if (ubatch.equal_seqs ) {
3176
3191
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
3177
3192
if (seq.seq_id) {
3178
3193
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
@@ -3181,9 +3196,10 @@ struct llama_sbatch {
3181
3196
ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id;
3182
3197
}
3183
3198
} else {
3199
+ // simple split
3184
3200
if (batch->n_seq_id) {
3185
3201
for (size_t i = 0; i < length; ++i) {
3186
- ubatch.n_seq_id[ubatch.n_seqs + i] = batch->n_seq_id[ids[seq.offset + i]] ;
3202
+ ubatch.n_seq_id = batch->n_seq_id + seq.offset ;
3187
3203
}
3188
3204
} else {
3189
3205
for (size_t i = 0; i < length; ++i) {
@@ -3192,7 +3208,7 @@ struct llama_sbatch {
3192
3208
}
3193
3209
if (batch->seq_id) {
3194
3210
for (size_t i = 0; i < length; ++i) {
3195
- ubatch.seq_id[ubatch.n_seqs + i] = batch->seq_id[ids[seq.offset + i]] ;
3211
+ ubatch.seq_id = batch->seq_id + seq.offset ;
3196
3212
}
3197
3213
} else {
3198
3214
for (size_t i = 0; i < length; ++i) {
@@ -3201,11 +3217,19 @@ struct llama_sbatch {
3201
3217
}
3202
3218
}
3203
3219
if (batch->logits) {
3204
- for (size_t i = 0; i < length; ++i) {
3205
- size_t id = ids[seq.offset + i];
3206
- int8_t is_output = batch->logits[id];
3207
- ubatch.output[ubatch.n_tokens + i] = is_output;
3208
- if (is_output) { out_ids.push_back(id); }
3220
+ if (ubatch.equal_seqs) {
3221
+ for (size_t i = 0; i < length; ++i) {
3222
+ size_t id = ids[seq.offset + i];
3223
+ int8_t is_output = batch->logits[id];
3224
+ ubatch.output[ubatch.n_tokens + i] = is_output;
3225
+ if (is_output) { out_ids.push_back(id); }
3226
+ }
3227
+ } else {
3228
+ // simple split
3229
+ ubatch.output = batch->logits + seq.offset;
3230
+ for (size_t i = 0; i < length; ++i) {
3231
+ if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
3232
+ }
3209
3233
}
3210
3234
} else if (logits_all) {
3211
3235
for (size_t i = 0; i < length; ++i) {
@@ -3222,26 +3246,25 @@ struct llama_sbatch {
3222
3246
}
3223
3247
}
3224
3248
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
3225
- ubatch.n_seq_tokens = seq.n_seq_id > 0 ? length : 1;
3249
+ ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
3226
3250
}
3227
3251
ubatch.n_tokens += length;
3228
- ubatch.n_seqs += seq.n_seq_id > 0 ? 1 : length; // virtual sequences for legacy splits
3252
+ ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
3229
3253
seq.offset += length;
3230
3254
seq.length -= length;
3231
3255
n_tokens -= length;
3232
3256
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
3233
3257
}
3234
3258
3235
- // legacy split, unknown number of sequences of unequal lengths
3236
- llama_ubatch split_slice (size_t n_ubatch) {
3259
+ // simple split, unknown number of sequences of unequal lengths
3260
+ llama_ubatch split_simple (size_t n_ubatch) {
3237
3261
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
3238
3262
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
3239
3263
ubatch.equal_seqs = false;
3240
3264
if (!seq.empty()) {
3241
3265
llama_sbatch_seq & s = seq[0];
3242
3266
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
3243
3267
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
3244
- // TODO: reduce copies
3245
3268
add_seq_to_ubatch(ubatch, s, length);
3246
3269
}
3247
3270
return ubatch;
@@ -3254,7 +3277,7 @@ struct llama_sbatch {
3254
3277
if (!seq.empty()) {
3255
3278
size_t length = 0;
3256
3279
size_t n_tokens_in_ubatch = 0;
3257
- GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with legacy splits
3280
+ GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
3258
3281
// smallest first, because it's easier to split this way;
3259
3282
// starting from the end to pop in constant time.
3260
3283
for (size_t i = seq.size(); i-- > 0;) {
@@ -3282,13 +3305,13 @@ struct llama_sbatch {
3282
3305
if (!seq.empty()) {
3283
3306
llama_sbatch_seq & s = seq[seq.size() - 1];
3284
3307
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
3285
- GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with legacy splits
3308
+ GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
3286
3309
add_seq_to_ubatch(ubatch, s, length);
3287
3310
}
3288
3311
return ubatch;
3289
3312
}
3290
3313
3291
- void from_batch(const llama_batch & batch, const size_t n_embd, const bool legacy_split = false, const bool logits_all = false) {
3314
+ void from_batch(const llama_batch & batch, const size_t n_embd, const bool simple_split = false, const bool logits_all = false) {
3292
3315
GGML_ASSERT(batch.n_tokens >= 0);
3293
3316
this->batch = &batch;
3294
3317
this->n_embd = n_embd;
@@ -3302,7 +3325,7 @@ struct llama_sbatch {
3302
3325
for (size_t i = 0; i < n_tokens; ++i) {
3303
3326
ids[i] = i;
3304
3327
}
3305
- if (legacy_split ) {
3328
+ if (simple_split ) {
3306
3329
seq.resize(1);
3307
3330
llama_sbatch_seq & s = seq[0];
3308
3331
s.n_seq_id = 0;
@@ -13737,7 +13760,7 @@ static int llama_decode_internal(
13737
13760
}
13738
13761
13739
13762
lctx.sbatch.from_batch(batch_all, n_embd,
13740
- /* legacy_split */ rs_self.size == 0,
13763
+ /* simple_split */ rs_self.size == 0,
13741
13764
/* logits_all */ n_outputs == n_tokens_all);
13742
13765
13743
13766
// reserve output buffer
@@ -13749,7 +13772,7 @@ static int llama_decode_internal(
13749
13772
while (lctx.sbatch.n_tokens > 0) {
13750
13773
// TODO: deprecate slice splits in favor of equal splits
13751
13774
// For now, only use equal splits for recurrent or hybrid model architectures
13752
- llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_slice (n_ubatch);
13775
+ llama_ubatch u_batch = (rs_self.size > 0) ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple (n_ubatch);
13753
13776
const uint32_t n_tokens = u_batch.n_tokens;
13754
13777
13755
13778
// count the outputs in this u_batch
0 commit comments