Skip to content

Commit 465cadd

Browse files
committed
Refactor special tokens tokenization
1 parent ada6cce commit 465cadd

File tree

1 file changed

+29
-53
lines changed

1 file changed

+29
-53
lines changed

llama.cpp

Lines changed: 29 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,9 @@ struct llama_vocab {
279279
std::vector<token_score> id_to_token;
280280

281281
std::unordered_map<token, id> special_token_to_id;
282-
size_t max_special_token_length = 0;
283282

284283
void add_special_token(const token & word, id token_id) {
285284
special_token_to_id[word] = token_id;
286-
287-
if (max_special_token_length < word.size()) {
288-
max_special_token_length = word.size();
289-
}
290285
}
291286
};
292287

@@ -2088,38 +2083,6 @@ struct llama_tokenizer {
20882083
llama_sp_bigram::queue work_queue_;
20892084
};
20902085

2091-
static std::vector<size_t> llama_split_special_tokens(const llama_vocab & vocab, const std::string & text) {
2092-
std::vector<size_t> offsets{0};
2093-
size_t start = 0;
2094-
2095-
while (start < text.size()) {
2096-
size_t max_end = start;
2097-
const std::string * max_delimiter = nullptr;
2098-
2099-
for (const auto & mit : vocab.special_token_to_id) {
2100-
const std::string & delimiter = mit.first;
2101-
size_t end = start + delimiter.size();
2102-
if (end <= text.size() && text.compare(start, delimiter.size(), delimiter) == 0) {
2103-
if (max_delimiter == nullptr || delimiter.size() > max_delimiter->size()) {
2104-
max_end = end;
2105-
max_delimiter = &delimiter;
2106-
}
2107-
}
2108-
}
2109-
2110-
if (max_delimiter != nullptr) {
2111-
offsets.push_back(start);
2112-
offsets.push_back(max_end);
2113-
start = max_end;
2114-
} else {
2115-
start++;
2116-
}
2117-
}
2118-
2119-
offsets.push_back(text.size());
2120-
return offsets;
2121-
}
2122-
21232086
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
21242087
llama_tokenizer tokenizer(vocab);
21252088
std::vector<llama_vocab::id> output;
@@ -2137,27 +2100,40 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
21372100
return output;
21382101
}
21392102

2140-
std::vector<size_t> offsets = llama_split_special_tokens(vocab, text);
2141-
size_t start = 0;
2142-
for (size_t end : offsets) {
2143-
if (start >= end) {
2144-
continue;
2103+
size_t delim_start = 0;
2104+
size_t last_delim_end = 0;
2105+
2106+
while (delim_start < text.size()) {
2107+
size_t delim_end = 0;
2108+
llama_vocab::id token_id = -1;
2109+
2110+
for (const auto & mit : vocab.special_token_to_id) {
2111+
const std::string & delimiter = mit.first;
2112+
size_t end = delim_start + delimiter.size();
2113+
if (end <= text.size() && text.compare(delim_start, delimiter.size(), delimiter) == 0) {
2114+
if (token_id == -1 || end > delim_end) {
2115+
token_id = mit.second;
2116+
delim_end = end;
2117+
}
2118+
}
21452119
}
21462120

2147-
const char *part = text.c_str() + start;
2148-
size_t part_len = end - start;
2149-
if (vocab.max_special_token_length < part_len) {
2150-
tokenizer.tokenize(part, part_len, output);
2151-
} else {
2152-
auto token_it = vocab.special_token_to_id.find(std::string(part, part_len));
2153-
if (token_it != vocab.special_token_to_id.end()) {
2154-
output.push_back(token_it->second);
2155-
} else {
2156-
tokenizer.tokenize(part, part_len, output);
2121+
if (token_id != -1) {
2122+
if (last_delim_end < delim_start) {
2123+
tokenizer.tokenize(text.c_str() + last_delim_end, delim_start - last_delim_end, output);
21572124
}
2125+
output.push_back(token_id);
2126+
delim_start = delim_end;
2127+
last_delim_end = delim_end;
2128+
} else {
2129+
delim_start++;
21582130
}
2159-
start = end;
21602131
}
2132+
2133+
if (last_delim_end < text.size()) {
2134+
tokenizer.tokenize(text.c_str() + last_delim_end, text.size() - last_delim_end, output);
2135+
}
2136+
21612137
return output;
21622138
}
21632139

0 commit comments

Comments
 (0)