Skip to content

Commit d9791bb

Browse files
committed
Add C API for adding special tokens
1 parent 099119f commit d9791bb

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

llama.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,15 @@ struct llama_vocab {
281281
llama_trie special_token_trie;
282282
std::unordered_map<token, id> special_token_to_id;
283283
size_t max_special_token_length = 0;
284+
285+
void add_special_token(const token & word, id token_id) {
286+
special_token_trie.add(word);
287+
special_token_to_id[word] = token_id;
288+
289+
if (max_special_token_length < word.size()) {
290+
max_special_token_length = word.size();
291+
}
292+
}
284293
};
285294

286295
struct llama_model {
@@ -624,15 +633,8 @@ struct llama_file_loader {
624633
for (uint32_t i = 0; i < vocab_sp; i++) {
625634
llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
626635
const auto & word = vocab.id_to_token[token_id].tok;
627-
if (word.empty()) {
628-
continue;
629-
}
630-
631-
vocab.special_token_trie.add(word);
632-
vocab.special_token_to_id[word] = token_id;
633-
634-
if (vocab.max_special_token_length < word.size()) {
635-
vocab.max_special_token_length = word.size();
636+
if (!word.empty()) {
637+
vocab.add_special_token(word, token_id);
636638
}
637639
}
638640
}
@@ -4263,6 +4265,10 @@ llama_token llama_token_nl() {
42634265
return 13;
42644266
}
42654267

4268+
void llama_add_special_token(struct llama_model * model, const char * token, llama_token token_id) {
4269+
model->vocab.add_special_token(token, token_id);
4270+
}
4271+
42664272
struct llama_timings llama_get_timings(struct llama_context * ctx) {
42674273
struct llama_timings result = {
42684274
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,

llama.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,11 @@ extern "C" {
373373
LLAMA_API llama_token llama_token_eos(); // end-of-sentence
374374
LLAMA_API llama_token llama_token_nl(); // next-line
375375

376+
LLAMA_API void llama_add_special_token(
377+
struct llama_model * model,
378+
const char * token,
379+
llama_token token_id);
380+
376381
// Grammar
377382
//
378383
LLAMA_API struct llama_grammar * llama_grammar_init(

0 commit comments

Comments
 (0)