Skip to content

Commit 863a440

Browse files
committed
Replace trie with linear search
1 parent 4fc3776 commit 863a440

File tree

2 files changed

+33
-165
lines changed

2 files changed

+33
-165
lines changed

llama-util.h

Lines changed: 0 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -544,166 +544,4 @@ struct llama_ctx_buffer {
544544
typedef llama_buffer llama_ctx_buffer;
545545
#endif
546546

547-
struct llama_trie_node {
548-
llama_trie_node(): is_terminator(false) {}
549-
550-
std::unordered_map<char, std::unique_ptr<llama_trie_node>> children;
551-
bool is_terminator;
552-
};
553-
554-
// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass
555-
// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52
556-
struct llama_trie {
557-
public:
558-
llama_trie(): root_(new llama_trie_node()) {}
559-
560-
void add(const std::string & word) {
561-
if (word.empty()) {
562-
return;
563-
}
564-
565-
llama_trie_node *ref = root_.get();
566-
for (char c : word) {
567-
if (ref->children.find(c) == ref->children.end()) {
568-
ref->children[c].reset(new llama_trie_node());
569-
}
570-
ref = ref->children[c].get();
571-
}
572-
ref->is_terminator = true;
573-
}
574-
575-
// Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
576-
// Note that this trie will match the longest possible word first!
577-
std::vector<size_t> split(const std::string & text) const {
578-
std::map<size_t, llama_trie_node*> states;
579-
std::vector<size_t> offsets{0};
580-
581-
size_t skip = 0;
582-
for (size_t current = 0; current < text.size(); current++) {
583-
char current_char = text[current];
584-
if (skip > 0 && current < skip) {
585-
// Prevents the lookahead for matching twice
586-
// like extra_id_100 and id_100
587-
continue;
588-
}
589-
590-
// Whenever we found a match, we need to drop everything
591-
// this is a greedy algorithm, it will match on the first found token
592-
bool reset = false;
593-
594-
// In this case, we already have partial matches (But unfinished)
595-
for (auto state = states.begin(); state != states.end(); ) {
596-
size_t start = state->first;
597-
llama_trie_node *trie_pointer = state->second;
598-
if (trie_pointer->is_terminator) {
599-
// This is a final match, we need to reset and
600-
// store the results in `offsets`.
601-
602-
// Lookahead to match longest first
603-
// Important in case of extra_id_1 vs extra_id_100
604-
// Here we are also actively looking for other earlier partial
605-
// matches
606-
// "[CLS]", "L", we need to match CLS even if L is special
607-
size_t end = 0;
608-
for (const auto & look : states) {
609-
size_t lookstart = look.first;
610-
llama_trie_node *looktrie_pointer = look.second;
611-
size_t lookahead_index = 0;
612-
if (lookstart > start) {
613-
// This partial match is later, we can stop looking
614-
break;
615-
}
616-
if (lookstart < start) {
617-
// This partial match is earlier, the trie pointer
618-
// was already updated, so index is + 1
619-
lookahead_index = current + 1;
620-
end = current + 1;
621-
} else {
622-
// Here lookstart == start and
623-
// looktrie_pointer == trie_pointer
624-
// It wasn't updated yet so indices are current ones
625-
lookahead_index = current;
626-
end = current;
627-
}
628-
char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0';
629-
if (looktrie_pointer->is_terminator) {
630-
start = lookstart;
631-
end = lookahead_index;
632-
skip = lookahead_index;
633-
}
634-
635-
auto looktrie_pointer_it = looktrie_pointer->children.find(next_char);
636-
while (looktrie_pointer_it != looktrie_pointer->children.end()) {
637-
looktrie_pointer = looktrie_pointer_it->second.get();
638-
lookahead_index++;
639-
if (looktrie_pointer->is_terminator) {
640-
start = lookstart;
641-
end = lookahead_index;
642-
skip = lookahead_index;
643-
}
644-
645-
if (lookahead_index == text.size()) {
646-
// End of string
647-
break;
648-
}
649-
next_char = text[lookahead_index];
650-
looktrie_pointer_it = looktrie_pointer->children.find(next_char);
651-
}
652-
}
653-
654-
offsets.push_back(start);
655-
offsets.push_back(end);
656-
reset = true;
657-
break;
658-
}
659-
660-
auto trie_pointer_it = trie_pointer->children.find(current_char);
661-
if (trie_pointer_it != trie_pointer->children.end()) {
662-
// The current character being looked at has a match within the trie
663-
// update the pointer (it will be stored back into states later).
664-
trie_pointer = trie_pointer_it->second.get();
665-
states[start] = trie_pointer;
666-
++state;
667-
} else {
668-
// The new character has not match in the trie, we need
669-
// to stop keeping track of this partial match.
670-
state = states.erase(state);
671-
}
672-
}
673-
674-
if (reset) {
675-
// Clear the full start (we found a real match)
676-
states.clear();
677-
}
678-
679-
// If this character is a starting character within the trie
680-
// start keeping track of this partial match.
681-
auto children_it = root_->children.find(current_char);
682-
if (current >= skip && children_it != root_->children.end()) {
683-
states[current] = children_it->second.get();
684-
}
685-
}
686-
687-
// We have a cut at the end with states.
688-
for (const auto & state : states) {
689-
size_t start = state.first;
690-
llama_trie_node *trie_pointer = state.second;
691-
if (trie_pointer->is_terminator) {
692-
// This is a final match, we need to reset and
693-
// store the results in `offsets`.
694-
size_t end = text.size();
695-
offsets.push_back(start);
696-
offsets.push_back(end);
697-
break;
698-
}
699-
}
700-
701-
offsets.push_back(text.size());
702-
return offsets;
703-
}
704-
705-
private:
706-
std::unique_ptr<llama_trie_node> root_;
707-
};
708-
709547
#endif

llama.cpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,10 @@ struct llama_vocab {
278278
std::unordered_map<token, id> token_to_id;
279279
std::vector<token_score> id_to_token;
280280

281-
llama_trie special_token_trie;
282281
std::unordered_map<token, id> special_token_to_id;
283282
size_t max_special_token_length = 0;
284283

285284
void add_special_token(const token & word, id token_id) {
286-
special_token_trie.add(word);
287285
special_token_to_id[word] = token_id;
288286

289287
if (max_special_token_length < word.size()) {
@@ -2090,6 +2088,38 @@ struct llama_tokenizer {
20902088
llama_sp_bigram::queue work_queue_;
20912089
};
20922090

2091+
static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) {
2092+
std::vector<size_t> offsets{0};
2093+
size_t start = 0;
2094+
2095+
while (start < text.size()) {
2096+
size_t max_end = start;
2097+
const std::string * max_delimiter = nullptr;
2098+
2099+
for (const auto & mit : vocab.special_token_to_id) {
2100+
const std::string & delimiter = mit.first;
2101+
size_t end = start + delimiter.size();
2102+
if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) {
2103+
if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) {
2104+
max_end = end;
2105+
max_delimiter = &delimiter;
2106+
}
2107+
}
2108+
}
2109+
2110+
if (max_delimiter != nullptr) {
2111+
offsets.push_back(start);
2112+
offsets.push_back(max_end);
2113+
start = max_end;
2114+
} else {
2115+
start++;
2116+
}
2117+
}
2118+
2119+
offsets.push_back(text.size());
2120+
return offsets;
2121+
}
2122+
20932123
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
20942124
llama_tokenizer tokenizer(vocab);
20952125
std::vector<llama_vocab::id> output;
@@ -2107,7 +2137,7 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
21072137
return output;
21082138
}
21092139

2110-
std::vector<size_t> offsets = vocab.special_token_trie.split(text);
2140+
std::vector<size_t> offsets = llama_split_special_tokens(vocab, text);
21112141
size_t start = 0;
21122142
for (size_t end : offsets) {
21132143
if (start >= end) {

0 commit comments

Comments
 (0)