Skip to content

Commit 2d527fd

Browse files
author
FMayran
committed
adapting the previous fix to the syntax used by other fields of the ubatch
1 parent 6df6b98 commit 2d527fd

File tree

3 files changed

+67
-64
lines changed

3 files changed

+67
-64
lines changed

src/llama-batch.cpp

Lines changed: 53 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ bool llama_batch_allocr::init(
257257
continue;
258258
}
259259

260+
//@fmayran: these checks don't make sense with models using position encoding such as Qwen VL, because the position stored in the KV cache can jump around (it is not even always increasing).
261+
//it is not enough to let them be repeating. Within an image embedding, arbitrary jumps are expected.
260262
//const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
261263
//
262264
//if (p0 >= 0) {
@@ -370,37 +372,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
370372

371373
auto udata = std::make_shared<llama_ubatch::data_t>();
372374

373-
udata->token .resize(n_tokens);
374-
udata->embd .clear();
375-
udata->pos .resize(n_tokens);
376-
udata->n_seq_id .resize(n_tokens);
377-
udata->seq_id .resize(n_tokens);
378-
udata->seq_id_unq.resize(0);
379-
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
380-
udata->output .resize(n_tokens);
375+
udata->token .resize(n_tokens);
376+
udata->embd .clear();
377+
udata->pos .resize(n_tokens);
378+
udata->n_seq_id .resize(n_tokens);
379+
udata->seq_id .resize(n_tokens);
380+
udata->seq_id_unq .resize(0);
381+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
382+
udata->output .resize(n_tokens);
383+
udata->kv_position_of_token.resize(n_tokens, -1);
381384

382385
for (uint32_t s = 0; s < n_seqs; ++s) {
383386
udata->seq_idx[s] = s;
384387
udata->seq_id_unq.push_back(s);
385388
}
386389

387390
llama_ubatch res {
388-
/*.b_equal_seqs =*/ true,
389-
/*.n_tokens =*/ n_tokens,
390-
/*.n_seq_tokens =*/ n_seq_tokens,
391-
/*.n_seqs =*/ n_seqs,
392-
/*.n_seqs_unq =*/ n_seqs,
393-
394-
/*.token =*/ udata->token.data(),
395-
/*.embd =*/ nullptr,
396-
/*.pos =*/ udata->pos.data(),
397-
/*.n_seq_id =*/ udata->n_seq_id.data(),
398-
/*.seq_id =*/ udata->seq_id.data(),
399-
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
400-
/*.seq_idx =*/ udata->seq_idx.data(),
401-
/*.output =*/ udata->output.data(),
402-
/*.data =*/ std::move(udata),
403-
/*.kv_position_of_token=*/ {},
391+
/*.b_equal_seqs =*/ true,
392+
/*.n_tokens =*/ n_tokens,
393+
/*.n_seq_tokens =*/ n_seq_tokens,
394+
/*.n_seqs =*/ n_seqs,
395+
/*.n_seqs_unq =*/ n_seqs,
396+
397+
/*.token =*/ udata->token.data(),
398+
/*.embd =*/ nullptr,
399+
/*.pos =*/ udata->pos.data(),
400+
/*.n_seq_id =*/ udata->n_seq_id.data(),
401+
/*.seq_id =*/ udata->seq_id.data(),
402+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
403+
/*.seq_idx =*/ udata->seq_idx.data(),
404+
/*.output =*/ udata->output.data(),
405+
/*.kv_position_of_token=*/ udata->kv_position_of_token.data(),
406+
/*.data =*/ std::move(udata),
404407
};
405408

406409
return res;
@@ -662,14 +665,15 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
662665
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
663666
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
664667

665-
udata->token .resize(n_tokens);
666-
udata->embd .resize(n_embd_all);
667-
udata->pos .resize(n_pos_all);
668-
udata->n_seq_id .resize(n_tokens);
669-
udata->seq_id .resize(n_tokens);
670-
udata->seq_id_unq.resize(0);
671-
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
672-
udata->output .resize(n_tokens);
668+
udata->token .resize(n_tokens);
669+
udata->embd .resize(n_embd_all);
670+
udata->pos .resize(n_pos_all);
671+
udata->n_seq_id .resize(n_tokens);
672+
udata->seq_id .resize(n_tokens);
673+
udata->seq_id_unq .resize(0);
674+
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
675+
udata->output .resize(n_tokens);
676+
udata->kv_position_of_token.resize(n_tokens, -1);
673677

674678
seq_set_t seq_set_unq;
675679

@@ -707,22 +711,23 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
707711
}
708712

