@@ -281,6 +281,15 @@ struct llama_vocab {
281
281
llama_trie special_token_trie;
282
282
std::unordered_map<token, id> special_token_to_id;
283
283
size_t max_special_token_length = 0 ;
284
+
285
+ void add_special_token (const token & word, id token_id) {
286
+ special_token_trie.add (word);
287
+ special_token_to_id[word] = token_id;
288
+
289
+ if (max_special_token_length < word.size ()) {
290
+ max_special_token_length = word.size ();
291
+ }
292
+ }
284
293
};
285
294
286
295
struct llama_model {
@@ -624,15 +633,8 @@ struct llama_file_loader {
624
633
for (uint32_t i = 0 ; i < vocab_sp; i++) {
625
634
llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
626
635
const auto & word = vocab.id_to_token [token_id].tok ;
627
- if (word.empty ()) {
628
- continue ;
629
- }
630
-
631
- vocab.special_token_trie .add (word);
632
- vocab.special_token_to_id [word] = token_id;
633
-
634
- if (vocab.max_special_token_length < word.size ()) {
635
- vocab.max_special_token_length = word.size ();
636
+ if (!word.empty ()) {
637
+ vocab.add_special_token (word, token_id);
636
638
}
637
639
}
638
640
}
@@ -4263,6 +4265,10 @@ llama_token llama_token_nl() {
4263
4265
return 13 ;
4264
4266
}
4265
4267
4268
+ void llama_add_special_token (struct llama_model * model, const char * token, llama_token token_id) {
4269
+ model->vocab .add_special_token (token, token_id);
4270
+ }
4271
+
4266
4272
struct llama_timings llama_get_timings (struct llama_context * ctx) {
4267
4273
struct llama_timings result = {
4268
4274
/* .t_start_ms =*/ 1e-3 * ctx->t_start_us ,
0 commit comments