From a511731c363349eda1d5debd9e7afb792631d920 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 9 May 2025 09:38:46 +0200 Subject: [PATCH 1/2] mtmd : fix batch_view for m-rope --- tools/mtmd/mtmd.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 5d18e8929b31f..33ce5926d12a2 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -554,14 +554,19 @@ struct decode_embd_batch { llama_batch get_view(int offset, int n_tokens) { llama_pos * pos_ptr; pos_view.clear(); - pos_view.resize(n_tokens * n_pos_per_embd); + pos_view.reserve(n_tokens * n_pos_per_embd); if (n_pos_per_embd > 1) { // mrope // for example, with layout of src: 1234...1234...1234...1234... // offset 2 will give us dst: 34...34...34...34... for (int i = 0; i < n_pos_per_embd; i++) { - auto src = pos.begin() + i * batch.n_tokens + offset; - pos_view.insert(pos_view.end(), src, src + n_tokens); + // assume n_tokens < batch.n_tokens + // batch.n_tokens is number of **total** tokens + // n_tokens is number of viewed token + size_t src_idx = i * batch.n_tokens + offset; + pos_view.insert(pos_view.end(), + pos.data() + src_idx, + pos.data() + src_idx + n_tokens); } pos_ptr = pos_view.data(); } else { From 0340281f2bb1f286454753e68e68493d49dbada9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 9 May 2025 09:41:13 +0200 Subject: [PATCH 2/2] nits : fix comment --- tools/mtmd/mtmd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 33ce5926d12a2..2fecf08a44e94 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -560,7 +560,7 @@ struct decode_embd_batch { // for example, with layout of src: 1234...1234...1234...1234... // offset 2 will give us dst: 34...34...34...34... for (int i = 0; i < n_pos_per_embd; i++) { - // assume n_tokens < batch.n_tokens + // assume n_tokens is less than or equal to batch.n_tokens // batch.n_tokens is number of **total** tokens // n_tokens is number of viewed token size_t src_idx = i * batch.n_tokens + offset;