@@ -279,14 +279,9 @@ struct llama_vocab {
279
279
std::vector<token_score> id_to_token;
280
280
281
281
std::unordered_map<token, id> special_token_to_id;
282
- size_t max_special_token_length = 0 ;
283
282
284
283
void add_special_token (const token & word, id token_id) {
285
284
special_token_to_id[word] = token_id;
286
-
287
- if (max_special_token_length < word.size ()) {
288
- max_special_token_length = word.size ();
289
- }
290
285
}
291
286
};
292
287
@@ -2088,38 +2083,6 @@ struct llama_tokenizer {
2088
2083
llama_sp_bigram::queue work_queue_;
2089
2084
};
2090
2085
2091
- static std::vector<size_t > llama_split_special_tokens (const llama_vocab & vocab, const std::string & text) {
2092
- std::vector<size_t > offsets{0 };
2093
- size_t start = 0 ;
2094
-
2095
- while (start < text.size ()) {
2096
- size_t max_end = start;
2097
- const std::string * max_delimiter = nullptr ;
2098
-
2099
- for (const auto & mit : vocab.special_token_to_id ) {
2100
- const std::string & delimiter = mit.first ;
2101
- size_t end = start + delimiter.size ();
2102
- if (end <= text.size () && text.compare (start, delimiter.size (), delimiter) == 0 ) {
2103
- if (max_delimiter == nullptr || delimiter.size () > max_delimiter->size ()) {
2104
- max_end = end;
2105
- max_delimiter = &delimiter;
2106
- }
2107
- }
2108
- }
2109
-
2110
- if (max_delimiter != nullptr ) {
2111
- offsets.push_back (start);
2112
- offsets.push_back (max_end);
2113
- start = max_end;
2114
- } else {
2115
- start++;
2116
- }
2117
- }
2118
-
2119
- offsets.push_back (text.size ());
2120
- return offsets;
2121
- }
2122
-
2123
2086
static std::vector<llama_vocab::id> llama_tokenize (const llama_vocab & vocab, const std::string & text, bool bos) {
2124
2087
llama_tokenizer tokenizer (vocab);
2125
2088
std::vector<llama_vocab::id> output;
@@ -2137,27 +2100,40 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
2137
2100
return output;
2138
2101
}
2139
2102
2140
- std::vector<size_t > offsets = llama_split_special_tokens (vocab, text);
2141
- size_t start = 0 ;
2142
- for (size_t end : offsets) {
2143
- if (start >= end) {
2144
- continue ;
2103
+ size_t delim_start = 0 ;
2104
+ size_t last_delim_end = 0 ;
2105
+
2106
+ while (delim_start < text.size ()) {
2107
+ size_t delim_end = 0 ;
2108
+ llama_vocab::id token_id = -1 ;
2109
+
2110
+ for (const auto & mit : vocab.special_token_to_id ) {
2111
+ const std::string & delimiter = mit.first ;
2112
+ size_t end = delim_start + delimiter.size ();
2113
+ if (end <= text.size () && text.compare (delim_start, delimiter.size (), delimiter) == 0 ) {
2114
+ if (token_id == -1 || end > delim_end) {
2115
+ token_id = mit.second ;
2116
+ delim_end = end;
2117
+ }
2118
+ }
2145
2119
}
2146
2120
2147
- const char *part = text.c_str () + start;
2148
- size_t part_len = end - start;
2149
- if (vocab.max_special_token_length < part_len) {
2150
- tokenizer.tokenize (part, part_len, output);
2151
- } else {
2152
- auto token_it = vocab.special_token_to_id .find (std::string (part, part_len));
2153
- if (token_it != vocab.special_token_to_id .end ()) {
2154
- output.push_back (token_it->second );
2155
- } else {
2156
- tokenizer.tokenize (part, part_len, output);
2121
+ if (token_id != -1 ) {
2122
+ if (last_delim_end < delim_start) {
2123
+ tokenizer.tokenize (text.c_str () + last_delim_end, delim_start - last_delim_end, output);
2157
2124
}
2125
+ output.push_back (token_id);
2126
+ delim_start = delim_end;
2127
+ last_delim_end = delim_end;
2128
+ } else {
2129
+ delim_start++;
2158
2130
}
2159
- start = end;
2160
2131
}
2132
+
2133
+ if (last_delim_end < text.size ()) {
2134
+ tokenizer.tokenize (text.c_str () + last_delim_end, text.size () - last_delim_end, output);
2135
+ }
2136
+
2161
2137
return output;
2162
2138
}
2163
2139
0 commit comments