Skip to content

Commit 9cd16a3

Browse files
committed
rm n_pos_per_embd from llm_graph_input_attn_temp
1 parent bd310ff commit 9cd16a3

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

src/llama-graph.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
8282
) * f_attn_temp_scale + 1.0;
8383
}
8484

85-
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_embd*ggml_element_size(attn_scale));
85+
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
8686
}
8787
}
8888

@@ -1042,12 +1042,12 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
10421042
}
10431043

10441044
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
1045-
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_embd(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
1045+
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
10461046

10471047
auto & cur = inp->attn_scale;
10481048

10491049
// this need to be 1x1xN for broadcasting
1050-
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_embd());
1050+
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
10511051
ggml_set_input(cur);
10521052

10531053
res->add_input(std::move(inp));

src/llama-graph.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,14 @@ class llm_graph_input_pos : public llm_graph_input_i {
103103
// temperature tuning, used by llama4
104104
class llm_graph_input_attn_temp : public llm_graph_input_i {
105105
public:
106-
llm_graph_input_attn_temp(int64_t n_pos_per_embd, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107-
: n_pos_per_embd(n_pos_per_embd), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
106+
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107+
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
108108
virtual ~llm_graph_input_attn_temp() = default;
109109

110110
void set_input(const llama_ubatch * ubatch) override;
111111

112112
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
113113

114-
const int64_t n_pos_per_embd = 1;
115-
116114
const uint32_t n_attn_temp_floor_scale;
117115
const float f_attn_temp_scale;
118116
};

0 commit comments

Comments
 (0)