stm32/mboot: Protect against invalid address flash writes.

And provide a DFU error message for invalid erases and writes.
This commit is contained in:
Andrew Leech 2020-02-24 17:57:19 +11:00 committed by Damien George
parent b41d08cf15
commit f7130a99b6
2 changed files with 35 additions and 1 deletions

View File

@ -32,6 +32,13 @@
#define DFU_XFER_SIZE (2048) #define DFU_XFER_SIZE (2048)
// The DFU standard allows for error messages to be sent (via str index) in GETSTATUS responses
#define MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER_IDX 0x10
#define MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER "Can't overwrite mboot"
#define MBOOT_ERROR_STR_INVALID_ADDRESS_IDX 0x11
#define MBOOT_ERROR_STR_INVALID_ADDRESS "Address out of range"
// DFU class requests // DFU class requests
enum { enum {
DFU_DETACH = 0, DFU_DETACH = 0,

View File

@ -526,6 +526,13 @@ static const flash_layout_t flash_layout[] = {
#endif #endif
static inline bool flash_is_valid_addr(uint32_t addr) {
uint8_t last = MP_ARRAY_SIZE(flash_layout) - 1;
uint32_t end_of_flash = flash_layout[last].base_address +
flash_layout[last].sector_count * flash_layout[last].sector_size;
return flash_layout[0].base_address <= addr && addr < end_of_flash;
}
static uint32_t flash_get_sector_index(uint32_t addr, uint32_t *sector_size) { static uint32_t flash_get_sector_index(uint32_t addr, uint32_t *sector_size) {
if (addr >= flash_layout[0].base_address) { if (addr >= flash_layout[0].base_address) {
uint32_t sector_index = 0; uint32_t sector_index = 0;
@ -575,6 +582,8 @@ static int flash_page_erase(uint32_t addr, uint32_t *next_addr) {
uint32_t sector = flash_get_sector_index(addr, &sector_size); uint32_t sector = flash_get_sector_index(addr, &sector_size);
if (sector == 0) { if (sector == 0) {
// Don't allow to erase the sector with this bootloader in it // Don't allow to erase the sector with this bootloader in it
dfu_context.status = DFU_STATUS_ERROR_ADDRESS;
dfu_context.error = MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER_IDX;
return -1; return -1;
} }
@ -619,6 +628,8 @@ static int flash_page_erase(uint32_t addr, uint32_t *next_addr) {
static int flash_write(uint32_t addr, const uint8_t *src8, size_t len) { static int flash_write(uint32_t addr, const uint8_t *src8, size_t len) {
if (addr >= flash_layout[0].base_address && addr < flash_layout[0].base_address + flash_layout[0].sector_size) { if (addr >= flash_layout[0].base_address && addr < flash_layout[0].base_address + flash_layout[0].sector_size) {
// Don't allow to write the sector with this bootloader in it // Don't allow to write the sector with this bootloader in it
dfu_context.status = DFU_STATUS_ERROR_ADDRESS;
dfu_context.error = MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER_IDX;
return -1; return -1;
} }
@ -732,7 +743,13 @@ int do_write(uint32_t addr, const uint8_t *src8, size_t len) {
} }
#endif #endif
return flash_write(addr, src8, len); if (flash_is_valid_addr(addr)) {
return flash_write(addr, src8, len);
}
dfu_context.status = DFU_STATUS_ERROR_ADDRESS;
dfu_context.error = MBOOT_ERROR_STR_INVALID_ADDRESS_IDX;
return -1;
} }
/******************************************************************************/ /******************************************************************************/
@ -909,6 +926,8 @@ uint8_t i2c_slave_process_tx_byte(void) {
static void dfu_init(void) { static void dfu_init(void) {
dfu_context.state = DFU_STATE_IDLE; dfu_context.state = DFU_STATE_IDLE;
dfu_context.cmd = DFU_CMD_NONE; dfu_context.cmd = DFU_CMD_NONE;
dfu_context.status = DFU_STATUS_OK;
dfu_context.error = 0;
dfu_context.addr = 0x08000000; dfu_context.addr = 0x08000000;
} }
@ -1158,6 +1177,14 @@ static uint8_t *pyb_usbdd_StrDescriptor(USBD_HandleTypeDef *pdev, uint8_t idx, u
USBD_GetString((uint8_t*)FLASH_LAYOUT_STR, str_desc, length); USBD_GetString((uint8_t*)FLASH_LAYOUT_STR, str_desc, length);
return str_desc; return str_desc;
case MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER_IDX:
USBD_GetString((uint8_t*)MBOOT_ERROR_STR_OVERWRITE_BOOTLOADER, str_desc, length);
return str_desc;
case MBOOT_ERROR_STR_INVALID_ADDRESS_IDX:
USBD_GetString((uint8_t*)MBOOT_ERROR_STR_INVALID_ADDRESS, str_desc, length);
return str_desc;
default: default:
return NULL; return NULL;
} }