Merge pull request #3 from ciscorn/dict-comp

Small improvements to the dictionary compression
This commit is contained in:
Jeff Epler 2020-09-13 12:58:13 -05:00 committed by GitHub
commit 9abfc51ced
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 61 additions and 88 deletions

View File

@ -12,6 +12,7 @@ from __future__ import print_function
import re import re
import sys import sys
from math import log
import collections import collections
import gettext import gettext
import os.path import os.path
@ -111,9 +112,10 @@ class TextSplitter:
def iter_words(self, text): def iter_words(self, text):
s = [] s = []
words = self.words
for m in self.pat.finditer(text): for m in self.pat.finditer(text):
t = m.group(0) t = m.group(0)
if t in self.words: if t in words:
if s: if s:
yield (False, "".join(s)) yield (False, "".join(s))
s = [] s = []
@ -124,33 +126,35 @@ class TextSplitter:
yield (False, "".join(s)) yield (False, "".join(s))
def iter(self, text): def iter(self, text):
s = []
for m in self.pat.finditer(text): for m in self.pat.finditer(text):
yield m.group(0) yield m.group(0)
def iter_substrings(s, minlen, maxlen): def iter_substrings(s, minlen, maxlen):
maxlen = min(len(s), maxlen) len_s = len(s)
maxlen = min(len_s, maxlen)
for n in range(minlen, maxlen + 1): for n in range(minlen, maxlen + 1):
for begin in range(0, len(s) - n + 1): for begin in range(0, len_s - n + 1):
yield s[begin : begin + n] yield s[begin : begin + n]
def compute_huffman_coding(translations, compression_filename): def compute_huffman_coding(translations, compression_filename):
texts = [t[1] for t in translations] texts = [t[1] for t in translations]
all_strings_concat = "".join(texts)
words = [] words = []
start_unused = 0x80
end_unused = 0xff
max_ord = 0 max_ord = 0
begin_unused = 128
end_unused = 256
for text in texts: for text in texts:
for c in text: for c in text:
ord_c = ord(c) ord_c = ord(c)
max_ord = max(max_ord, ord_c) max_ord = max(ord_c, max_ord)
if 128 <= ord_c < 256: if 0x80 <= ord_c < 0xff:
end_unused = min(ord_c, end_unused) end_unused = min(ord_c, end_unused)
max_words = end_unused - begin_unused max_words = end_unused - 0x80
char_size = 1 if max_ord < 256 else 2
sum_word_len = 0 values_type = "uint16_t" if max_ord > 255 else "uint8_t"
max_words_len = 160 if max_ord > 255 else 255
sum_len = 0
while True: while True:
extractor = TextSplitter(words) extractor = TextSplitter(words)
counter = collections.Counter() counter = collections.Counter()
@ -162,30 +166,30 @@ def compute_huffman_coding(translations, compression_filename):
scores = sorted( scores = sorted(
( (
# I don't know why this works good. This could be better. (s, (len(s) - 1) ** log(max(occ - 2, 1)), occ)
(s, (len(s) - 1) ** ((max(occ - 2, 1) + 0.5) ** 0.8), occ)
for (s, occ) in counter.items() for (s, occ) in counter.items()
), ),
key=lambda x: x[1], key=lambda x: x[1],
reverse=True, reverse=True,
) )
w = None word = None
for (s, score, occ) in scores: for (s, score, occ) in scores:
if score < 0: if occ < 5:
continue
if score < 5:
break break
if len(s) > 1: word = s
w = s
break break
if not w: if not word:
break break
if len(w) + sum_word_len > 256: if sum_len + len(word) - 2 > max_words_len:
break break
if len(words) == max_words: if len(words) == max_words:
break break
words.append(w) words.append(word)
sum_word_len += len(w) sum_len += len(word) - 2
extractor = TextSplitter(words) extractor = TextSplitter(words)
counter = collections.Counter() counter = collections.Counter()
@ -194,7 +198,7 @@ def compute_huffman_coding(translations, compression_filename):
counter[atom] += 1 counter[atom] += 1
cb = huffman.codebook(counter.items()) cb = huffman.codebook(counter.items())
word_start = begin_unused word_start = start_unused
word_end = word_start + len(words) - 1 word_end = word_start + len(words) - 1
print("// # words", len(words)) print("// # words", len(words))
print("// words", words) print("// words", words)
@ -202,17 +206,17 @@ def compute_huffman_coding(translations, compression_filename):
values = [] values = []
length_count = {} length_count = {}
renumbered = 0 renumbered = 0
last_l = None last_length = None
canonical = {} canonical = {}
for atom, code in sorted(cb.items(), key=lambda x: (len(x[1]), x[0])): for atom, code in sorted(cb.items(), key=lambda x: (len(x[1]), x[0])):
values.append(atom) values.append(atom)
l = len(code) length = len(code)
if l not in length_count: if length not in length_count:
length_count[l] = 0 length_count[length] = 0
length_count[l] += 1 length_count[length] += 1
if last_l: if last_length:
renumbered <<= (l - last_l) renumbered <<= (length - last_length)
canonical[atom] = '{0:0{width}b}'.format(renumbered, width=l) canonical[atom] = '{0:0{width}b}'.format(renumbered, width=length)
# print(f"atom={repr(atom)} code={code}", file=sys.stderr) # print(f"atom={repr(atom)} code={code}", file=sys.stderr)
if len(atom) > 1: if len(atom) > 1:
o = words.index(atom) + 0x80 o = words.index(atom) + 0x80
@ -222,34 +226,37 @@ def compute_huffman_coding(translations, compression_filename):
o = ord(atom) o = ord(atom)
print("//", o, s, counter[atom], canonical[atom], renumbered) print("//", o, s, counter[atom], canonical[atom], renumbered)
renumbered += 1 renumbered += 1
last_l = l last_length = length
lengths = bytearray() lengths = bytearray()
print("// length count", length_count) print("// length count", length_count)
for i in range(1, max(length_count) + 2): for i in range(1, max(length_count) + 2):
lengths.append(length_count.get(i, 0)) lengths.append(length_count.get(i, 0))
print("// values", values, "lengths", len(lengths), lengths) print("// values", values, "lengths", len(lengths), lengths)
maxord = max(ord(u) for u in values if len(u) == 1)
values_type = "uint16_t" if maxord > 255 else "uint8_t"
ch_size = 1 if maxord > 255 else 2
print("//", values, lengths) print("//", values, lengths)
values = [(atom if len(atom) == 1 else chr(0x80 + words.index(atom))) for atom in values] values = [(atom if len(atom) == 1 else chr(0x80 + words.index(atom))) for atom in values]
print("//", values, lengths) print("//", values, lengths)
max_translation_encoded_length = max(len(translation.encode("utf-8")) for original,translation in translations) max_translation_encoded_length = max(
len(translation.encode("utf-8")) for (original, translation) in translations)
wends = list(len(w) - 2 for w in words)
for i in range(1, len(wends)):
wends[i] += wends[i - 1]
with open(compression_filename, "w") as f: with open(compression_filename, "w") as f:
f.write("const uint8_t lengths[] = {{ {} }};\n".format(", ".join(map(str, lengths)))) f.write("const uint8_t lengths[] = {{ {} }};\n".format(", ".join(map(str, lengths))))
f.write("const {} values[] = {{ {} }};\n".format(values_type, ", ".join(str(ord(u)) for u in values))) f.write("const {} values[] = {{ {} }};\n".format(values_type, ", ".join(str(ord(u)) for u in values)))
f.write("#define compress_max_length_bits ({})\n".format(max_translation_encoded_length.bit_length())) f.write("#define compress_max_length_bits ({})\n".format(max_translation_encoded_length.bit_length()))
f.write("const {} words[] = {{ {} }};\n".format(values_type, ", ".join(str(ord(c)) for w in words for c in w))) f.write("const {} words[] = {{ {} }};\n".format(values_type, ", ".join(str(ord(c)) for w in words for c in w)))
f.write("const uint8_t wlen[] = {{ {} }};\n".format(", ".join(str(len(w)) for w in words))) f.write("const uint8_t wends[] = {{ {} }};\n".format(", ".join(str(p) for p in wends)))
f.write("#define word_start {}\n".format(word_start)) f.write("#define word_start {}\n".format(word_start))
f.write("#define word_end {}\n".format(word_end)) f.write("#define word_end {}\n".format(word_end))
extractor = TextSplitter(words) return (values, lengths, words, canonical, extractor)
return values, lengths, words, extractor
def decompress(encoding_table, encoded, encoded_length_bits): def decompress(encoding_table, encoded, encoded_length_bits):
values, lengths, words, extractor = encoding_table (values, lengths, words, _, _) = encoding_table
dec = [] dec = []
this_byte = 0 this_byte = 0
this_bit = 7 this_bit = 7
@ -306,66 +313,32 @@ def decompress(encoding_table, encoded, encoded_length_bits):
def compress(encoding_table, decompressed, encoded_length_bits, len_translation_encoded): def compress(encoding_table, decompressed, encoded_length_bits, len_translation_encoded):
if not isinstance(decompressed, str): if not isinstance(decompressed, str):
raise TypeError() raise TypeError()
values, lengths, words, extractor = encoding_table (_, _, _, canonical, extractor) = encoding_table
enc = bytearray(len(decompressed) * 3) enc = bytearray(len(decompressed) * 3)
#print(decompressed)
#print(lengths)
current_bit = 7 current_bit = 7
current_byte = 0 current_byte = 0
code = len_translation_encoded
bits = encoded_length_bits + 1 bits = encoded_length_bits + 1
for i in range(bits - 1, 0, -1): for i in range(bits - 1, 0, -1):
if len_translation_encoded & (1 << (i - 1)): if len_translation_encoded & (1 << (i - 1)):
enc[current_byte] |= 1 << current_bit enc[current_byte] |= 1 << current_bit
if current_bit == 0: if current_bit == 0:
current_bit = 7 current_bit = 7
#print("packed {0:0{width}b}".format(enc[current_byte], width=8))
current_byte += 1 current_byte += 1
else: else:
current_bit -= 1 current_bit -= 1
#print("values = ", values, file=sys.stderr)
for atom in extractor.iter(decompressed): for atom in extractor.iter(decompressed):
#print("", file=sys.stderr) for b in canonical[atom]:
if len(atom) > 1: if b == "1":
c = chr(0x80 + words.index(atom))
else:
c = atom
assert c in values
start = 0
end = lengths[0]
bits = 1
compressed = None
code = 0
while compressed is None:
s = start
e = end
#print("{0:0{width}b}".format(code, width=bits))
# Linear search!
for i in range(s, e):
if values[i] == c:
compressed = code + (i - start)
#print("found {0:0{width}b}".format(compressed, width=bits), file=sys.stderr)
break
code += end - start
code <<= 1
start = end
end += lengths[bits]
bits += 1
#print("next bit", bits)
for i in range(bits - 1, 0, -1):
if compressed & (1 << (i - 1)):
enc[current_byte] |= 1 << current_bit enc[current_byte] |= 1 << current_bit
if current_bit == 0: if current_bit == 0:
current_bit = 7 current_bit = 7
#print("packed {0:0{width}b}".format(enc[current_byte], width=8))
current_byte += 1 current_byte += 1
else: else:
current_bit -= 1 current_bit -= 1
if current_bit != 7: if current_bit != 7:
current_byte += 1 current_byte += 1
return enc[:current_byte] return enc[:current_byte]

View File

@ -48,17 +48,17 @@ STATIC int put_utf8(char *buf, int u) {
*buf = u; *buf = u;
return 1; return 1;
} else if(word_start <= u && u <= word_end) { } else if(word_start <= u && u <= word_end) {
int n = (u - 0x80); uint n = (u - word_start);
size_t off = 0; size_t pos = 0;
for(int i=0; i<n; i++) { if (n > 0) {
off += wlen[i]; pos = wends[n - 1] + (n * 2);
} }
int ret = 0; int ret = 0;
// note that at present, entries in the words table are // note that at present, entries in the words table are
// guaranteed not to represent words themselves, so this adds // guaranteed not to represent words themselves, so this adds
// at most 1 level of recursive call // at most 1 level of recursive call
for(int i=0; i<wlen[n]; i++) { for(; pos < wends[n] + (n + 1) * 2; pos++) {
int len = put_utf8(buf, words[off+i]); int len = put_utf8(buf, words[pos]);
buf += len; buf += len;
ret += len; ret += len;
} }