Skip to content

Commit 40ab5c6

Browse files
committed
compression: Implement ciscorn's dictionary approach
Massive savings. Thanks so much @ciscorn for providing the initial code for choosing the dictionary. This adds a bit of time to the build, both to find the dictionary but also because (for reasons I don't fully understand), the binary search in the compress() function no longer worked and had to be replaced with a linear search. I think this is because the intended invariant is that for codebook entries that encode to the same number of bits, the entries are ordered in ascending value. However, I mis-placed the transition from "words" to "byte/char values" so the codebook entries for words are in word-order rather than their code order. Because this price is only paid at build time, I didn't care to determine exactly where the correct fix was. I also commented out a line to produce the "estimated total memory size" -- at least on the unix build with TRANSLATION=ja, this led to a build time KeyError trying to compute the codebook size for all the strings. I think this occurs because some single unicode code point ('ァ') is no longer present as itself in the compressed strings, due to always being replaced by a word. As promised, this seems to save hundreds of bytes in the German translation on the trinket m0. Testing performed: - built trinket_m0 in several languages - built and ran unix port in several languages (en, de_DE, ja) and ran simple error-producing codes like ./micropython -c '1/0'
1 parent 7611e71 commit 40ab5c6

File tree

2 files changed

+154
-68
lines changed

2 files changed

+154
-68
lines changed

py/makeqstrdata.py

Lines changed: 139 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -100,77 +100,153 @@ def translate(translation_file, i18ns):
100100
translations.append((original, translation))
101101
return translations
102102

103-
def frequent_ngrams(corpus, sz, n):
104-
return collections.Counter(corpus[i:i+sz] for i in range(len(corpus)-sz)).most_common(n)
103+
class TextSplitter:
104+
def __init__(self, words):
105+
words.sort(key=lambda x: len(x), reverse=True)
106+
self.words = set(words)
107+
self.pat = re.compile("|".join(re.escape(w) for w in words) + "|.", flags=re.DOTALL)
108+
109+
def iter_words(self, text):
110+
s = []
111+
for m in self.pat.finditer(text):
112+
t = m.group(0)
113+
if t in self.words:
114+
if s:
115+
yield (False, "".join(s))
116+
s = []
117+
yield (True, t)
118+
else:
119+
s.append(t)
120+
if s:
121+
yield (False, "".join(s))
122+
123+
def iter(self, text):
124+
s = []
125+
for m in self.pat.finditer(text):
126+
yield m.group(0)
127+
128+
def iter_substrings(s, minlen, maxlen):
129+
maxlen = min(len(s), maxlen)
130+
for n in range(minlen, maxlen + 1):
131+
for begin in range(0, len(s) - n + 1):
132+
yield s[begin : begin + n]
133+
134+
def compute_huffman_coding(translations, compression_filename):
135+
texts = [t[1] for t in translations]
136+
all_strings_concat = "".join(texts)
137+
words = []
138+
max_ord = 0
139+
begin_unused = 128
140+
end_unused = 256
141+
for text in texts:
142+
for c in text:
143+
ord_c = ord(c)
144+
max_ord = max(max_ord, ord_c)
145+
if 128 <= ord_c < 256:
146+
end_unused = min(ord_c, end_unused)
147+
max_words = end_unused - begin_unused
148+
char_size = 1 if max_ord < 256 else 2
149+
150+
sum_word_len = 0
151+
while True:
152+
extractor = TextSplitter(words)
153+
counter = collections.Counter()
154+
for t in texts:
155+
for (found, word) in extractor.iter_words(t):
156+
if not found:
157+
for substr in iter_substrings(word, minlen=2, maxlen=9):
158+
counter[substr] += 1
159+
160+
scores = sorted(
161+
(
162+
# I don't know why this works good. This could be better.
163+
(s, (len(s) - 1) ** ((max(occ - 2, 1) + 0.5) ** 0.8), occ)
164+
for (s, occ) in counter.items()
165+
),
166+
key=lambda x: x[1],
167+
reverse=True,
168+
)
169+
170+
w = None
171+
for (s, score, occ) in scores:
172+
if score < 0:
173+
break
174+
if len(s) > 1:
175+
w = s
176+
break
177+
178+
if not w:
179+
break
180+
if len(w) + sum_word_len > 256:
181+
break
182+
if len(words) == max_words:
183+
break
184+
words.append(w)
185+
sum_word_len += len(w)
186+
187+
extractor = TextSplitter(words)
188+
counter = collections.Counter()
189+
for t in texts:
190+
for atom in extractor.iter(t):
191+
counter[atom] += 1
192+
cb = huffman.codebook(counter.items())
193+
194+
word_start = begin_unused
195+
word_end = word_start + len(words) - 1
196+
print("// # words", len(words))
197+
print("// words", words)
105198

106-
def encode_ngrams(translation, ngrams):
107-
if len(ngrams) > 32:
108-
start = 0xe000
109-
else:
110-
start = 0x80
111-
for i, g in enumerate(ngrams):
112-
translation = translation.replace(g, chr(start + i))
113-
return translation
114-
115-
def decode_ngrams(compressed, ngrams):
116-
if len(ngrams) > 32:
117-
start, end = 0xe000, 0xf8ff
118-
else:
119-
start, end = 0x80, 0x9f
120-
return "".join(ngrams[ord(c) - start] if (start <= ord(c) <= end) else c for c in compressed)
121-
122-
def compute_huffman_coding(translations, qstrs, compression_filename):
123-
all_strings = [x[1] for x in translations]
124-
all_strings_concat = "".join(all_strings)
125-
ngrams = [i[0] for i in frequent_ngrams(all_strings_concat, 2, 32)]
126-
all_strings_concat = encode_ngrams(all_strings_concat, ngrams)
127-
counts = collections.Counter(all_strings_concat)
128-
cb = huffman.codebook(counts.items())
129199
values = []
130200
length_count = {}
131201
renumbered = 0
132202
last_l = None
133203
canonical = {}
134-
for ch, code in sorted(cb.items(), key=lambda x: (len(x[1]), x[0])):
135-
values.append(ch)
204+
for atom, code in sorted(cb.items(), key=lambda x: (len(x[1]), x[0])):
205+
values.append(atom)
136206
l = len(code)
137207
if l not in length_count:
138208
length_count[l] = 0
139209
length_count[l] += 1
140210
if last_l:
141211
renumbered <<= (l - last_l)
142-
canonical[ch] = '{0:0{width}b}'.format(renumbered, width=l)
143-
s = C_ESCAPES.get(ch, ch)
144-
print("//", ord(ch), s, counts[ch], canonical[ch], renumbered)
212+
canonical[atom] = '{0:0{width}b}'.format(renumbered, width=l)
213+
#print(f"atom={repr(atom)} code={code}", file=sys.stderr)
214+
if len(atom) > 1:
215+
o = words.index(atom) + 0x80
216+
s = "".join(C_ESCAPES.get(ch1, ch1) for ch1 in atom)
217+
else:
218+
s = C_ESCAPES.get(atom, atom)
219+
o = ord(atom)
220+
print("//", o, s, counter[atom], canonical[atom], renumbered)
145221
renumbered += 1
146222
last_l = l
147223
lengths = bytearray()
148224
print("// length count", length_count)
149-
print("// bigrams", ngrams)
225+
150226
for i in range(1, max(length_count) + 2):
151227
lengths.append(length_count.get(i, 0))
152228
print("// values", values, "lengths", len(lengths), lengths)
153-
ngramdata = [ord(ni) for i in ngrams for ni in i]
154-
print("// estimated total memory size", len(lengths) + 2*len(values) + 2 * len(ngramdata) + sum((len(cb[u]) + 7)//8 for u in all_strings_concat))
229+
maxord = max(ord(u) for u in values if len(u) == 1)
230+
values_type = "uint16_t" if maxord > 255 else "uint8_t"
231+
ch_size = 1 if maxord > 255 else 2
232+
print("//", values, lengths)
233+
values = [(atom if len(atom) == 1 else chr(0x80 + words.index(atom))) for atom in values]
155234
print("//", values, lengths)
156-
values_type = "uint16_t" if max(ord(u) for u in values) > 255 else "uint8_t"
157235
max_translation_encoded_length = max(len(translation.encode("utf-8")) for original,translation in translations)
158236
with open(compression_filename, "w") as f:
159237
f.write("const uint8_t lengths[] = {{ {} }};\n".format(", ".join(map(str, lengths))))
160238
f.write("const {} values[] = {{ {} }};\n".format(values_type, ", ".join(str(ord(u)) for u in values)))
161239
f.write("#define compress_max_length_bits ({})\n".format(max_translation_encoded_length.bit_length()))
162-
f.write("const {} bigrams[] = {{ {} }};\n".format(values_type, ", ".join(str(u) for u in ngramdata)))
163-
if len(ngrams) > 32:
164-
bigram_start = 0xe000
165-
else:
166-
bigram_start = 0x80
167-
bigram_end = bigram_start + len(ngrams) - 1 # End is inclusive
168-
f.write("#define bigram_start {}\n".format(bigram_start))
169-
f.write("#define bigram_end {}\n".format(bigram_end))
170-
return values, lengths, ngrams
240+
f.write("const {} words[] = {{ {} }};\n".format(values_type, ", ".join(str(ord(c)) for w in words for c in w)))
241+
f.write("const uint8_t wlen[] = {{ {} }};\n".format(", ".join(str(len(w)) for w in words)))
242+
f.write("#define word_start {}\n".format(word_start))
243+
f.write("#define word_end {}\n".format(word_end))
244+
245+
extractor = TextSplitter(words)
246+
return values, lengths, words, extractor
171247

172248
def decompress(encoding_table, encoded, encoded_length_bits):
173-
values, lengths, ngrams = encoding_table
249+
values, lengths, words, extractor = encoding_table
174250
dec = []
175251
this_byte = 0
176252
this_bit = 7
@@ -218,16 +294,17 @@ def decompress(encoding_table, encoded, encoded_length_bits):
218294
searched_length += lengths[bit_length]
219295

220296
v = values[searched_length + bits - max_code]
221-
v = decode_ngrams(v, ngrams)
297+
if v >= chr(0x80) and v < chr(0x80 + len(words)):
298+
v = words[ord(v) - 0x80]
222299
i += len(v.encode('utf-8'))
223300
dec.append(v)
224301
return ''.join(dec)
225302

226303
def compress(encoding_table, decompressed, encoded_length_bits, len_translation_encoded):
227304
if not isinstance(decompressed, str):
228305
raise TypeError()
229-
values, lengths, ngrams = encoding_table
230-
decompressed = encode_ngrams(decompressed, ngrams)
306+
values, lengths, words, extractor = encoding_table
307+
231308
enc = bytearray(len(decompressed) * 3)
232309
#print(decompressed)
233310
#print(lengths)
@@ -246,9 +323,15 @@ def compress(encoding_table, decompressed, encoded_length_bits, len_translation_
246323
else:
247324
current_bit -= 1
248325

249-
for c in decompressed:
250-
#print()
251-
#print("char", c, values.index(c))
326+
#print("values = ", values, file=sys.stderr)
327+
for atom in extractor.iter(decompressed):
328+
#print("", file=sys.stderr)
329+
if len(atom) > 1:
330+
c = chr(0x80 + words.index(atom))
331+
else:
332+
c = atom
333+
assert c in values
334+
252335
start = 0
253336
end = lengths[0]
254337
bits = 1
@@ -258,18 +341,12 @@ def compress(encoding_table, decompressed, encoded_length_bits, len_translation_
258341
s = start
259342
e = end
260343
#print("{0:0{width}b}".format(code, width=bits))
261-
# Binary search!
262-
while e > s:
263-
midpoint = (s + e) // 2
264-
#print(s, e, midpoint)
265-
if values[midpoint] == c:
266-
compressed = code + (midpoint - start)
267-
#print("found {0:0{width}b}".format(compressed, width=bits))
344+
# Linear search!
345+
for i in range(s, e):
346+
if values[i] == c:
347+
compressed = code + (i - start)
348+
#print("found {0:0{width}b}".format(compressed, width=bits), file=sys.stderr)
268349
break
269-
elif c < values[midpoint]:
270-
e = midpoint
271-
else:
272-
s = midpoint + 1
273350
code += end - start
274351
code <<= 1
275352
start = end
@@ -452,7 +529,7 @@ def print_qstr_enums(qstrs):
452529
if args.translation:
453530
i18ns = sorted(i18ns)
454531
translations = translate(args.translation, i18ns)
455-
encoding_table = compute_huffman_coding(translations, qstrs, args.compression_filename)
532+
encoding_table = compute_huffman_coding(translations, args.compression_filename)
456533
print_qstr_data(encoding_table, qcfgs, qstrs, translations)
457534
else:
458535
print_qstr_enums(qstrs)

supervisor/shared/translate.c

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,22 @@ STATIC int put_utf8(char *buf, int u) {
4747
if(u <= 0x7f) {
4848
*buf = u;
4949
return 1;
50-
} else if(bigram_start <= u && u <= bigram_end) {
51-
int n = (u - 0x80) * 2;
52-
// (note that at present, entries in the bigrams table are
53-
// guaranteed not to represent bigrams themselves, so this adds
50+
} else if(word_start <= u && u <= word_end) {
51+
int n = (u - 0x80);
52+
size_t off = 0;
53+
for(int i=0; i<n; i++) {
54+
off += wlen[i];
55+
}
56+
int ret = 0;
57+
// note that at present, entries in the words table are
58+
// guaranteed not to represent words themselves, so this adds
5459
// at most 1 level of recursive call
55-
int ret = put_utf8(buf, bigrams[n]);
56-
return ret + put_utf8(buf + ret, bigrams[n+1]);
60+
for(int i=0; i<wlen[n]; i++) {
61+
int len = put_utf8(buf, words[off+i]);
62+
buf += len;
63+
ret += len;
64+
}
65+
return ret;
5766
} else if(u <= 0x07ff) {
5867
*buf++ = 0b11000000 | (u >> 6);
5968
*buf = 0b10000000 | (u & 0b00111111);

0 commit comments

Comments
 (0)