diff --git a/py/modstruct.c b/py/modstruct.c index 0d4a45f6b6..61dd0f81b4 100644 --- a/py/modstruct.c +++ b/py/modstruct.c @@ -82,26 +82,10 @@ STATIC mp_uint_t get_fmt_num(const char **p) { return val; } -STATIC uint calcsize_items(const char *fmt) { - uint cnt = 0; - while (*fmt) { - int num = 1; - if (unichar_isdigit(*fmt)) { - num = get_fmt_num(&fmt); - if (*fmt == 's') { - num = 1; - } - } - cnt += num; - fmt++; - } - return cnt; -} - -STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) { - const char *fmt = mp_obj_str_get_str(fmt_in); +STATIC size_t calc_size_items(const char *fmt, size_t *total_sz) { char fmt_type = get_fmt_type(&fmt); - mp_uint_t size; + size_t total_cnt = 0; + size_t size; for (size = 0; *fmt; fmt++) { mp_uint_t cnt = 1; if (unichar_isdigit(*fmt)) { @@ -109,8 +93,10 @@ STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) { } if (*fmt == 's') { + total_cnt += 1; size += cnt; } else { + total_cnt += cnt; mp_uint_t align; size_t sz = mp_binary_get_size(fmt_type, *fmt, &align); while (cnt--) { @@ -120,6 +106,14 @@ STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) { } } } + *total_sz = size; + return total_cnt; +} + +STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) { + const char *fmt = mp_obj_str_get_str(fmt_in); + size_t size; + calc_size_items(fmt, &size); return MP_OBJ_NEW_SMALL_INT(size); } MP_DEFINE_CONST_FUN_OBJ_1(struct_calcsize_obj, struct_calcsize); @@ -130,8 +124,9 @@ STATIC mp_obj_t struct_unpack_from(size_t n_args, const mp_obj_t *args) { // Since we implement unpack and unpack_from using the same function // we relax the "exact" requirement, and only implement "big enough". const char *fmt = mp_obj_str_get_str(args[0]); + size_t total_sz; + size_t num_items = calc_size_items(fmt, &total_sz); char fmt_type = get_fmt_type(&fmt); - uint num_items = calcsize_items(fmt); mp_obj_tuple_t *res = MP_OBJ_TO_PTR(mp_obj_new_tuple(num_items, NULL)); mp_buffer_info_t bufinfo; mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_READ); @@ -152,21 +147,23 @@ STATIC mp_obj_t struct_unpack_from(size_t n_args, const mp_obj_t *args) { p += offset; } - for (uint i = 0; i < num_items;) { - mp_uint_t sz = 1; + // Check that the input buffer is big enough to unpack all the values + if (p + total_sz > end_p) { + mp_raise_ValueError("buffer too small"); + } + + for (size_t i = 0; i < num_items;) { + mp_uint_t cnt = 1; if (unichar_isdigit(*fmt)) { - sz = get_fmt_num(&fmt); - } - if (p + sz > end_p) { - mp_raise_ValueError("buffer too small"); + cnt = get_fmt_num(&fmt); } mp_obj_t item; if (*fmt == 's') { - item = mp_obj_new_bytes(p, sz); - p += sz; + item = mp_obj_new_bytes(p, cnt); + p += cnt; res->items[i++] = item; } else { - while (sz--) { + while (cnt--) { item = mp_binary_get_val(fmt_type, *fmt, &p); res->items[i++] = item; } diff --git a/tests/basics/struct1.py b/tests/basics/struct1.py index a442beb1e5..2cf75137b8 100644 --- a/tests/basics/struct1.py +++ b/tests/basics/struct1.py @@ -39,6 +39,12 @@ print(v == (10, 100, 200, 300)) # network byte order print(struct.pack('!i', 123)) +# check that we get an error if the buffer is too small +try: + struct.unpack('I', b'\x00\x00\x00') +except: + print('struct.error') + # first arg must be a string try: struct.pack(1, 2) diff --git a/tests/basics/struct2.py b/tests/basics/struct2.py index 3b9dd5c1f6..e3336c0c78 100644 --- a/tests/basics/struct2.py +++ b/tests/basics/struct2.py @@ -25,6 +25,12 @@ print(struct.calcsize('0s1s0H2H')) print(struct.unpack('<0s1s0H2H', b'01234')) print(struct.pack('<0s1s0H2H', b'abc', b'abc', 258, 515)) +# check that we get an error if the buffer is too small +try: + struct.unpack('2H', b'\x00\x00') +except: + print('Exception') + # check that unknown types raise an exception try: struct.unpack('z', b'1')