diff --git a/py/objstr.c b/py/objstr.c index f2f8063d0b..7ca8afc6ba 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -251,11 +251,13 @@ STATIC const byte *find_subbytes(const byte *haystack, machine_uint_t hlen, cons STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len); + mp_obj_type_t *lhs_type = mp_obj_get_type(lhs_in); + mp_obj_type_t *rhs_type = mp_obj_get_type(rhs_in); switch (op) { case MP_BINARY_OP_ADD: case MP_BINARY_OP_INPLACE_ADD: - if (MP_OBJ_IS_STR(rhs_in)) { - // add 2 strings + if (lhs_type == rhs_type) { + // add 2 strings or bytes GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len); int alloc_len = lhs_len + rhs_len; @@ -270,7 +272,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { // code for non-qstr byte *data; - mp_obj_t s = mp_obj_str_builder_start(mp_obj_get_type(lhs_in), alloc_len, &data); + mp_obj_t s = mp_obj_str_builder_start(lhs_type, alloc_len, &data); memcpy(data, lhs_data, lhs_len); memcpy(data + lhs_len, rhs_data, rhs_len); return mp_obj_str_builder_end(s); @@ -279,7 +281,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { case MP_BINARY_OP_IN: /* NOTE `a in b` is `b.__contains__(a)` */ - if (MP_OBJ_IS_STR(rhs_in)) { + if (lhs_type == rhs_type) { GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len); return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL); } @@ -292,7 +294,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { } int n = MP_OBJ_SMALL_INT_VALUE(rhs_in); byte *data; - mp_obj_t s = mp_obj_str_builder_start(mp_obj_get_type(lhs_in), lhs_len * n, &data); + mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data); mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data); return mp_obj_str_builder_end(s); } @@ -310,14 +312,13 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { return str_modulo_format(lhs_in, n_args, args); } - // These 2 are never passed here, dealt with as a special case in mp_binary_op(). - //case MP_BINARY_OP_EQUAL: - //case MP_BINARY_OP_NOT_EQUAL: + //case MP_BINARY_OP_NOT_EQUAL: // This is never passed here + case MP_BINARY_OP_EQUAL: // This will be passed only for bytes, str is dealt with in mp_obj_equal() case MP_BINARY_OP_LESS: case MP_BINARY_OP_LESS_EQUAL: case MP_BINARY_OP_MORE: case MP_BINARY_OP_MORE_EQUAL: - if (MP_OBJ_IS_STR(rhs_in)) { + if (lhs_type == rhs_type) { GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len); return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len)); } diff --git a/py/sequence.c b/py/sequence.c index 63f6bd6944..3d2bbba4df 100644 --- a/py/sequence.c +++ b/py/sequence.c @@ -83,6 +83,10 @@ bool m_seq_get_fast_slice_indexes(machine_uint_t len, mp_obj_t slice, machine_ui // Special-case comparison function for sequences of bytes // Don't pass MP_BINARY_OP_NOT_EQUAL here bool mp_seq_cmp_bytes(int op, const byte *data1, uint len1, const byte *data2, uint len2) { + if (op == MP_BINARY_OP_EQUAL && len1 != len2) { + return false; + } + // Let's deal only with > & >= if (op == MP_BINARY_OP_LESS || op == MP_BINARY_OP_LESS_EQUAL) { SWAP(const byte*, data1, data2); diff --git a/tests/basics/bytes_compare.py b/tests/basics/bytes_compare.py new file mode 100644 index 0000000000..3804844feb --- /dev/null +++ b/tests/basics/bytes_compare.py @@ -0,0 +1,51 @@ +print(b"" == b"") +print(b"" > b"") +print(b"" < b"") +print(b"" == b"1") +print(b"1" == b"") +print("==") +print(b"" > b"1") +print(b"1" > b"") +print(b"" < b"1") +print(b"1" < b"") +print(b"" >= b"1") +print(b"1" >= b"") +print(b"" <= b"1") +print(b"1" <= b"") + +print(b"1" == b"1") +print(b"1" != b"1") +print(b"1" == b"2") +print(b"1" == b"10") + +print(b"1" > b"1") +print(b"1" > b"2") +print(b"2" > b"1") +print(b"10" > b"1") +print(b"1/" > b"1") +print(b"1" > b"10") +print(b"1" > b"1/") + +print(b"1" < b"1") +print(b"2" < b"1") +print(b"1" < b"2") +print(b"1" < b"10") +print(b"1" < b"1/") +print(b"10" < b"1") +print(b"1/" < b"1") + +print(b"1" >= b"1") +print(b"1" >= b"2") +print(b"2" >= b"1") +print(b"10" >= b"1") +print(b"1/" >= b"1") +print(b"1" >= b"10") +print(b"1" >= b"1/") + +print(b"1" <= b"1") +print(b"2" <= b"1") +print(b"1" <= b"2") +print(b"1" <= b"10") +print(b"1" <= b"1/") +print(b"10" <= b"1") +print(b"1/" <= b"1")