Skip to content

Commit a9efdbb

Browse files
committed
qwen2vl: fix mrope position
1 parent e8827a6 commit a9efdbb

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

examples/llava/qwen2vl-cli.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
6868

6969
float * batch_embd = image_embed->embed+i*n_embd;
7070
auto batch = llama_batch_ext_ptr::init_from_embd(batch_embd, n_eval, n_embd, 0, 0);
71-
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval);
71+
llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval * 4);
7272

7373
if (llama_decode_ext(ctx_llama, batch.get())) {
7474
LOG_ERR("%s : failed to eval\n", __func__);
@@ -91,18 +91,18 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
9191
}
9292

9393
// TODO: add mrope pos ids somewhere else
94-
int n_tokens = n_eval;
95-
pos.resize(n_tokens * 4);
94+
pos.resize(n_eval * 4);
9695
std::fill(pos.begin(), pos.end(), 0);
97-
for (int j = 0; j < n_tokens * 3; j ++) {
98-
pos[j] = *st_pos_id + (j % n_tokens);
96+
for (int j = 0; j < n_eval * 3; j ++) {
97+
pos[j] = *st_pos_id + (j % n_eval);
9998
}
10099

101100
llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1));
102101
for (int j = 0; j < n_eval; j++) {
103102
llama_token token = tokens[i + j];
104-
batch.add_text(token, pos[j], 0, false);
103+
batch.add_text(token, 0, 0, false); // position is set in the next step
105104
}
105+
llama_batch_ext_set_pos(batch.get(), pos.data(), pos.size());
106106
llama_batch_ext_set_output_last(batch.get());
107107

108108
if (llama_decode_ext(ctx_llama, batch.get())) {

src/llama-batch.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "llama-batch.h"
2+
#include "llama-graph.h"
23

34
#include <cstring>
45
#include <algorithm>
@@ -356,7 +357,7 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc
356357
batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
357358
}
358359

359-
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc);
360+
batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc * MAX_POS_PER_TOKEN);
360361
batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc);
361362
batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
362363
for (int i = 0; i < n_tokens_alloc; ++i) {
@@ -390,7 +391,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd(
390391
}
391392

392393
int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) {
393-
if ((size_t) batch->n_tokens != n_pos) {
394+
if ((size_t) batch->n_tokens * MAX_POS_PER_TOKEN < n_pos) {
394395
return -1;
395396
}
396397
memcpy(batch->pos, pos, n_pos * sizeof(llama_pos));

src/llama-graph.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,9 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
603603
}
604604

605605
int64_t llm_graph_context::n_pos_per_token() const {
606-
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
606+
constexpr int64_t n_pos_per_token_qwen2vl = 4;
607+
static_assert(n_pos_per_token_qwen2vl <= MAX_POS_PER_TOKEN);
608+
return arch == LLM_ARCH_QWEN2VL ? n_pos_per_token_qwen2vl : 1;
607609
}
608610

609611
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {

src/llama-graph.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <set>
1111
#include <functional>
1212

13+
#define MAX_POS_PER_TOKEN 4
14+
1315
struct ggml_cgraph;
1416
struct ggml_context;
1517
struct ggml_tensor;

0 commit comments

Comments
 (0)