Skip to content

Commit e1a502c

Browse files
committed
kv_cache : provide rope factors
ggml-ci
1 parent 10ea682 commit e1a502c

File tree

5 files changed

+55
-23
lines changed

5 files changed

+55
-23
lines changed

src/llama-context.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,7 +1729,7 @@ llama_context_kv_self::llama_context_kv_self(
17291729

17301730
const auto & hparams = model.hparams;
17311731

1732-
kv_self = std::make_unique<llama_kv_cache_unified>(hparams);
1732+
kv_self.reset(static_cast<llama_kv_cache_unified *>(model.create_memory()));
17331733

17341734
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
17351735

@@ -1885,7 +1885,7 @@ llm_graph_result_ptr llama_context_kv_self::build_kv_self_shift(
18851885
const int64_t n_head_kv = hparams.n_head_kv(il);
18861886
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
18871887

1888-
ggml_tensor * rope_factors = model.build_rope_factors(n_ctx_per_seq(), il);
1888+
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
18891889

18901890
ggml_tensor * k =
18911891
ggml_view_3d(ctx0, kv_self->k_l[il],
@@ -2665,7 +2665,7 @@ llama_context_recurrent::llama_context_recurrent(
26652665

26662666
const auto & hparams = model.hparams;
26672667

2668-
kv_self = std::make_unique<llama_kv_cache_recurrent>(hparams);
2668+
kv_self.reset(static_cast<llama_kv_cache_recurrent *>(model.create_memory()));
26692669

26702670
LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx);
26712671

src/llama-kv-cache.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
1515

16-
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams) : hparams(hparams) {
16+
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
1717
}
1818

1919
bool llama_kv_cache_unified::init(

src/llama-kv-cache.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
#include "ggml-cpp.h"
88

9+
#include <functional>
910
#include <set>
1011
#include <vector>
11-
#include <algorithm>
1212

1313
struct llama_cparams;
1414
struct llama_hparams;
@@ -62,7 +62,15 @@ struct llama_kv_cache_slot_info {
6262
// TODO: add notion of max sequences
6363
class llama_kv_cache_unified : public llama_kv_cache {
6464
public:
65-
llama_kv_cache_unified(const llama_hparams & hparams);
65+
// can be used to query data from the model if needed
66+
struct callbacks {
67+
std::function<ggml_tensor * (uint32_t n_ctx_per_seq, int il)> get_rope_factors;
68+
};
69+
70+
llama_kv_cache_unified(
71+
const llama_hparams & hparams,
72+
callbacks cbs);
73+
6674
virtual ~llama_kv_cache_unified() = default;
6775

6876
// TODO: become constructor
@@ -129,6 +137,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
129137

130138
const llama_hparams & hparams;
131139

140+
callbacks cbs;
141+
132142
bool has_shift = false;
133143
bool do_defrag = false;
134144

src/llama-model.cpp

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3824,7 +3824,7 @@ struct llm_build_llama : public llm_graph_context {
38243824
// self-attention
38253825
{
38263826
// rope freq factors for llama3; may return nullptr for llama2 and other models
3827-
struct ggml_tensor * rope_factors = model.build_rope_factors(n_ctx_per_seq, il);
3827+
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
38283828

38293829
// compute Q and K and RoPE them
38303830
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -3998,7 +3998,7 @@ struct llm_build_deci : public llm_graph_context {
39983998
} else if (n_head > 0) {
39993999
// self-attention
40004000
// rope freq factors for llama3; may return nullptr for llama2 and other models
4001-
struct ggml_tensor * rope_factors = model.build_rope_factors(n_ctx_per_seq, il);
4001+
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
40024002

40034003
// compute Q and K and RoPE them
40044004
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -6156,7 +6156,7 @@ struct llm_build_phi3 : public llm_graph_context {
61566156
// self-attention
61576157
{
61586158
// rope freq factors for 128k context
6159-
struct ggml_tensor * rope_factors = model.build_rope_factors(n_ctx_per_seq, il);
6159+
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
61606160

61616161
struct ggml_tensor* attn_norm_output = build_norm(inpL,
61626162
model.layers[il].attn_norm,
@@ -6879,7 +6879,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
68796879
for (int il = 0; il < n_layer; ++il) {
68806880
struct ggml_tensor * inpSA = inpL;
68816881

6882-
struct ggml_tensor * rope_factors = model.build_rope_factors(n_ctx_per_seq, il);
6882+
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
68836883

68846884
// norm
68856885
cur = build_norm(inpL,
@@ -7801,7 +7801,7 @@ struct llm_build_cohere2 : public llm_graph_context {
78017801
// self-attention
78027802
{
78037803
// rope freq factors for 128k context
7804-
struct ggml_tensor * rope_factors = model.build_rope_factors(n_ctx_per_seq, il);
7804+
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
78057805

78067806
// compute Q and K and RoPE them
78077807
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -8715,7 +8715,7 @@ struct llm_build_deepseek : public llm_graph_context {
87158715
// self-attention
87168716
{
87178717
// rope freq factors for llama3; may return nullptr for llama2 and other models
8718-
struct ggml_tensor * rope_factors = model.build_rope_factors(n_ctx_per_seq, il);
8718+
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
87198719

87208720
// compute Q and K and RoPE them
87218721
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9872,7 +9872,7 @@ struct llm_build_exaone : public llm_graph_context {
98729872
// self-attention
98739873
{
98749874
// rope freq factors for llama3; may return nullptr for llama2 and other models
9875-
struct ggml_tensor * rope_factors = model.build_rope_factors(n_ctx_per_seq, il);
9875+
ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
98769876

98779877
// compute Q and K and RoPE them
98789878
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -10682,17 +10682,38 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context {
1068210682
}
1068310683
};
1068410684

10685-
ggml_tensor * llama_model::build_rope_factors(uint32_t n_ctx_per_seq, int il) const {
10686-
// choose long/short freq factors based on the context size
10687-
if (layers[il].rope_freqs != nullptr) {
10688-
return layers[il].rope_freqs;
10689-
}
10685+
llama_memory_i * llama_model::create_memory() const {
10686+
llama_memory_i * res;
1069010687

10691-
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
10692-
return layers[il].rope_long;
10688+
switch (arch) {
10689+
case LLM_ARCH_RWKV6:
10690+
case LLM_ARCH_RWKV6QWEN2:
10691+
case LLM_ARCH_MAMBA:
10692+
{
10693+
res = new llama_kv_cache_recurrent(hparams, {
10694+
/*.get_rope_factors =*/ nullptr
10695+
});
10696+
} break;
10697+
default:
10698+
{
10699+
res = new llama_kv_cache_unified(hparams, {
10700+
/*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) {
10701+
// choose long/short freq factors based on the context size
10702+
if (layers[il].rope_freqs != nullptr) {
10703+
return layers[il].rope_freqs;
10704+
}
10705+
10706+
if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) {
10707+
return layers[il].rope_long;
10708+
}
10709+
10710+
return layers[il].rope_short;
10711+
}
10712+
});
10713+
}
1069310714
}
1069410715

10695-
return layers[il].rope_short;
10716+
return res;
1069610717
}
1069710718

1069810719
llm_graph_result_ptr llama_model::build_graph(

src/llama-model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
#include "llama.h"
44
#include "llama-arch.h"
5-
#include "llama-hparams.h"
65
#include "llama-graph.h"
6+
#include "llama-hparams.h"
7+
#include "llama-memory.h"
78
#include "llama-vocab.h"
89

910
#include <memory>
@@ -366,7 +367,7 @@ struct llama_model {
366367
const struct ggml_tensor * get_tensor(const char * name) const;
367368

368369
// TODO: move this to new llm_arch_model_i interface
369-
ggml_tensor * build_rope_factors(uint32_t n_ctx_per_seq, int il) const;
370+
llama_memory_i * create_memory() const; // TODO: params
370371

371372
// TODO: move this to new llm_arch_model_i interface
372373
llm_graph_result_ptr build_graph(

0 commit comments

Comments
 (0)