# Driver for Mboot, the MicroPython boot loader # MIT license; Copyright (c) 2018 Damien P. George import struct, time, os, hashlib I2C_CMD_ECHO = 1 I2C_CMD_GETID = 2 I2C_CMD_GETCAPS = 3 I2C_CMD_RESET = 4 I2C_CMD_CONFIG = 5 I2C_CMD_GETLAYOUT = 6 I2C_CMD_MASSERASE = 7 I2C_CMD_PAGEERASE = 8 I2C_CMD_SETRDADDR = 9 I2C_CMD_SETWRADDR = 10 I2C_CMD_READ = 11 I2C_CMD_WRITE = 12 I2C_CMD_COPY = 13 I2C_CMD_CALCHASH = 14 I2C_CMD_MARKVALID = 15 class Bootloader: def __init__(self, i2c, addr): self.i2c = i2c self.addr = addr self.buf1 = bytearray(1) try: self.i2c.writeto(addr, b'') except OSError: raise Exception('no I2C mboot device found') def wait_response(self): start = time.ticks_ms() while 1: try: self.i2c.readfrom_into(self.addr, self.buf1) n = self.buf1[0] break except OSError as er: time.sleep_us(500) if time.ticks_diff(time.ticks_ms(), start) > 5000: raise Exception('timeout') if n >= 129: raise Exception(n) if n == 0: return b'' else: return self.i2c.readfrom(self.addr, n) def wait_empty_response(self): ret = self.wait_response() if ret: raise Exception('expected empty response got %r' % ret) else: return None def echo(self, data): self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_ECHO) + data) return self.wait_response() def getid(self): self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_GETID)) ret = self.wait_response() unique_id = ret[:12] mcu_name, board_name = ret[12:].split(b'\x00') return unique_id, str(mcu_name, 'ascii'), str(board_name, 'ascii') def reset(self): self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_RESET)) # we don't expect any response def getlayout(self): self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_GETLAYOUT)) layout = self.wait_response() id, flash_addr, layout = layout.split(b'/') assert id == b'@Internal Flash ' flash_addr = int(flash_addr, 16) pages = [] for chunk in layout.split(b','): n, sz = chunk.split(b'*') n = int(n) assert sz.endswith(b'Kg') sz = int(sz[:-2]) * 1024 for i in range(n): pages.append((flash_addr, sz)) flash_addr += sz return pages def pageerase(self, addr): self.i2c.writeto(self.addr, struct.pack('<BI', I2C_CMD_PAGEERASE, addr)) self.wait_empty_response() def setrdaddr(self, addr): self.i2c.writeto(self.addr, struct.pack('<BI', I2C_CMD_SETRDADDR, addr)) self.wait_empty_response() def setwraddr(self, addr): self.i2c.writeto(self.addr, struct.pack('<BI', I2C_CMD_SETWRADDR, addr)) self.wait_empty_response() def read(self, n): self.i2c.writeto(self.addr, struct.pack('<BB', I2C_CMD_READ, n)) return self.wait_response() def write(self, buf): self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_WRITE) + buf) self.wait_empty_response() def calchash(self, n): self.i2c.writeto(self.addr, struct.pack('<BI', I2C_CMD_CALCHASH, n)) return self.wait_response() def markvalid(self): self.i2c.writeto(self.addr, struct.pack('<B', I2C_CMD_MARKVALID)) self.wait_empty_response() def deployfile(self, filename, addr): pages = self.getlayout() page_erased = [False] * len(pages) buf = bytearray(128) # maximum payload supported by I2C protocol start_addr = addr self.setwraddr(addr) fsize = os.stat(filename)[6] local_sha = hashlib.sha256() print('Deploying %s to location 0x%08x' % (filename, addr)) with open(filename, 'rb') as f: t0 = time.ticks_ms() while True: n = f.readinto(buf) if n == 0: break # check if we need to erase the page for i, p in enumerate(pages): if p[0] <= addr < p[0] + p[1]: # found page if not page_erased[i]: print('\r% 3u%% erase 0x%08x' % (100 * (addr - start_addr) // fsize, addr), end='') self.pageerase(addr) page_erased[i] = True break else: raise Exception('address 0x%08x not valid' % addr) # write the data self.write(buf) # update local SHA256, with validity bits set if addr == start_addr: buf[0] |= 3 if n == len(buf): local_sha.update(buf) else: local_sha.update(buf[:n]) addr += n ntotal = addr - start_addr if ntotal % 2048 == 0 or ntotal == fsize: print('\r% 3u%% % 7u bytes ' % (100 * ntotal // fsize, ntotal), end='') t1 = time.ticks_ms() print() print('rate: %.2f KiB/sec' % (1024 * ntotal / (t1 - t0) / 1000)) local_sha = local_sha.digest() print('Local SHA256: ', ''.join('%02x' % x for x in local_sha)) self.setrdaddr(start_addr) remote_sha = self.calchash(ntotal) print('Remote SHA256:', ''.join('%02x' % x for x in remote_sha)) if local_sha == remote_sha: print('Marking app firmware as valid') self.markvalid() self.reset()