py/objarray: Detect bytearray(str) without an encoding.

This prevents a very subtle bug caused by writing e.g. `bytearray('\xfd')`
which gives you `(0xc3, 0xbd)`.

This work was funded through GitHub Sponsors.

Signed-off-by: Jim Mussared <jim.mussared@gmail.com>
This commit is contained in:
Jim Mussared 2022-11-07 12:55:31 +11:00 committed by Damien George
parent f8b0ae32d3
commit 2c8dab7ab4
4 changed files with 18 additions and 1 deletions

View File

@ -192,6 +192,14 @@ STATIC mp_obj_t bytearray_make_new(const mp_obj_type_t *type_in, size_t n_args,
return MP_OBJ_FROM_PTR(o); return MP_OBJ_FROM_PTR(o);
} else { } else {
// 1 arg: construct the bytearray from that // 1 arg: construct the bytearray from that
if (mp_obj_is_str(args[0]) && n_args == 1) {
#if MICROPY_ERROR_REPORTING <= MICROPY_ERROR_REPORTING_TERSE
// Match bytes_make_new.
mp_raise_TypeError(MP_ERROR_TEXT("wrong number of arguments"));
#else
mp_raise_TypeError(MP_ERROR_TEXT("string argument without an encoding"));
#endif
}
return array_construct(BYTEARRAY_TYPECODE, args[0]); return array_construct(BYTEARRAY_TYPECODE, args[0]);
} }
} }

View File

@ -233,7 +233,11 @@ STATIC mp_obj_t bytes_make_new(const mp_obj_type_t *type_in, size_t n_args, size
if (mp_obj_is_str(args[0])) { if (mp_obj_is_str(args[0])) {
if (n_args < 2 || n_args > 3) { if (n_args < 2 || n_args > 3) {
#if MICROPY_ERROR_REPORTING <= MICROPY_ERROR_REPORTING_TERSE
goto wrong_args; goto wrong_args;
#else
mp_raise_TypeError(MP_ERROR_TEXT("string argument without an encoding"));
#endif
} }
GET_STR_DATA_LEN(args[0], str_data, str_len); GET_STR_DATA_LEN(args[0], str_data, str_len);
GET_STR_HASH(args[0], str_hash); GET_STR_HASH(args[0], str_hash);

View File

@ -5,3 +5,8 @@ print(bytearray('1234', 'utf-8'))
print(bytearray('12345', 'utf-8', 'strict')) print(bytearray('12345', 'utf-8', 'strict'))
print(bytearray((1, 2))) print(bytearray((1, 2)))
print(bytearray([1, 2])) print(bytearray([1, 2]))
try:
print(bytearray('1234'))
except TypeError:
print("TypeError")

View File

@ -21,7 +21,7 @@ def memsum(src: ptr8, n: int) -> int:
# create array and get its address # create array and get its address
ar = bytearray("0000") ar = bytearray(b"0000")
addr = get_addr(ar) addr = get_addr(ar)
print(type(ar)) print(type(ar))
print(type(addr)) print(type(addr))