diff --git a/py/objint.c b/py/objint.c index 9e11871f10..b78c9f25b1 100644 --- a/py/objint.c +++ b/py/objint.c @@ -490,15 +490,24 @@ STATIC mp_obj_t int_from_bytes(size_t n_args, const mp_obj_t *args) { STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(int_from_bytes_fun_obj, 3, 4, int_from_bytes); STATIC MP_DEFINE_CONST_CLASSMETHOD_OBJ(int_from_bytes_obj, MP_ROM_PTR(&int_from_bytes_fun_obj)); -STATIC mp_obj_t int_to_bytes(size_t n_args, const mp_obj_t *args) { - // TODO: Support signed param (assumes signed=False) - (void)n_args; +STATIC mp_obj_t int_to_bytes(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { + enum { ARG_length, ARG_byteorder, ARG_signed }; + static const mp_arg_t allowed_args[] = { + { MP_QSTR_length, MP_ARG_REQUIRED | MP_ARG_INT }, + { MP_QSTR_byteorder, MP_ARG_REQUIRED | MP_ARG_OBJ }, + { MP_QSTR_signed, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} }, + }; + mp_arg_val_t args[MP_ARRAY_SIZE(allowed_args)]; + mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, MP_ARRAY_SIZE(allowed_args), allowed_args, args); - mp_int_t len = mp_obj_get_int(args[1]); + mp_int_t len = args[ARG_length].u_int; if (len < 0) { mp_raise_ValueError(NULL); } - bool big_endian = args[2] != MP_OBJ_NEW_QSTR(MP_QSTR_little); + + mp_obj_t self = pos_args[0]; + bool big_endian = args[ARG_byteorder].u_obj != MP_OBJ_NEW_QSTR(MP_QSTR_little); + bool signed_ = args[ARG_signed].u_bool; vstr_t vstr; vstr_init_len(&vstr, len); @@ -506,22 +515,26 @@ STATIC mp_obj_t int_to_bytes(size_t n_args, const mp_obj_t *args) { memset(data, 0, len); #if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE - if (!MP_OBJ_IS_SMALL_INT(args[0])) { - mp_obj_int_buffer_overflow_check(args[0], len, false); - mp_obj_int_to_bytes_impl(args[0], big_endian, len, data); + if (!MP_OBJ_IS_SMALL_INT(self)) { + mp_obj_int_buffer_overflow_check(self, len, signed_); + mp_obj_int_to_bytes_impl(self, big_endian, len, data); } else #endif { - mp_int_t val = MP_OBJ_SMALL_INT_VALUE(args[0]); + mp_int_t val = MP_OBJ_SMALL_INT_VALUE(self); // Small int checking is separate, to be fast. - mp_small_int_buffer_overflow_check(val, len, false); + mp_small_int_buffer_overflow_check(val, len, signed_); size_t l = MIN((size_t)len, sizeof(val)); + if (val < 0) { + // Sign extend negative numbers. + memset(data, -1, len); + } mp_binary_set_int(l, big_endian, data + (big_endian ? (len - l) : 0), val); } return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); } -STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(int_to_bytes_obj, 3, 4, int_to_bytes); +STATIC MP_DEFINE_CONST_FUN_OBJ_KW(int_to_bytes_obj, 3, int_to_bytes); STATIC const mp_rom_map_elem_t int_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR_from_bytes), MP_ROM_PTR(&int_from_bytes_obj) }, diff --git a/tests/basics/int_bytes.py b/tests/basics/int_bytes.py index aaf4400815..d42afac1fd 100644 --- a/tests/basics/int_bytes.py +++ b/tests/basics/int_bytes.py @@ -1,8 +1,11 @@ print((10).to_bytes(1, "little")) +print((-10).to_bytes(1, "little", signed=True)) # Test fitting in length that's not a power of two. print((0x10000).to_bytes(3, 'little')) print((111111).to_bytes(4, "little")) +print((-111111).to_bytes(4, "little", signed=True)) print((100).to_bytes(10, "little")) +print((-100).to_bytes(10, "little", signed=True)) # check that extra zero bytes don't change the internal int value print(int.from_bytes(bytes(20), "little") == 0) @@ -10,7 +13,9 @@ print(int.from_bytes(b"\x01" + bytes(20), "little") == 1) # big-endian conversion print((10).to_bytes(1, "big")) +print((-10).to_bytes(1, "big", signed=True)) print((100).to_bytes(10, "big")) +print((-100).to_bytes(10, "big", signed=True)) print(int.from_bytes(b"\0\0\0\0\0\0\0\0\0\x01", "big")) print(int.from_bytes(b"\x01\0", "big")) @@ -26,8 +31,13 @@ try: except OverflowError: print("OverflowError") -# negative numbers should raise an error +# negative numbers should raise an error if signed=False try: (-256).to_bytes(2, "little") except OverflowError: print("OverflowError") + +try: + (-256).to_bytes(2, "little", signed=False) +except OverflowError: + print("OverflowError") diff --git a/tests/basics/int_bytes_intbig.py b/tests/basics/int_bytes_intbig.py index b19c8ae53e..a9f0699024 100644 --- a/tests/basics/int_bytes_intbig.py +++ b/tests/basics/int_bytes_intbig.py @@ -2,7 +2,9 @@ import skip_if skip_if.no_bigint() print((2**64).to_bytes(9, "little")) +print((-2**64).to_bytes(9, "little", signed=True)) print((2**64).to_bytes(9, "big")) +print((-2**64).to_bytes(9, "big", signed=True)) b = bytes(range(20)) @@ -22,8 +24,12 @@ try: except OverflowError: print("OverflowError") -# negative numbers should raise an error +# negative numbers should raise an error if signed=False try: (-2**64).to_bytes(9, "little") except OverflowError: print("OverflowError") +try: + (-2**64).to_bytes(9, "little", signed=False) +except OverflowError: + print("OverflowError") diff --git a/tests/basics/int_longint_bytes.py b/tests/basics/int_longint_bytes.py index 42e445f539..90cbff2f87 100644 --- a/tests/basics/int_longint_bytes.py +++ b/tests/basics/int_longint_bytes.py @@ -3,6 +3,7 @@ import skip_if skip_if.no_bigint() print((2**64).to_bytes(9, "little")) +print((-2**64).to_bytes(9, "little", signed=True)) print(int.from_bytes(b"\x00\x01\0\0\0\0\0\0", "little")) print(int.from_bytes(b"\x01\0\0\0\0\0\0\0", "little")) print(int.from_bytes(b"\x00\x01\0\0\0\0\0\0", "little"))