Skip to content

Commit 2c70d79

Browse files
lookup-merge
1 parent 49e794f commit 2c70d79

File tree

3 files changed

+71
-1
lines changed

3 files changed

+71
-1
lines changed

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,8 @@ lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
746746
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
747747
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-create.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp)
748748
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp) -o lookup-create $(LDFLAGS)
749+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-merge.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp)
750+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp) -o lookup-merge $(LDFLAGS)
749751
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-stats.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp)
750752
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp) -o lookup-stats $(LDFLAGS)
751753

examples/lookup/lookup-create.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "common.h"
2-
#include "common/common.h"
32
#include "ggml.h"
43
#include "llama.h"
54

examples/lookup/lookup-merge.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include "common.h"
2+
#include "common/common.h"
3+
#include "ggml.h"
4+
#include "llama.h"
5+
6+
#include <cstdint>
7+
#include <cstdio>
8+
#include <fstream>
9+
#include <iostream>
10+
#include <string>
11+
#include <unordered_map>
12+
#include <vector>
13+
14+
static void print_usage() {
15+
fprintf(stderr, "Merges multiple lookup cache files into a single one.\n");
16+
fprintf(stderr, "Usage: lookup-merge [--help] lookup_part_1.bin lookup_part_2.bin ... lookup_merged.bin\n");
17+
}
18+
19+
int main(int argc, char ** argv){
20+
if (argc < 3) {
21+
print_usage();
22+
exit(1);
23+
}
24+
25+
std::vector<std::string> args;
26+
args.resize(argc-1);
27+
for (int i = 0; i < argc-1; ++i) {
28+
args[i] = argv[i+1];
29+
if (args[i] == "-h" || args[i] == "--help") {
30+
print_usage();
31+
exit(0);
32+
}
33+
}
34+
35+
std::vector<llama_ngram_cache> ngram_cache_merged;
36+
ngram_cache_merged.push_back(llama_ngram_cache_load(args[0]));
37+
38+
for (size_t i = 1; i < args.size()-1; ++i) {
39+
fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
40+
llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]);
41+
42+
for (std::pair<uint64_t, llama_ngram_cache_part> ngram_part : ngram_cache) {
43+
const uint64_t ngram = ngram_part.first;
44+
llama_ngram_cache_part part = ngram_part.second;
45+
46+
llama_ngram_cache::iterator part_merged_it = ngram_cache_merged[0].find(ngram);
47+
if (part_merged_it == ngram_cache_merged[0].end()) {
48+
ngram_cache_merged[0].emplace(ngram, part);
49+
continue;
50+
}
51+
52+
for (std::pair<llama_token, int32_t> token_count : part) {
53+
const llama_token token = token_count.first;
54+
const int32_t count = token_count.second;
55+
56+
llama_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token);
57+
if (token_count_merged_it == part_merged_it->second.end()) {
58+
part_merged_it->second.emplace(token, count);
59+
continue;
60+
} else {
61+
token_count_merged_it->second += count;
62+
}
63+
}
64+
}
65+
}
66+
67+
fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str());
68+
llama_ngram_cache_save(ngram_cache_merged, args.back());
69+
}

0 commit comments

Comments
 (0)