Skip to content

Commit d2b2031

Browse files
authored
llama : (mrope) allow using normal 1D position for text token (#13138)
* llama : (mrope) use normal position for text token * rm n_pos_per_embd from llm_graph_input_attn_temp
1 parent 5fa9e63 commit d2b2031

File tree

3 files changed

+24
-22
lines changed

3 files changed

+24
-22
lines changed

examples/llava/qwen2vl-cli.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,12 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla
9292

9393
static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past, int * st_pos_id) {
9494
int N = (int) tokens.size();
95-
std::vector<llama_pos> pos;
9695
for (int i = 0; i < N; i += n_batch) {
9796
int n_eval = (int) tokens.size() - i;
9897
if (n_eval > n_batch) {
9998
n_eval = n_batch;
10099
}
101100
auto batch = llama_batch_get_one(&tokens[i], n_eval);
102-
// TODO: add mrope pos ids somewhere else
103-
pos.resize(batch.n_tokens * 4);
104-
std::fill(pos.begin(), pos.end(), 0);
105-
for (int j = 0; j < batch.n_tokens * 3; j ++) {
106-
pos[j] = *st_pos_id + (j % batch.n_tokens);
107-
}
108-
batch.pos = pos.data();
109101

110102
if (llama_decode(ctx_llama, batch)) {
111103
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);

src/llama-graph.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,18 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
5555
if (ubatch->pos && pos) {
5656
const int64_t n_tokens = ubatch->n_tokens;
5757

58-
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos));
58+
if (ubatch->token && n_pos_per_embd > 1) {
59+
// in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
60+
// the other dimensions are all 0, they are unused for text tokens
61+
std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd, 0);
62+
// copy the first dimension
63+
for (int i = 0; i < n_tokens; ++i) {
64+
pos_data[i] = ubatch->pos[i];
65+
}
66+
ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
67+
} else {
68+
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
69+
}
5970
}
6071
}
6172

@@ -71,7 +82,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
7182
) * f_attn_temp_scale + 1.0;
7283
}
7384

74-
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
85+
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
7586
}
7687
}
7788

@@ -592,7 +603,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
592603
res (std::make_unique<llm_graph_result>()) {
593604
}
594605

595-
int64_t llm_graph_context::n_pos_per_token() const {
606+
int64_t llm_graph_context::n_pos_per_embd() const {
596607
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
597608
}
598609

@@ -1018,11 +1029,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
10181029
}
10191030

10201031
ggml_tensor * llm_graph_context::build_inp_pos() const {
1021-
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_token());
1032+
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
10221033

10231034
auto & cur = inp->pos;
10241035

1025-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token());
1036+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
10261037
ggml_set_input(cur);
10271038

10281039
res->add_input(std::move(inp));
@@ -1031,11 +1042,12 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
10311042
}
10321043

10331044
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1034-
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1045+
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
10351046

10361047
auto & cur = inp->attn_scale;
10371048

1038-
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
1049+
// this need to be 1x1xN for broadcasting
1050+
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
10391051
ggml_set_input(cur);
10401052

10411053
res->add_input(std::move(inp));

src/llama-graph.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,29 +90,27 @@ class llm_graph_input_embd : public llm_graph_input_i {
9090

9191
class llm_graph_input_pos : public llm_graph_input_i {
9292
public:
93-
llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {}
93+
llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
9494
virtual ~llm_graph_input_pos() = default;
9595

9696
void set_input(const llama_ubatch * ubatch) override;
9797

9898
ggml_tensor * pos = nullptr; // I32 [n_batch]
9999

100-
const int64_t n_pos_per_token = 1;
100+
const int64_t n_pos_per_embd = 1;
101101
};
102102

103103
// temperature tuning, used by llama4
104104
class llm_graph_input_attn_temp : public llm_graph_input_i {
105105
public:
106-
llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107-
: n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
106+
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107+
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
108108
virtual ~llm_graph_input_attn_temp() = default;
109109

110110
void set_input(const llama_ubatch * ubatch) override;
111111

112112
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
113113

114-
const int64_t n_pos_per_token = 1;
115-
116114
const uint32_t n_attn_temp_floor_scale;
117115
const float f_attn_temp_scale;
118116
};
@@ -419,7 +417,7 @@ struct llm_graph_context {
419417

420418
llm_graph_context(const llm_graph_params & params);
421419

422-
int64_t n_pos_per_token() const;
420+
int64_t n_pos_per_embd() const;
423421

424422
void cb(ggml_tensor * cur, const char * name, int il) const;
425423

0 commit comments

Comments
 (0)