Skip to content

Commit 87e5656

Browse files
Server: enable lookup decoding
1 parent 40f74e4 commit 87e5656

File tree

8 files changed

+192
-59
lines changed

8 files changed

+192
-59
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
800800
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
801801
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
802802

803-
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
803+
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp common/stb_image.h ggml.o llama.o ngram-cache.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
804804
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
805805
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
806806

common/ngram-cache.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,11 @@ void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filen
216216

217217
}
218218

219-
llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
219+
bool llama_ngram_cache_load(llama_ngram_cache & ngram_cache, std::string & filename) {
220220
std::ifstream hashmap_file(filename, std::ios::binary);
221221
if (!hashmap_file) {
222-
throw std::ifstream::failure("Unable to open file " + filename);
222+
return false;
223223
}
224-
llama_ngram_cache ngram_cache;
225224

226225
llama_ngram ngram;
227226
int32_t ntokens;
@@ -251,7 +250,7 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
251250
}
252251
GGML_ASSERT(hashmap_file.eof());
253252

254-
return ngram_cache;
253+
return true;
255254
}
256255

257256
void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {

common/ngram-cache.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ void llama_ngram_cache_draft(
8484
void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename);
8585

8686
// Load an ngram cache saved with llama_ngram_cache_save.
87+
// ngram_cache: the ngram cache to load the data into.
8788
// filename: the path from which to load the ngram cache.
8889
// returns: an ngram cache containing the information saved to filename.
89-
llama_ngram_cache llama_ngram_cache_load(std::string & filename);
90+
bool llama_ngram_cache_load(llama_ngram_cache & ngram_cache, std::string & filename);
9091

9192
// Merge two ngram caches.
9293
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.

examples/lookup/lookup-merge.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ int main(int argc, char ** argv){
3333
}
3434

3535
fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
36-
llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]);
36+
llama_ngram_cache ngram_cache_merged;
37+
GGML_ASSERT(llama_ngram_cache_load(ngram_cache_merged, args[0]));
3738

3839
for (size_t i = 1; i < args.size()-1; ++i) {
3940
fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
40-
llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]);
41+
llama_ngram_cache ngram_cache;
42+
GGML_ASSERT(llama_ngram_cache_load(ngram_cache, args[i]));
4143

4244
llama_ngram_cache_merge(ngram_cache_merged, ngram_cache);
4345
}

examples/lookup/lookup-stats.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,15 @@ int main(int argc, char ** argv){
4747
const int64_t t_start_draft_us = ggml_time_us();
4848

4949
if (!params.lookup_cache_static.empty()) {
50-
try {
51-
ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
52-
} catch (std::ifstream::failure const &) {
50+
if(!llama_ngram_cache_load(ngram_cache_static, params.lookup_cache_static)) {
5351
fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
5452
exit(1);
5553
}
5654
}
5755

5856
if (!params.lookup_cache_dynamic.empty()) {
59-
try {
60-
ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
61-
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
57+
// If the dynamic lookup cache doesn't exist it will be created at the end of the program:
58+
llama_ngram_cache_load(ngram_cache_dynamic, params.lookup_cache_dynamic);
6259
}
6360

6461
t_draft_flat_us += ggml_time_us() - t_start_draft_us;

examples/lookup/lookup.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,18 +57,15 @@ int main(int argc, char ** argv){
5757
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
5858

5959
if (!params.lookup_cache_static.empty()) {
60-
try {
61-
ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
62-
} catch (std::ifstream::failure const &) {
60+
if(!llama_ngram_cache_load(ngram_cache_static, params.lookup_cache_static)) {
6361
fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
6462
exit(1);
6563
}
6664
}
6765

6866
if (!params.lookup_cache_dynamic.empty()) {
69-
try {
70-
ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
71-
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
67+
// If the dynamic lookup cache doesn't exist it will be created at the end of the program:
68+
llama_ngram_cache_load(ngram_cache_dynamic, params.lookup_cache_dynamic);
7269
}
7370

7471
t_draft_flat_us += ggml_time_us() - t_start_draft_us;

examples/server/bench/bench.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def main(args_in: list[str] | None = None) -> None:
4545
parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True)
4646
parser.add_argument("--scenario", type=str, help="Scenario to run", required=True)
4747
parser.add_argument("--duration", type=str, help="Bench scenario", required=True)
48+
parser.add_argument("--draft", type=int, help="Max. number of additional tokens to draft for lookup decoding", required=False, default=5)
49+
parser.add_argument("-lcs", "--lookup-cache-static", type=str, help="Path to optional static lookup cache to use.", required=False, default=None)
50+
parser.add_argument("-lcd", "--lookup-cache-dynamic", type=str, help="Path to optional dynamic lookup cache to use. Will be overwritten upon server shutdown.", required=False, default=None)
4851

4952
args = parser.parse_args(args_in)
5053

@@ -269,6 +272,11 @@ def start_server_background(args):
269272
server_args.append('--cont-batching')
270273
server_args.append('--metrics')
271274
server_args.extend(['--log-format', "text"])
275+
server_args.extend(['--draft', args.draft])
276+
if args.lookup_cache_static is not None:
277+
server_args.extend(['--lookup-cache-static', args.lookup_cache_static])
278+
if args.lookup_cache_dynamic is not None:
279+
server_args.extend(['--lookup-cache-dynamic', args.lookup_cache_dynamic])
272280
args = [str(arg) for arg in [server_path, *server_args]]
273281
print(f"bench: starting server with: {' '.join(args)}")
274282
pkwargs = {

0 commit comments

Comments
 (0)