py/objtuple: Properly implement comparison with incompatible types.
Should raise TypeError, unless it's (in)equality comparison.
This commit is contained in:
parent
e354b0a0cb
commit
1aaba5cabe
|
@ -101,7 +101,7 @@ STATIC mp_obj_t mp_obj_tuple_make_new(const mp_obj_type_t *type_in, size_t n_arg
|
|||
}
|
||||
|
||||
// Don't pass MP_BINARY_OP_NOT_EQUAL here
|
||||
STATIC bool tuple_cmp_helper(mp_uint_t op, mp_obj_t self_in, mp_obj_t another_in) {
|
||||
STATIC mp_obj_t tuple_cmp_helper(mp_uint_t op, mp_obj_t self_in, mp_obj_t another_in) {
|
||||
// type check is done on getiter method to allow tuple, namedtuple, attrtuple
|
||||
mp_check_self(mp_obj_get_type(self_in)->getiter == mp_obj_tuple_getiter);
|
||||
mp_obj_type_t *another_type = mp_obj_get_type(another_in);
|
||||
|
@ -110,12 +110,15 @@ STATIC bool tuple_cmp_helper(mp_uint_t op, mp_obj_t self_in, mp_obj_t another_in
|
|||
// Slow path for user subclasses
|
||||
another_in = mp_instance_cast_to_native_base(another_in, MP_OBJ_FROM_PTR(&mp_type_tuple));
|
||||
if (another_in == MP_OBJ_NULL) {
|
||||
return false;
|
||||
if (op == MP_BINARY_OP_EQUAL) {
|
||||
return mp_const_false;
|
||||
}
|
||||
return MP_OBJ_NULL;
|
||||
}
|
||||
}
|
||||
mp_obj_tuple_t *another = MP_OBJ_TO_PTR(another_in);
|
||||
|
||||
return mp_seq_cmp_objs(op, self->items, self->len, another->items, another->len);
|
||||
return mp_obj_new_bool(mp_seq_cmp_objs(op, self->items, self->len, another->items, another->len));
|
||||
}
|
||||
|
||||
mp_obj_t mp_obj_tuple_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
|
||||
|
@ -166,7 +169,7 @@ mp_obj_t mp_obj_tuple_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
|
|||
case MP_BINARY_OP_LESS_EQUAL:
|
||||
case MP_BINARY_OP_MORE:
|
||||
case MP_BINARY_OP_MORE_EQUAL:
|
||||
return mp_obj_new_bool(tuple_cmp_helper(op, lhs, rhs));
|
||||
return tuple_cmp_helper(op, lhs, rhs);
|
||||
|
||||
default:
|
||||
return MP_OBJ_NULL; // op not supported
|
||||
|
|
|
@ -53,3 +53,13 @@ print((10, 0) > (1, 1))
|
|||
print((10, 0) < (1, 1))
|
||||
print((0, 0, 10, 0) > (0, 0, 1, 1))
|
||||
print((0, 0, 10, 0) < (0, 0, 1, 1))
|
||||
|
||||
|
||||
print(() == {})
|
||||
print(() != {})
|
||||
print((1,) == [1])
|
||||
|
||||
try:
|
||||
print(() < {})
|
||||
except TypeError:
|
||||
print("TypeError")
|
||||
|
|
Loading…
Reference in New Issue