Skip to content

Commit e8834e9

Browse files
strutive07jordankanter
authored andcommitted
Add byte token type when tokenizer.model is not exists (ggml-org#4641)
* Add byte token type to hf format * remove unused variable
1 parent 5a5d5e2 commit e8834e9

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

convert.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def __init__(self, params: Params, fname_tokenizer: Path) -> None:
357357
for tok in self.tokenizer.all_special_tokens
358358
}
359359
self.special_ids: set[int] = set(self.tokenizer.all_special_ids)
360+
self.reverse_vocab = {id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()}
360361
self.vocab_size_base: int = self.tokenizer.vocab_size
361362
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict)
362363
self.fname_tokenizer: Path = fname_tokenizer
@@ -370,15 +371,13 @@ def __init__(self, params: Params, fname_tokenizer: Path) -> None:
370371
self.spm = None
371372

372373
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
373-
tokenizer = self.tokenizer
374-
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.get_vocab().items()}
375374
added_tokens_ids = set(self.added_tokens_dict.values())
376375

377376
for i in range(self.vocab_size_base):
378377
if i in added_tokens_ids:
379378
continue
380379

381-
text = reverse_vocab[i].encode("utf-8")
380+
text = self.reverse_vocab[i].encode("utf-8")
382381
yield text, self.get_token_score(i), self.get_token_type(i)
383382

384383
def get_token_type(self, token_id: int) -> gguf.TokenType:
@@ -394,10 +393,13 @@ def get_token_type(self, token_id: int) -> gguf.TokenType:
394393
if self.spm.is_byte(token_id):
395394
toktype = gguf.TokenType.BYTE
396395
else:
396+
token = self.reverse_vocab[token_id]
397397
if token_id == self.unk_token_id:
398398
toktype = gguf.TokenType.UNKNOWN
399-
if token_id in self.special_ids:
399+
elif token_id in self.special_ids:
400400
toktype = gguf.TokenType.CONTROL
401+
elif len(token) == 6 and token.startswith("<0x") and token.endswith(">"):
402+
toktype = gguf.TokenType.BYTE
401403

402404
return toktype
403405

0 commit comments

Comments
 (0)