Skip to content

Commit 4fbd809

Browse files
authored
gguf : add special tokens metadata for FIM/Infill (#6689)
This commit adds special token metadata for Fill-In-the-Middle (FIM)/Infill to the GGUF model. The motivation for this is that currently there is support for CodeLlama but other models exist now like CodeGemma, but the different models use different token ids for the special tokens and this commit allows for supporting multiple models. Signed-off-by: Daniel Bevenius <[email protected]>
1 parent 7593639 commit 4fbd809

File tree

4 files changed

+83
-11
lines changed

4 files changed

+83
-11
lines changed

convert-hf-to-gguf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,6 +1221,14 @@ def set_vocab(self):
12211221
except FileNotFoundError:
12221222
self._set_vocab_llama_hf()
12231223

1224+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
1225+
special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
1226+
special_vocab._set_special_token("prefix", 32007)
1227+
special_vocab._set_special_token("suffix", 32008)
1228+
special_vocab._set_special_token("middle", 32009)
1229+
special_vocab._set_special_token("eot", 32010)
1230+
special_vocab.add_to_gguf(self.gguf_writer)
1231+
12241232
def set_gguf_parameters(self):
12251233
super().set_gguf_parameters()
12261234
hparams = self.hparams
@@ -2240,6 +2248,13 @@ class GemmaModel(Model):
22402248

22412249
def set_vocab(self):
22422250
self._set_vocab_sentencepiece()
2251+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
2252+
special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
2253+
special_vocab._set_special_token("prefix", 67)
2254+
special_vocab._set_special_token("suffix", 69)
2255+
special_vocab._set_special_token("middle", 68)
2256+
special_vocab._set_special_token("eot", 70)
2257+
special_vocab.add_to_gguf(self.gguf_writer)
22432258

22442259
def set_gguf_parameters(self):
22452260
hparams = self.hparams

gguf-py/gguf/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ class Tokenizer:
9090
HF_JSON = "tokenizer.huggingface.json"
9191
RWKV = "tokenizer.rwkv.world"
9292
CHAT_TEMPLATE = "tokenizer.chat_template"
93+
# FIM/Infill special tokens constants
94+
PREFIX_ID = "tokenizer.ggml.prefix_token_id"
95+
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
96+
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
97+
EOT_ID = "tokenizer.ggml.eot_token_id"
9398

9499

95100
#
@@ -885,3 +890,7 @@ def get_type(val: Any) -> GGUFValueType:
885890
KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID
886891
KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON
887892
KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV
893+
KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID
894+
KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
895+
KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
896+
KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID

gguf-py/gguf/gguf_writer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,18 @@ def add_add_space_prefix(self, value: bool) -> None:
469469
def add_chat_template(self, value: str) -> None:
470470
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
471471

472+
def add_prefix_token_id(self, id: int) -> None:
473+
self.add_uint32(Keys.Tokenizer.PREFIX_ID, id)
474+
475+
def add_suffix_token_id(self, id: int) -> None:
476+
self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id)
477+
478+
def add_middle_token_id(self, id: int) -> None:
479+
self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id)
480+
481+
def add_eot_token_id(self, id: int) -> None:
482+
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
483+
472484
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
473485
pack_prefix = ''
474486
if not skip_pack_prefix:

llama.cpp

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ enum llm_kv {
327327
LLM_KV_TOKENIZER_ADD_PREFIX,
328328
LLM_KV_TOKENIZER_HF_JSON,
329329
LLM_KV_TOKENIZER_RWKV,
330+
LLM_KV_TOKENIZER_PREFIX_ID,
331+
LLM_KV_TOKENIZER_SUFFIX_ID,
332+
LLM_KV_TOKENIZER_MIDDLE_ID,
333+
LLM_KV_TOKENIZER_EOT_ID,
330334
};
331335

332336
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
@@ -399,6 +403,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
399403
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
400404
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
401405
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
406+
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
407+
{ LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
408+
{ LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" },
409+
{ LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
402410
};
403411

404412
struct LLM_KV {
@@ -2055,10 +2063,10 @@ struct llama_vocab {
20552063
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
20562064

20572065
id linefeed_id = 13;
2058-
id special_prefix_id = 32007;
2059-
id special_middle_id = 32009;
2060-
id special_suffix_id = 32008;
2061-
id special_eot_id = 32010;
2066+
id special_prefix_id = -1;
2067+
id special_suffix_id = -1;
2068+
id special_middle_id = -1;
2069+
id special_eot_id = -1;
20622070

20632071
bool add_space_prefix = true;
20642072

@@ -4072,6 +4080,30 @@ static void llm_load_vocab(
40724080
vocab.special_cls_id = -1;
40734081
vocab.special_mask_id = -1;
40744082

4083+
// For Fill-In-the-Middle (FIM)/infill models which where converted
4084+
// prior to support of FIM special tokens in GGUF, the following
4085+
// will allow those models to continue to work. The general names
4086+
// of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and
4087+
// CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once
4088+
// new versions of these models have been published.
4089+
std::string gen_name;
4090+
ml.get_key(LLM_KV_GENERAL_NAME, gen_name);
4091+
std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(),
4092+
[](unsigned char c){ return std::tolower(c); });
4093+
if (gen_name.find("code") != std::string::npos) {
4094+
if (model.arch == LLM_ARCH_LLAMA) {
4095+
vocab.special_prefix_id = 32007;
4096+
vocab.special_suffix_id = 32008;
4097+
vocab.special_middle_id = 32009;
4098+
vocab.special_eot_id = 32010;
4099+
} else if (model.arch == LLM_ARCH_GEMMA) {
4100+
vocab.special_prefix_id = 67;
4101+
vocab.special_suffix_id = 69;
4102+
vocab.special_middle_id = 68;
4103+
vocab.special_eot_id = 70;
4104+
}
4105+
}
4106+
40754107
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
40764108
if (add_space_prefix_keyidx != -1) {
40774109
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
@@ -4185,13 +4217,17 @@ static void llm_load_vocab(
41854217
// special tokens
41864218
{
41874219
const std::vector<std::pair<enum llm_kv, int32_t &>> special_token_types = {
4188-
{ LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
4189-
{ LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
4190-
{ LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
4191-
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
4192-
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
4193-
{ LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
4194-
{ LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
4220+
{ LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id },
4221+
{ LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id },
4222+
{ LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id },
4223+
{ LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id },
4224+
{ LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id },
4225+
{ LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id },
4226+
{ LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id },
4227+
{ LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
4228+
{ LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
4229+
{ LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
4230+
{ LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
41954231
};
41964232
for (const auto & it : special_token_types) {
41974233
const std::string & key = kv(std::get<0>(it));

0 commit comments

Comments
 (0)