@@ -224,16 +224,39 @@ def bytes_to_unicode():
224
224
225
225
#code.interact(local=locals())
226
226
227
+ # load tokenizer
228
+ # for backwards compatibility, also check for older hf_transformers format tokenizer files
229
+ # old format: dir_whisper/whisper/assets/[multilingual/gpt2]/vocab.json
230
+ # new format: dir_whisper/whisper/assets/[multilingual/gpt2].tiktoken
227
231
multilingual = hparams ["n_vocab" ] == 51865
228
232
tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken" )
233
+ tokenizer_type = "tiktoken"
234
+ if not tokenizer .is_file ():
235
+ tokenizer = dir_whisper / "whisper" / "assets" / (multilingual and "multilingual" or "gpt2" ) / "vocab.json"
236
+ tokenizer_type = "hf_transformers"
237
+ if not tokenizer .is_file ():
238
+ print ("Error: failed to find either tiktoken or hf_transformers tokenizer file:" , tokenizer )
239
+ sys .exit (1 )
240
+
241
+ byte_encoder = bytes_to_unicode ()
242
+ byte_decoder = {v :k for k , v in byte_encoder .items ()}
243
+
244
+ if tokenizer_type == "tiktoken" :
245
+ with open (tokenizer , "rb" ) as f :
246
+ contents = f .read ()
247
+ tokens = {base64 .b64decode (token ): int (rank ) for token , rank in (line .split () for line in contents .splitlines () if line )}
248
+ elif tokenizer_type == "hf_transformers" :
249
+ with open (tokenizer , "r" , encoding = "utf8" ) as f :
250
+ _tokens_raw = json .load (f )
251
+ if '<|endoftext|>' in _tokens_raw :
252
+ # ensures exact same model as tokenizer_type == tiktoken
253
+ # details: https://github.com/ggerganov/whisper.cpp/pull/725
254
+ del _tokens_raw ['<|endoftext|>' ]
255
+ tokens = {bytes ([byte_decoder [c ] for c in token ]): int (idx ) for token , idx in _tokens_raw .items ()}
229
256
230
257
# output in the same directory as the model
231
258
fname_out = dir_out / "ggml-model.bin"
232
259
233
- with open (tokenizer , "rb" ) as f :
234
- contents = f .read ()
235
- tokens = {base64 .b64decode (token ): int (rank ) for token , rank in (line .split () for line in contents .splitlines () if line )}
236
-
237
260
# use 16-bit or 32-bit floats
238
261
use_f16 = True
239
262
if len (sys .argv ) > 4 :
@@ -262,9 +285,7 @@ def bytes_to_unicode():
262
285
for j in range (filters .shape [1 ]):
263
286
fout .write (struct .pack ("f" , filters [i ][j ]))
264
287
265
- byte_encoder = bytes_to_unicode ()
266
- byte_decoder = {v :k for k , v in byte_encoder .items ()}
267
-
288
+ # write tokenizer
268
289
fout .write (struct .pack ("i" , len (tokens )))
269
290
270
291
for key in tokens :
0 commit comments