Skip to content

Commit 5dce921

Browse files
authored
py : make convert-pt-to-ggml.py backwards compatible with older vocab.json tokenizer files (ggml-org#1001)
* patch checkpoint convert script to keep compatibility with older hf_transformers whisper tokenizer * typo fix
1 parent 258f16b commit 5dce921

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

models/convert-pt-to-ggml.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,16 +224,39 @@ def bytes_to_unicode():
224224

225225
#code.interact(local=locals())
226226

227+
# load tokenizer
228+
# for backwards compatibility, also check for older hf_transformers format tokenizer files
229+
# old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
230+
# new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
227231
multilingual = hparams["n_vocab"] == 51865
228232
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
233+
tokenizer_type = "tiktoken"
234+
if not tokenizer.is_file():
235+
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2") / "vocab.json"
236+
tokenizer_type = "hf_transformers"
237+
if not tokenizer.is_file():
238+
print("Error: failed to find either tiktoken or hf_transformers tokenizer file:", tokenizer)
239+
sys.exit(1)
240+
241+
byte_encoder = bytes_to_unicode()
242+
byte_decoder = {v:k for k, v in byte_encoder.items()}
243+
244+
if tokenizer_type == "tiktoken":
245+
with open(tokenizer, "rb") as f:
246+
contents = f.read()
247+
tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
248+
elif tokenizer_type == "hf_transformers":
249+
with open(tokenizer, "r", encoding="utf8") as f:
250+
_tokens_raw = json.load(f)
251+
if '<|endoftext|>' in _tokens_raw:
252+
# ensures exact same model as tokenizer_type == tiktoken
253+
# details: https://github.com/ggerganov/whisper.cpp/pull/725
254+
del _tokens_raw['<|endoftext|>']
255+
tokens = {bytes([byte_decoder[c] for c in token]): int(idx) for token, idx in _tokens_raw.items()}
229256

230257
# output in the same directory as the model
231258
fname_out = dir_out / "ggml-model.bin"
232259

233-
with open(tokenizer, "rb") as f:
234-
contents = f.read()
235-
tokens = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in contents.splitlines() if line)}
236-
237260
# use 16-bit or 32-bit floats
238261
use_f16 = True
239262
if len(sys.argv) > 4:
@@ -262,9 +285,7 @@ def bytes_to_unicode():
262285
for j in range(filters.shape[1]):
263286
fout.write(struct.pack("f", filters[i][j]))
264287

265-
byte_encoder = bytes_to_unicode()
266-
byte_decoder = {v:k for k, v in byte_encoder.items()}
267-
288+
# write tokenizer
268289
fout.write(struct.pack("i", len(tokens)))
269290

270291
for key in tokens:

0 commit comments

Comments
 (0)