Skip to content

Commit 1d516d3

Browse files
refactor
1 parent fdc0e47 commit 1d516d3

File tree

5 files changed

+12
-13
lines changed

5 files changed

+12
-13
lines changed

common/ngram-cache.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
4747
}
4848

4949
// Helper function to get a token from the combined, speculative sequence of inp and draft.
50-
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
51-
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
50+
static llama_token get_token(const llama_token * inp_data, const int inp_size, const std::vector<llama_token> & draft, const int i) {
51+
return i < inp_size ? inp_data[i] : draft[1 + i - inp_size];
5252
}
5353

5454
// If sample size or percentage are below these thresholds the draft is aborted early:
@@ -139,11 +139,10 @@ static llama_token try_draft(
139139
}
140140

141141
void llama_ngram_cache_draft(
142-
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
142+
llama_token * inp_data, int inp_size, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
143143
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
144144
) {
145145
GGML_ASSERT(draft.size() == 1);
146-
const int inp_size = inp.size();
147146

148147
if (inp_size < LLAMA_NGRAM_STATIC) {
149148
return;
@@ -155,7 +154,7 @@ void llama_ngram_cache_draft(
155154
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
156155
llama_ngram ngram_static;
157156
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
158-
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
157+
ngram_static.tokens[j-ngram_start_static] = get_token(inp_data, inp_size, draft, j);
159158
}
160159
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
161160
llama_ngram_cache_part part_static;
@@ -169,7 +168,7 @@ void llama_ngram_cache_draft(
169168
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
170169
llama_ngram ngram_cd;
171170
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
172-
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
171+
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp_data, inp_size, draft, j);
173172
}
174173
ngrams_cd.push_back(ngram_cd);
175174
}

common/ngram-cache.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ void llama_ngram_cache_update(
7575
// nc_dynamic: ngram cache based on previous user generations.
7676
// nc_static: ngram cache generated from a large text corpus, used for validation.
7777
void llama_ngram_cache_draft(
78-
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
78+
llama_token * inp_data, int inp_size, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
7979
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static);
8080

8181
// Save an ngram cache to a file.

examples/lookup/lookup-stats.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ int main(int argc, char ** argv){
8282

8383
{
8484
const int64_t t_start_draft_us = ggml_time_us();
85-
llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
85+
llama_ngram_cache_draft(
86+
pseudo_output.data(), pseudo_output.size(), draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
87+
ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
8688
t_draft_us += ggml_time_us() - t_start_draft_us;
8789
}
8890

examples/lookup/lookup.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ int main(int argc, char ** argv){
201201
GGML_ASSERT(draft[0] == inp.back());
202202
const int64_t t_start_draft_us = ggml_time_us();
203203

204-
llama_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
204+
llama_ngram_cache_draft(
205+
inp.data(), inp.size(), draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
205206

206207
for (size_t i = 1; i < draft.size(); ++i) {
207208
llama_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);

examples/server/server.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,13 +1920,10 @@ struct server_context {
19201920
continue;
19211921
}
19221922

1923-
const int32_t tail_start = std::max(slot.n_past - LLAMA_NGRAM_MAX, 0);
1924-
std::vector<llama_token> context_tail(slot.context_tokens.begin() + tail_start, slot.context_tokens.begin() + slot.n_past);
1925-
19261923
slot.draft.clear();
19271924
slot.draft.push_back(slot.context_tokens[slot.n_past - 1]);
19281925
llama_ngram_cache_draft(
1929-
context_tail, slot.draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, slot.nc_context, nc_dynamic, nc_static);
1926+
slot.context_tokens.data(), slot.n_past, slot.draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, slot.nc_context, nc_dynamic, nc_static);
19301927

19311928
for (int j = 1; j < (int)slot.draft.size(); ++j) {
19321929
llama_batch_add(batch, slot.draft[j], system_tokens.size() + slot.n_past, {slot.id + 1}, true);

0 commit comments

Comments
 (0)