Skip to content

Commit fdc0e47

Browse files
refactor llama_ngram_cache_update
1 parent dd1b905 commit fdc0e47

File tree

6 files changed

+19
-17
lines changed

6 files changed

+19
-17
lines changed

common/ngram-cache.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,18 @@
66
#include <fstream>
77

88
void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
9-
std::vector<llama_token> & inp, int nnew, bool print_progress) {
9+
llama_token * inp_data, int inp_size, int nnew, bool print_progress) {
1010
const int64_t t_start_ms = ggml_time_ms();
11-
const int64_t inp_size = inp.size();
1211

1312
const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1);
1413
int64_t n_done = 0;
1514

1615
for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
17-
const int64_t i_start = std::max(inp_size - nnew, ngram_size);
16+
const int64_t i_start = std::max((int64_t)(inp_size - nnew), ngram_size);
1817
for (int64_t i = i_start; i < inp_size; ++i) {
1918
const int64_t ngram_start = i - ngram_size;
20-
llama_ngram ngram(&inp[ngram_start], ngram_size);
21-
const llama_token token = inp[i];
19+
llama_ngram ngram(inp_data + ngram_start, ngram_size);
20+
const llama_token token = inp_data[i];
2221

2322
llama_ngram_cache::iterator part_it = ngram_cache.find(ngram);
2423
if (part_it == ngram_cache.end()) {

common/ngram-cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash
6464
// In order to get correct results inp_data can ONLY BE APPENDED TO.
6565
// Changes in the middle need a complete rebuild.
6666
void llama_ngram_cache_update(
67-
llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector<llama_token> & inp_data, int nnew, bool print_progress);
67+
llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max, llama_token * inp_data, int inp_size, int nnew, bool print_progress);
6868

6969
// Try to draft tokens from ngram caches.
7070
// inp: the tokens generated so far.

examples/lookup/lookup-create.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ int main(int argc, char ** argv){
3434

3535

3636
llama_ngram_cache ngram_cache;
37-
llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
37+
llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp.data(), inp.size(), inp.size(), true);
3838
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
3939

4040
llama_ngram_cache_save(ngram_cache, params.lookup_cache_static);

examples/lookup/lookup-stats.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ int main(int argc, char ** argv){
101101

102102
{
103103
const int64_t t_start_draft_us = ggml_time_us();
104-
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
104+
llama_ngram_cache_update(
105+
ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output.data(), pseudo_output.size(), 1, false);
105106
t_draft_us += ggml_time_us() - t_start_draft_us;
106107
}
107108
}
@@ -111,7 +112,8 @@ int main(int argc, char ** argv){
111112
pseudo_output.push_back(inp_slice[pseudo_output.size()]);
112113
{
113114
const int64_t t_start_draft_us = ggml_time_us();
114-
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
115+
llama_ngram_cache_update(
116+
ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output.data(), pseudo_output.size(), 1, false);
115117
t_draft_us += ggml_time_us() - t_start_draft_us;
116118
}
117119
}

examples/lookup/lookup.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ int main(int argc, char ** argv){
5353
{
5454
// Fill up context ngram cache with tokens from user input:
5555
const int64_t t_start_draft_us = ggml_time_us();
56-
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
56+
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp.data(), inp.size(), inp.size(), false);
5757

5858
if (!params.lookup_cache_static.empty()) {
5959
if(!llama_ngram_cache_load(ngram_cache_static, params.lookup_cache_static)) {
@@ -153,7 +153,7 @@ int main(int argc, char ** argv){
153153
{
154154
// Update context ngram cache with the newly accepted token:
155155
const int64_t t_start_draft_us = ggml_time_us();
156-
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
156+
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp.data(), inp.size(), 1, false);
157157
t_draft_us += ggml_time_us() - t_start_draft_us;
158158
}
159159

@@ -179,7 +179,7 @@ int main(int argc, char ** argv){
179179
{
180180
// Update context ngram cache with the newly accepted token:
181181
const int64_t t_start_draft_us = ggml_time_us();
182-
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
182+
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp.data(), inp.size(), 1, false);
183183
t_draft_us += ggml_time_us() - t_start_draft_us;
184184
}
185185
break;

examples/server/server.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,8 @@ struct server_context {
10811081
}
10821082
for (auto slot : slots) {
10831083
memcpy(slot.context_tokens.data(), system_tokens.data(), system_tokens.size()*sizeof(llama_token));
1084-
llama_ngram_cache_update(slot.nc_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, system_tokens, system_tokens.size(), false);
1084+
llama_ngram_cache_update(
1085+
slot.nc_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, system_tokens.data(), system_tokens.size(), system_tokens.size(), false);
10851086
}
10861087

10871088
const int32_t n_batch = llama_n_batch(ctx);
@@ -1901,8 +1902,8 @@ struct server_context {
19011902
// this is not great and needs to be improved somehow
19021903
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
19031904
slot.context_tokens[system_tokens.size() + slot_npast] = slot.sampled;
1904-
std::vector<llama_token> tail(slot.context_tokens.begin(), slot.context_tokens.begin() + system_tokens.size() + slot_npast);
1905-
llama_ngram_cache_update(slot.nc_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tail, 1, false);
1905+
llama_ngram_cache_update(
1906+
slot.nc_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, slot.context_tokens.data(), system_tokens.size() + slot_npast, 1, false);
19061907

19071908
slot.n_past += 1;
19081909

@@ -2155,8 +2156,8 @@ struct server_context {
21552156

21562157
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
21572158
slot.context_tokens[system_tokens.size() + slot_npast] = prompt_tokens[slot.n_past];
2158-
std::vector<llama_token> tail(slot.context_tokens.begin(), slot.context_tokens.begin() + slot_npast);
2159-
llama_ngram_cache_update(slot.nc_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tail, 1, false);
2159+
llama_ngram_cache_update(
2160+
slot.nc_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, slot.context_tokens.data(), slot_npast, 1, false);
21602161

21612162
if (slot.params.cache_prompt) {
21622163
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);

0 commit comments

Comments
 (0)