Skip to content

Commit dac3e4f

Browse files
fairydreamingsszymczy
authored andcommitted
Stop the generation when <|eom_id|> token is encountered - needed for Llama 3.1 tool call support (ggml-org#8858)
* gguf-py, llama : add constants and methods related to Llama-3.1 <|eom_id|> token * llama : find Llama-3.1 <|eom_id|> token id during vocab loading * llama-vocab : add Llama-3.1 <|eom_id|> token to the set of tokens stopping the generation --------- Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent 5133773 commit dac3e4f

File tree

5 files changed

+27
-1
lines changed

5 files changed

+27
-1
lines changed

gguf-py/gguf/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class Tokenizer:
161161
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
162162
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
163163
EOT_ID = "tokenizer.ggml.eot_token_id"
164+
EOM_ID = "tokenizer.ggml.eom_token_id"
164165

165166
class Adapter:
166167
TYPE = "adapter.type"
@@ -1327,3 +1328,4 @@ def get_type(val: Any) -> GGUFValueType:
13271328
KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
13281329
KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
13291330
KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID
1331+
KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,9 @@ def add_middle_token_id(self, id: int) -> None:
828828
def add_eot_token_id(self, id: int) -> None:
829829
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
830830

831+
def add_eom_token_id(self, id: int) -> None:
832+
self.add_uint32(Keys.Tokenizer.EOM_ID, id)
833+
831834
def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
832835
pack_prefix = ''
833836
if not skip_pack_prefix:

src/llama-vocab.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,8 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
14441444
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
14451445
return token != -1 && (
14461446
token == llama_token_eos_impl(vocab) ||
1447-
token == llama_token_eot_impl(vocab)
1447+
token == llama_token_eot_impl(vocab) ||
1448+
token == llama_token_eom_impl(vocab)
14481449
);
14491450
}
14501451

@@ -1500,6 +1501,10 @@ llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
15001501
return vocab.special_eot_id;
15011502
}
15021503

1504+
llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1505+
return vocab.special_eom_id;
1506+
}
1507+
15031508
int32_t llama_tokenize_impl(
15041509
const struct llama_vocab & vocab,
15051510
const char * text,

src/llama-vocab.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct llama_vocab {
4545
id special_suffix_id = -1;
4646
id special_middle_id = -1;
4747
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
48+
id special_eom_id = -1;
4849

4950
// tokenizer flags
5051
bool tokenizer_add_space_prefix = false;
@@ -101,6 +102,7 @@ llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
101102
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
102103
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
103104
llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
105+
llama_token llama_token_eom_impl (const struct llama_vocab & vocab);
104106

105107
int32_t llama_tokenize_impl(
106108
const struct llama_vocab & vocab,

src/llama.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ enum llm_kv {
359359
LLM_KV_TOKENIZER_SUFFIX_ID,
360360
LLM_KV_TOKENIZER_MIDDLE_ID,
361361
LLM_KV_TOKENIZER_EOT_ID,
362+
LLM_KV_TOKENIZER_EOM_ID,
362363

363364
LLM_KV_ADAPTER_TYPE,
364365
LLM_KV_ADAPTER_LORA_ALPHA,
@@ -456,6 +457,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
456457
{ LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
457458
{ LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" },
458459
{ LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
460+
{ LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" },
459461

460462
{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
461463
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
@@ -5583,6 +5585,7 @@ static void llm_load_vocab(
55835585
{ LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
55845586
{ LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
55855587
{ LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
5588+
{ LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id },
55865589
};
55875590

55885591
for (const auto & it : special_token_types) {
@@ -5635,6 +5638,17 @@ static void llm_load_vocab(
56355638
}
56365639
}
56375640
}
5641+
5642+
// find EOM token: "<|eom_id|>"
5643+
//
5644+
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
5645+
// for now, we apply this workaround to find the EOM token based on its text
5646+
if (vocab.special_eom_id == -1) {
5647+
const auto & t = vocab.token_to_id.find("<|eom_id|>");
5648+
if (t != vocab.token_to_id.end()) {
5649+
vocab.special_eom_id = t->second;
5650+
}
5651+
}
56385652
}
56395653

56405654
// build special tokens cache

0 commit comments

Comments
 (0)