|
22 | 22 | import struct
|
23 | 23 | import numpy as np
|
24 | 24 | import torch
|
25 |
| - |
26 | 25 | from sentencepiece import SentencePieceProcessor
|
27 | 26 |
|
28 | 27 | if len(sys.argv) < 3:
|
@@ -101,12 +100,28 @@ def get_n_parts(dim):
|
101 | 100 |
|
102 | 101 | # Is this correct??
|
103 | 102 | for i in range(32000):
|
104 |
| - # TODO: this is probably wrong - not sure how this tokenizer works |
105 |
| - text = tokenizer.decode([29889, i]).encode('utf-8') |
106 |
| - # remove the first byte (it's always '.') |
107 |
| - text = text[1:] |
108 |
| - fout.write(struct.pack("i", len(text))) |
109 |
| - fout.write(text) |
| 103 | + if tokenizer.is_unknown(i): |
| 104 | + # "<unk>" token (translated as ??) |
| 105 | + text = " \u2047 ".encode("utf-8") |
| 106 | + fout.write(struct.pack("i", len(text))) |
| 107 | + fout.write(text) |
| 108 | + elif tokenizer.is_control(i): |
| 109 | + # "<s>"/"</s>" tokens |
| 110 | + fout.write(struct.pack("i", 0)) |
| 111 | + elif tokenizer.is_byte(i): |
| 112 | + # "<U+XX>" tokens (which may be invalid UTF-8) |
| 113 | + piece = tokenizer.id_to_piece(i) |
| 114 | + if len(piece) != 6: |
| 115 | + print("Invalid token: " + piece) |
| 116 | + sys.exit(1) |
| 117 | + byte_value = int(piece[3:-1], 16) |
| 118 | + fout.write(struct.pack("i", 1)) |
| 119 | + fout.write(struct.pack("B", byte_value)) |
| 120 | + else: |
| 121 | + # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces. |
| 122 | + text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") |
| 123 | + fout.write(struct.pack("i", len(text))) |
| 124 | + fout.write(text) |
110 | 125 |
|
111 | 126 | for k, v in model.items():
|
112 | 127 | name = k
|
|
0 commit comments