diff --git a/ports/stm32/mboot/dfu.h b/ports/stm32/mboot/dfu.h index 1f53c1f069..e826e217f0 100644 --- a/ports/stm32/mboot/dfu.h +++ b/ports/stm32/mboot/dfu.h @@ -32,6 +32,13 @@ #define DFU_XFER_SIZE (2048) +// The DFU standard allows for error messages to be sent (via str index) in GETSTATUS responses +#define MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER_IDX 0x10 +#define MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER "Can't overwrite mboot" + +#define MBOOT_ERROR_STR_INVALID_ADDRESS_IDX 0x11 +#define MBOOT_ERROR_STR_INVALID_ADDRESS "Address out of range" + // DFU class requests enum { DFU_DETACH = 0, diff --git a/ports/stm32/mboot/main.c b/ports/stm32/mboot/main.c index 134e40a839..b24fa7daa5 100644 --- a/ports/stm32/mboot/main.c +++ b/ports/stm32/mboot/main.c @@ -526,6 +526,13 @@ static const flash_layout_t flash_layout[] = { #endif +static inline bool flash_is_valid_addr(uint32_t addr) { + uint8_t last = MP_ARRAY_SIZE(flash_layout) - 1; + uint32_t end_of_flash = flash_layout[last].base_address + + flash_layout[last].sector_count * flash_layout[last].sector_size; + return flash_layout[0].base_address <= addr && addr < end_of_flash; +} + static uint32_t flash_get_sector_index(uint32_t addr, uint32_t *sector_size) { if (addr >= flash_layout[0].base_address) { uint32_t sector_index = 0; @@ -575,6 +582,8 @@ static int flash_page_erase(uint32_t addr, uint32_t *next_addr) { uint32_t sector = flash_get_sector_index(addr, §or_size); if (sector == 0) { // Don't allow to erase the sector with this bootloader in it + dfu_context.status = DFU_STATUS_ERROR_ADDRESS; + dfu_context.error = MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER_IDX; return -1; } @@ -619,6 +628,8 @@ static int flash_page_erase(uint32_t addr, uint32_t *next_addr) { static int flash_write(uint32_t addr, const uint8_t *src8, size_t len) { if (addr >= flash_layout[0].base_address && addr < flash_layout[0].base_address + flash_layout[0].sector_size) { // Don't allow to write the sector with this bootloader in it + dfu_context.status = DFU_STATUS_ERROR_ADDRESS; + dfu_context.error = MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER_IDX; return -1; } @@ -732,7 +743,13 @@ int do_write(uint32_t addr, const uint8_t *src8, size_t len) { } #endif - return flash_write(addr, src8, len); + if (flash_is_valid_addr(addr)) { + return flash_write(addr, src8, len); + } + + dfu_context.status = DFU_STATUS_ERROR_ADDRESS; + dfu_context.error = MBOOT_ERROR_STR_INVALID_ADDRESS_IDX; + return -1; } /******************************************************************************/ @@ -909,6 +926,8 @@ uint8_t i2c_slave_process_tx_byte(void) { static void dfu_init(void) { dfu_context.state = DFU_STATE_IDLE; dfu_context.cmd = DFU_CMD_NONE; + dfu_context.status = DFU_STATUS_OK; + dfu_context.error = 0; dfu_context.addr = 0x08000000; } @@ -1158,6 +1177,14 @@ static uint8_t *pyb_usbdd_StrDescriptor(USBD_HandleTypeDef *pdev, uint8_t idx, u USBD_GetString((uint8_t*)FLASH_LAYOUT_STR, str_desc, length); return str_desc; + case MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER_IDX: + USBD_GetString((uint8_t*)MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER, str_desc, length); + return str_desc; + + case MBOOT_ERROR_STR_INVALID_ADDRESS_IDX: + USBD_GetString((uint8_t*)MBOOT_ERROR_STR_INVALID_ADDRESS, str_desc, length); + return str_desc; + default: return NULL; }