tools/mpy-tool.py: Rework .mpy merging feature.

Now that the native qstr link table is gone, merging a native .mpy file
with a bytecode .mpy file is not as simple as concatenating the .mpy data.
The qstr_table and obj_table tables from all merged .mpy files must now be
joined together, because they are global to the .mpy file (and hence global
to the merged .mpy file).  This means the bytecode needs to be be decoded,
qstr_table and obj_table indices updated to point to the correct entries in
the new tables, and then the bytecode re-encoded.

This commit makes this change to the merging feature in mpy-tool.py.  This
can now merge an arbitrary number of bytecode .mpy files, and up to one
native .mpy file.

Signed-off-by: Damien George <damien@micropython.org>
This commit is contained in:
Damien George 2022-05-31 00:17:38 +10:00
parent f506bf342a
commit 599a22e569
1 changed files with 252 additions and 67 deletions

View File

@ -182,7 +182,7 @@ mp_binary_op_method_name = (
)
class Opcodes:
class Opcode:
# fmt: off
# Load, Store, Delete, Import, Make, Build, Unpack, Call, Jump, Exception, For, sTack, Return, Yield, Op
MP_BC_BASE_RESERVED = (0x00) # ----------------
@ -318,6 +318,13 @@ class Opcodes:
for i in range(MP_BC_BINARY_OP_MULTI_NUM):
mapping[MP_BC_BINARY_OP_MULTI + i] = "BINARY_OP %d %s" % (i, mp_binary_op_method_name[i])
def __init__(self, offset, fmt, opcode_byte, arg, extra_arg):
self.offset = offset
self.fmt = fmt
self.opcode_byte = opcode_byte
self.arg = arg
self.extra_arg = extra_arg
# This definition of a small int covers all possible targets, in the sense that every
# target can encode as a small int, an integer that passes this test. The minimum is set
@ -326,15 +333,31 @@ def mp_small_int_fits(i):
return -0x2000 <= i <= 0x1FFF
def mp_encode_uint(val, signed=False):
encoded = bytearray([val & 0x7F])
val >>= 7
while val != 0 and val != -1:
encoded.insert(0, 0x80 | (val & 0x7F))
val >>= 7
if signed:
if val == -1 and encoded[0] & 0x40 == 0:
encoded.insert(0, 0xFF)
elif val == 0 and encoded[0] & 0x40 != 0:
encoded.insert(0, 0x80)
return encoded
def mp_opcode_decode(bytecode, ip):
opcode = bytecode[ip]
ip_start = ip
f = (0x000003A4 >> (2 * ((opcode) >> 4))) & 3
extra_byte = (opcode & MP_BC_MASK_EXTRA_BYTE) == 0
ip += 1
arg = 0
arg = None
extra_arg = None
if f in (MP_BC_FORMAT_QSTR, MP_BC_FORMAT_VAR_UINT):
arg = bytecode[ip] & 0x7F
if opcode == Opcode.MP_BC_LOAD_CONST_SMALL_INT and arg & 0x40 != 0:
arg |= -1 << 7
while bytecode[ip] & 0x80 != 0:
ip += 1
arg = arg << 7 | bytecode[ip] & 0x7F
@ -343,15 +366,50 @@ def mp_opcode_decode(bytecode, ip):
if bytecode[ip] & 0x80 == 0:
arg = bytecode[ip]
ip += 1
if opcode in Opcodes.ALL_OFFSET_SIGNED:
if opcode in Opcode.ALL_OFFSET_SIGNED:
arg -= 0x40
else:
arg = bytecode[ip] & 0x7F | bytecode[ip + 1] << 7
ip += 2
if opcode in Opcodes.ALL_OFFSET_SIGNED:
if opcode in Opcode.ALL_OFFSET_SIGNED:
arg -= 0x4000
ip += extra_byte
return f, ip - ip_start, arg
if opcode & MP_BC_MASK_EXTRA_BYTE == 0:
extra_arg = bytecode[ip]
ip += 1
return f, ip - ip_start, arg, extra_arg
def mp_opcode_encode(opcode):
overflow = False
encoded = bytearray([opcode.opcode_byte])
if opcode.fmt in (MP_BC_FORMAT_QSTR, MP_BC_FORMAT_VAR_UINT):
signed = opcode.opcode_byte == Opcode.MP_BC_LOAD_CONST_SMALL_INT
encoded.extend(mp_encode_uint(opcode.arg, signed))
elif opcode.fmt == MP_BC_FORMAT_OFFSET:
is_signed = opcode.opcode_byte in Opcode.ALL_OFFSET_SIGNED
# The -2 accounts for this jump opcode taking 2 bytes (at least).
bytecode_offset = opcode.target.offset - opcode.offset - 2
# Check if the bytecode_offset is small enough to use a 1-byte encoding.
if (is_signed and -64 <= bytecode_offset <= 63) or (
not is_signed and bytecode_offset <= 127
):
# Use a 1-byte jump offset.
if is_signed:
bytecode_offset += 0x40
overflow = not (0 <= bytecode_offset <= 0x7F)
encoded.append(bytecode_offset & 0x7F)
else:
bytecode_offset -= 1
if is_signed:
bytecode_offset += 0x4000
overflow = not (0 <= bytecode_offset <= 0x7FFF)
encoded.append(0x80 | (bytecode_offset & 0x7F))
encoded.append((bytecode_offset >> 7) & 0xFF)
if opcode.extra_arg is not None:
encoded.append(opcode.extra_arg)
return overflow, encoded
def read_prelude_sig(read_byte):
@ -393,6 +451,21 @@ def read_prelude_size(read_byte):
return I, C
# See py/bc.h:MP_BC_PRELUDE_SIZE_ENCODE macro.
def encode_prelude_size(I, C):
# Encode bit-wise as: xIIIIIIC
encoded = bytearray()
while True:
z = (I & 0x3F) << 1 | (C & 1)
C >>= 1
I >>= 6
if C | I:
z |= 0x80
encoded.append(z)
if not C | I:
return encoded
def extract_prelude(bytecode, ip):
def local_read_byte():
b = bytecode[ip_ref[0]]
@ -400,6 +473,8 @@ def extract_prelude(bytecode, ip):
return b
ip_ref = [ip] # to close over ip in Python 2 and 3
# Read prelude signature.
(
n_state,
n_exc_stack,
@ -409,13 +484,12 @@ def extract_prelude(bytecode, ip):
n_def_pos_args,
) = read_prelude_sig(local_read_byte)
n_info, n_cell = read_prelude_size(local_read_byte)
ip = ip_ref[0]
offset_prelude_size = ip_ref[0]
ip2 = ip
ip = ip2 + n_info + n_cell
# ip now points to first opcode
# ip2 points to simple_name qstr
# Read prelude size.
n_info, n_cell = read_prelude_size(local_read_byte)
offset_source_info = ip_ref[0]
# Extract simple_name and argument qstrs (var uints).
args = []
@ -428,11 +502,18 @@ def extract_prelude(bytecode, ip):
break
args.append(value)
offset_line_info = ip_ref[0]
offset_closure_info = offset_source_info + n_info
offset_opcodes = offset_source_info + n_info + n_cell
return (
ip2,
ip,
ip_ref[0],
offset_prelude_size,
offset_source_info,
offset_line_info,
offset_closure_info,
offset_opcodes,
(n_state, n_exc_stack, scope_flags, n_pos_args, n_kwonly_args, n_def_pos_args),
(n_info, n_cell),
args,
)
@ -480,6 +561,8 @@ class CompiledModule:
qstr_table,
obj_table,
raw_code,
qstr_table_file_offset,
obj_table_file_offset,
raw_code_file_offset,
escaped_name,
):
@ -489,8 +572,10 @@ class CompiledModule:
self.header = header
self.qstr_table = qstr_table
self.obj_table = obj_table
self.raw_code_file_offset = raw_code_file_offset
self.raw_code = raw_code
self.qstr_table_file_offset = qstr_table_file_offset
self.obj_table_file_offset = obj_table_file_offset
self.raw_code_file_offset = raw_code_file_offset
self.escaped_name = escaped_name
def hexdump(self):
@ -772,14 +857,17 @@ class RawCode(object):
if code_kind in (MP_CODE_BYTECODE, MP_CODE_NATIVE_PY):
(
self.offset_names,
self.offset_opcodes,
self.offset_prelude_size,
self.offset_source_info,
self.offset_line_info,
self.prelude,
self.offset_closure_info,
self.offset_opcodes,
self.prelude_signature,
self.prelude_size,
self.names,
) = extract_prelude(self.fun_data, prelude_offset)
self.scope_flags = self.prelude[2]
self.n_pos_args = self.prelude[3]
self.scope_flags = self.prelude_signature[2]
self.n_pos_args = self.prelude_signature[3]
self.simple_name = self.qstr_table[self.names[0]]
else:
self.simple_name = self.qstr_table[0]
@ -836,12 +924,12 @@ class RawCode(object):
if self.code_kind == MP_CODE_BYTECODE:
print(" #if MICROPY_PY_SYS_SETTRACE")
print(" .prelude = {")
print(" .n_state = %u," % self.prelude[0])
print(" .n_exc_stack = %u," % self.prelude[1])
print(" .scope_flags = %u," % self.prelude[2])
print(" .n_pos_args = %u," % self.prelude[3])
print(" .n_kwonly_args = %u," % self.prelude[4])
print(" .n_def_pos_args = %u," % self.prelude[5])
print(" .n_state = %u," % self.prelude_signature[0])
print(" .n_exc_stack = %u," % self.prelude_signature[1])
print(" .scope_flags = %u," % self.prelude_signature[2])
print(" .n_pos_args = %u," % self.prelude_signature[3])
print(" .n_kwonly_args = %u," % self.prelude_signature[4])
print(" .n_def_pos_args = %u," % self.prelude_signature[5])
print(" .qstr_block_name_idx = %u," % self.names[0])
print(
" .line_info = fun_data_%s + %u,"
@ -878,13 +966,13 @@ class RawCodeBytecode(RawCode):
bc = self.fun_data
print("simple_name:", self.simple_name.str)
print(" raw bytecode:", len(bc), hexlify_to_str(bc))
print(" prelude:", self.prelude)
print(" prelude:", self.prelude_signature)
print(" args:", [self.qstr_table[i].str for i in self.names[1:]])
print(" line info:", hexlify_to_str(bc[self.offset_line_info : self.offset_opcodes]))
ip = self.offset_opcodes
while ip < len(bc):
fmt, sz, arg = mp_opcode_decode(bc, ip)
if bc[ip] == Opcodes.MP_BC_LOAD_CONST_OBJ:
fmt, sz, arg, _ = mp_opcode_decode(bc, ip)
if bc[ip] == Opcode.MP_BC_LOAD_CONST_OBJ:
arg = repr(self.obj_table[arg])
if fmt == MP_BC_FORMAT_QSTR:
arg = self.qstr_table[arg].str
@ -893,7 +981,7 @@ class RawCodeBytecode(RawCode):
else:
arg = ""
print(
" %-11s %s %s" % (hexlify_to_str(bc[ip : ip + sz]), Opcodes.mapping[bc[ip]], arg)
" %-11s %s %s" % (hexlify_to_str(bc[ip : ip + sz]), Opcode.mapping[bc[ip]], arg)
)
ip += sz
self.disassemble_children()
@ -908,12 +996,12 @@ class RawCodeBytecode(RawCode):
print("static const byte fun_data_%s[%u] = {" % (self.escaped_name, len(bc)))
print(" ", end="")
for b in bc[: self.offset_names]:
for b in bc[: self.offset_source_info]:
print("0x%02x," % b, end="")
print(" // prelude")
print(" ", end="")
for b in bc[self.offset_names : self.offset_line_info]:
for b in bc[self.offset_source_info : self.offset_line_info]:
print("0x%02x," % b, end="")
print(" // names: %s" % ", ".join(self.qstr_table[i].str for i in self.names))
@ -924,8 +1012,8 @@ class RawCodeBytecode(RawCode):
ip = self.offset_opcodes
while ip < len(bc):
fmt, sz, arg = mp_opcode_decode(bc, ip)
opcode_name = Opcodes.mapping[bc[ip]]
fmt, sz, arg, _ = mp_opcode_decode(bc, ip)
opcode_name = Opcode.mapping[bc[ip]]
if fmt == MP_BC_FORMAT_QSTR:
opcode_name += " " + repr(self.qstr_table[arg].str)
elif fmt in (MP_BC_FORMAT_VAR_UINT, MP_BC_FORMAT_OFFSET):
@ -1000,7 +1088,7 @@ class RawCodeNative(RawCode):
)
if self.code_kind != MP_CODE_NATIVE_PY:
return
print(" prelude:", self.prelude)
print(" prelude:", self.prelude_signature)
print(" args:", [self.qstr_table[i].str for i in self.names[1:]])
print(" line info:", fun_data[self.offset_line_info : self.offset_opcodes])
ip = 0
@ -1255,11 +1343,13 @@ def read_mpy(filename):
n_obj = reader.read_uint()
# Read qstrs and construct qstr table.
qstr_table_file_offset = reader.tell()
qstr_table = []
for i in range(n_qstr):
qstr_table.append(read_qstr(reader, segments))
# Read objects and construct object table.
obj_table_file_offset = reader.tell()
obj_table = []
for i in range(n_obj):
obj_table.append(read_obj(reader, segments))
@ -1279,6 +1369,8 @@ def read_mpy(filename):
qstr_table,
obj_table,
raw_code,
qstr_table_file_offset,
obj_table_file_offset,
raw_code_file_offset,
cm_escaped_name,
)
@ -1477,25 +1569,100 @@ def freeze_mpy(base_qstrs, compiled_modules):
print("*/")
def merge_mpy(raw_codes, output_file):
assert len(raw_codes) <= 2 # so var-uints all fit in 1 byte
def adjust_bytecode_qstr_obj_indices(bytecode_in, qstr_table_base, obj_table_base):
# Expand bytcode to a list of opcodes.
opcodes = []
labels = {}
ip = 0
while ip < len(bytecode_in):
fmt, sz, arg, extra_arg = mp_opcode_decode(bytecode_in, ip)
opcode = Opcode(ip, fmt, bytecode_in[ip], arg, extra_arg)
labels[ip] = opcode
opcodes.append(opcode)
ip += sz
if fmt == MP_BC_FORMAT_OFFSET:
opcode.arg += ip
# Link jump opcodes to their destination.
for opcode in opcodes:
if opcode.fmt == MP_BC_FORMAT_OFFSET:
opcode.target = labels[opcode.arg]
# Adjust bytcode as required.
for opcode in opcodes:
if opcode.fmt == MP_BC_FORMAT_QSTR:
opcode.arg += qstr_table_base
elif opcode.opcode_byte == Opcode.MP_BC_LOAD_CONST_OBJ:
opcode.arg += obj_table_base
# Write out new bytecode.
offset_changed = True
while offset_changed:
offset_changed = False
overflow = False
bytecode_out = b""
for opcode in opcodes:
ip = len(bytecode_out)
if opcode.offset != ip:
offset_changed = True
opcode.offset = ip
opcode_overflow, encoded_opcode = mp_opcode_encode(opcode)
if opcode_overflow:
overflow = True
bytecode_out += encoded_opcode
if overflow:
raise Exception("bytecode overflow")
return bytecode_out
def rewrite_raw_code(rc, qstr_table_base, obj_table_base):
if rc.code_kind != MP_CODE_BYTECODE:
raise Exception("can only rewrite bytecode")
source_info = bytearray()
for arg in rc.names:
source_info.extend(mp_encode_uint(qstr_table_base + arg))
closure_info = rc.fun_data[rc.offset_closure_info : rc.offset_opcodes]
bytecode_in = memoryview(rc.fun_data)[rc.offset_opcodes :]
bytecode_out = adjust_bytecode_qstr_obj_indices(bytecode_in, qstr_table_base, obj_table_base)
prelude_signature = rc.fun_data[: rc.offset_prelude_size]
prelude_size = encode_prelude_size(len(source_info), len(closure_info))
fun_data = prelude_signature + prelude_size + source_info + closure_info + bytecode_out
output = mp_encode_uint(len(fun_data) << 3 | bool(len(rc.children)) << 2)
output += fun_data
if rc.children:
output += mp_encode_uint(len(rc.children))
for child in rc.children:
output += rewrite_raw_code(child, qstr_table_base, obj_table_base)
return output
def merge_mpy(compiled_modules, output_file):
merged_mpy = bytearray()
if len(raw_codes) == 1:
with open(raw_codes[0].mpy_source_file, "rb") as f:
if len(compiled_modules) == 1:
with open(compiled_modules[0].mpy_source_file, "rb") as f:
merged_mpy.extend(f.read())
else:
main_rc = None
for rc in raw_codes:
if len(rc.qstr_table) > 1 or len(rc.obj_table) > 0:
main_cm_idx = None
for idx, cm in enumerate(compiled_modules):
if cm.header[2]:
# Must use qstr_table and obj_table from this raw_code
if main_rc is not None:
raise Exception(
"can't merge files when more than one has a populated qstr or obj table"
)
main_rc = rc
if main_rc is None:
main_rc = raw_codes[0]
if main_cm_idx is not None:
raise Exception("can't merge files when more than one contains native code")
main_cm_idx = idx
if main_cm_idx is not None:
# Shift main_cm to front of list.
compiled_modules.insert(0, compiled_modules.pop(main_cm_idx))
header = bytearray(4)
header[0] = ord("M")
@ -1504,32 +1671,50 @@ def merge_mpy(raw_codes, output_file):
header[3] = config.mp_small_int_bits
merged_mpy.extend(header)
# Copy n_qstr, n_obj, qstr_table, obj_table from main_rc.
with open(main_rc.mpy_source_file, "rb") as f:
data = f.read(main_rc.raw_code_file_offset)
merged_mpy.extend(data[4:])
n_qstr = 0
n_obj = 0
for cm in compiled_modules:
n_qstr += len(cm.qstr_table)
n_obj += len(cm.obj_table)
merged_mpy.extend(mp_encode_uint(n_qstr))
merged_mpy.extend(mp_encode_uint(n_obj))
# Copy verbatim the qstr and object tables from all compiled modules.
def copy_section(file, offset, offset2):
with open(file, "rb") as f:
f.seek(offset)
merged_mpy.extend(f.read(offset2 - offset))
for cm in compiled_modules:
copy_section(cm.mpy_source_file, cm.qstr_table_file_offset, cm.obj_table_file_offset)
for cm in compiled_modules:
copy_section(cm.mpy_source_file, cm.obj_table_file_offset, cm.raw_code_file_offset)
bytecode = bytearray()
bytecode_len = 3 + len(raw_codes) * 5 + 2
bytecode.append(bytecode_len << 3 | 1 << 2) # kind, has_children and length
bytecode.append(0b00000000) # signature prelude
bytecode.append(0b00000010) # size prelude; n_info=1
bytecode.append(0b00000000) # prelude signature
bytecode.append(0b00000010) # prelude size (n_info=1, n_cell=0)
bytecode.extend(b"\x00") # simple_name: qstr index 0 (will use source filename)
for idx in range(len(raw_codes)):
for idx in range(len(compiled_modules)):
bytecode.append(0x32) # MP_BC_MAKE_FUNCTION
bytecode.append(idx) # index raw code
bytecode.extend(b"\x34\x00\x59") # MP_BC_CALL_FUNCTION, 0 args, MP_BC_POP_TOP
bytecode.extend(b"\x51\x63") # MP_BC_LOAD_NONE, MP_BC_RETURN_VALUE
merged_mpy.extend(mp_encode_uint(len(bytecode) << 3 | 1 << 2)) # length, has_children
merged_mpy.extend(bytecode)
merged_mpy.extend(mp_encode_uint(len(compiled_modules))) # n_children
merged_mpy.append(len(raw_codes)) # n_children
for rc in raw_codes:
with open(rc.mpy_source_file, "rb") as f:
f.seek(rc.raw_code_file_offset)
data = f.read() # read rest of mpy file
merged_mpy.extend(data)
qstr_table_base = 0
obj_table_base = 0
for cm in compiled_modules:
if qstr_table_base == 0 and obj_table_base == 0:
with open(cm.mpy_source_file, "rb") as f:
f.seek(cm.raw_code_file_offset)
merged_mpy.extend(f.read())
else:
merged_mpy.extend(rewrite_raw_code(cm.raw_code, qstr_table_base, obj_table_base))
qstr_table_base += len(cm.qstr_table)
obj_table_base += len(cm.obj_table)
if output_file is None:
sys.stdout.buffer.write(merged_mpy)