Skip to content

Commit 2189fd3

Browse files
authored
mtmd : fix batch_view for m-rope (#13397)
* mtmd : fix batch_view for m-rope * nits : fix comment
1 parent 3f96aef commit 2189fd3

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

tools/mtmd/mtmd.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -554,14 +554,19 @@ struct decode_embd_batch {
554554
llama_batch get_view(int offset, int n_tokens) {
555555
llama_pos * pos_ptr;
556556
pos_view.clear();
557-
pos_view.resize(n_tokens * n_pos_per_embd);
557+
pos_view.reserve(n_tokens * n_pos_per_embd);
558558
if (n_pos_per_embd > 1) {
559559
// mrope
560560
// for example, with layout of src: 1234...1234...1234...1234...
561561
// offset 2 will give us dst: 34...34...34...34...
562562
for (int i = 0; i < n_pos_per_embd; i++) {
563-
auto src = pos.begin() + i * batch.n_tokens + offset;
564-
pos_view.insert(pos_view.end(), src, src + n_tokens);
563+
// assume n_tokens is less than or equal to batch.n_tokens
564+
// batch.n_tokens is number of **total** tokens
565+
// n_tokens is number of viewed token
566+
size_t src_idx = i * batch.n_tokens + offset;
567+
pos_view.insert(pos_view.end(),
568+
pos.data() + src_idx,
569+
pos.data() + src_idx + n_tokens);
565570
}
566571
pos_ptr = pos_view.data();
567572
} else {

0 commit comments

Comments
 (0)