709713
llama_ubatch res {
710-
/*.b_equal_seqs =*/ equal_seqs,
711-
/*.n_tokens =*/ n_tokens,
712-
/*.n_seq_tokens =*/ n_tokens/n_seqs,
713-
/*.n_seqs =*/ n_seqs,
714-
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
715-
716-
/*.token =*/ batch.token ? udata->token.data() : nullptr,
717-
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
718-
/*.pos =*/ udata->pos.data(),
719-
/*.n_seq_id =*/ udata->n_seq_id.data(),
720-
/*.seq_id =*/ udata->seq_id.data(),
721-
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
722-
/*.seq_idx =*/ udata->seq_idx.data(),
723-
/*.output =*/ udata->output.data(),
724-
/*.data =*/ std::move(udata),
725-
/*.kv_position_of_token=*/ {},
714+
/*.b_equal_seqs =*/ equal_seqs,
715+
/*.n_tokens =*/ n_tokens,
716+
/*.n_seq_tokens =*/ n_tokens/n_seqs,
717+
/*.n_seqs =*/ n_seqs,
718+
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
719+
720+
/*.token =*/ batch.token ? udata->token.data() : nullptr,
721+
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,
722+
/*.pos =*/ udata->pos.data(),
723+
/*.n_seq_id =*/ udata->n_seq_id.data(),
724+
/*.seq_id =*/ udata->seq_id.data(),
725+
/*.seq_id_unq =*/ udata->seq_id_unq.data(),
726+
/*.seq_idx =*/ udata->seq_idx.data(),
727+
/*.output =*/ udata->output.data(),
728+
/*.kv_position_of_token=*/ udata->kv_position_of_token.data(),
729+
/*.data =*/ std::move(udata),
730+
726731
};
727732

728733
if (debug > 0) {

src/llama-batch.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,16 @@ struct llama_ubatch {
3030
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
3131
// used for extracting sequence pooled embeddings
3232

33-
// // size | idx | val
34-
llama_token * token; // [n_tokens] | i | id, token
35-
float * embd; // [n_embd, n_tokens] | i | embd
36-
llama_pos * pos; // [n_tokens] | i | pos
37-
int32_t * n_seq_id; // [n_tokens] | i | -
38-
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
39-
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
40-
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
41-
int8_t * output; // [n_tokens] | i | -
33+
// // size | idx | val
34+
llama_token * token; // [n_tokens] | i | id, token
35+
float * embd; // [n_embd, n_tokens] | i | embd
36+
llama_pos * pos; // [n_tokens] | i | pos
37+
int32_t * n_seq_id; // [n_tokens] | i | -
38+
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
39+
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
40+
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
41+
int8_t * output; // [n_tokens] | i | -
42+
int32_t * kv_position_of_token; // [n_tokens] | i | kv position whre the token was inserted
4243

4344
struct data_t {
4445
std::vector<llama_token> token;
@@ -49,11 +50,11 @@ struct llama_ubatch {
4950
std::vector<llama_seq_id> seq_id_unq;
5051
std::vector<int32_t> seq_idx;
5152
std::vector<int8_t> output;
53+
std::vector<int32_t> kv_position_of_token;//when pushed to the kv cache, where is the token pushed (used for causal masking)
5254
};
5355

5456
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
5557
std::shared_ptr<data_t> data;
56-
mutable std::vector<int32_t> kv_position_of_token;//when pushed to the kv cache, where is the token pushed (used for causal masking)
5758
};
5859

5960
// a helper for sanitizing, fulfilling and splitting a batch

src/llama-kv-cache.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -875,9 +875,6 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
875875

876876
assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size());
877877

878-
ubatch.kv_position_of_token.clear();//clear first, to ensure that all values will be filled with -1
879-
ubatch.kv_position_of_token.resize(ubatch.n_tokens, -1);
880-
881878
for (uint32_t s = 0; s < sinfo.n_stream(); ++s) {
882879
for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
883880
const uint32_t i = s*sinfo.size() + ii;
@@ -898,7 +895,7 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch &
898895
}
899896

900897
cells.pos_set(idx, ubatch.pos[i]);
901-
ubatch.kv_position_of_token[i] = (int32_t)idx;
898+
ubatch.kv_position_of_token[i] = (int32_t)idx;//set the position in the kv cache as a property for this token (needed for proper causal masking)
902899

903900
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
904901
cells.seq_add(idx, ubatch.seq_id[i][s]);
@@ -1219,8 +1216,8 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u
12191216

12201217
std::fill(data, data + ggml_nelements(dst), -INFINITY);
12211218

1222-
std::vector<int32_t> map_kv_to_batch(n_kv, -1);
1223-
for (size_t i = 0; i < ubatch->kv_position_of_token.size(); ++i)//invert the batch -> kv position map into a kv -> batch position map
1219+
std::vector<int32_t> map_kv_to_batch(n_kv, -1);//for each token in the cache, either (-1) or the position in the current ubatch
1220+
for (uint32_t i = 0; i < n_tokens; ++i)//invert the batch -> kv position map into a kv -> batch position map
12241221
{
12251222
if (ubatch->kv_position_of_token[i] != -1)
12261223
map_kv_to_batch[ubatch->kv_position_of_token[i]] = i;

0 commit comments

Comments
 (0)