py: Support non-boolean results for equality and inequality tests.

This commit implements a more complete replication of CPython's behaviour
for equality and inequality testing of objects.  This addresses the issues
discussed in #5382 and a few other inconsistencies.  Improvements over the
old code include:

- Support for returning non-boolean results from comparisons (as used by
  numpy and others).
- Support for non-reflexive equality tests.
- Preferential use of __ne__ methods and MP_BINARY_OP_NOT_EQUAL binary
  operators for inequality tests, when available.
- Fallback to op2 == op1 or op2 != op1 when op1 does not implement the
  (in)equality operators.

The scheme here makes use of a new flag, MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
in the flags word of mp_obj_type_t to indicate if various shortcuts can or
cannot be used when performing equality and inequality tests.  Currently
four built-in classes have the flag set: float and complex are
non-reflexive (since nan != nan) while bytearray and frozenszet instances
can equal other builtin class instances (bytes and set respectively).  The
flag is also set for any new class defined by the user.

This commit also includes a more comprehensive set of tests for the
behaviour of (in)equality operators implemented in special methods.
This commit is contained in:
Nicko van Someren 2019-12-31 15:19:12 -07:00 committed by Damien George
parent c3450effd4
commit 3aab54bf43
10 changed files with 147 additions and 61 deletions

110
py/obj.c
View File

@ -189,7 +189,7 @@ bool mp_obj_is_callable(mp_obj_t o_in) {
return mp_obj_instance_is_callable(o_in); return mp_obj_instance_is_callable(o_in);
} }
// This function implements the '==' operator (and so the inverse of '!='). // This function implements the '==' and '!=' operators.
// //
// From the Python language reference: // From the Python language reference:
// (https://docs.python.org/3/reference/expressions.html#not-in) // (https://docs.python.org/3/reference/expressions.html#not-in)
@ -202,67 +202,89 @@ bool mp_obj_is_callable(mp_obj_t o_in) {
// Furthermore, from the v3.4.2 code for object.c: "Practical amendments: If rich // Furthermore, from the v3.4.2 code for object.c: "Practical amendments: If rich
// comparison returns NotImplemented, == and != are decided by comparing the object // comparison returns NotImplemented, == and != are decided by comparing the object
// pointer." // pointer."
bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2) {
// Float (and complex) NaN is never equal to anything, not even itself, mp_obj_t local_true = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_false : mp_const_true;
// so we must have a special check here to cover those cases. mp_obj_t local_false = (op == MP_BINARY_OP_NOT_EQUAL) ? mp_const_true : mp_const_false;
if (o1 == o2 int pass_number = 0;
#if MICROPY_PY_BUILTINS_FLOAT
&& !mp_obj_is_float(o1)
#endif
#if MICROPY_PY_BUILTINS_COMPLEX
&& !mp_obj_is_type(o1, &mp_type_complex)
#endif
) {
return true;
}
if (o1 == mp_const_none || o2 == mp_const_none) {
return false;
}
// fast path for small ints // Shortcut for very common cases
if (mp_obj_is_small_int(o1)) { if (o1 == o2 &&
if (mp_obj_is_small_int(o2)) { (mp_obj_is_small_int(o1) || !(mp_obj_get_type(o1)->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST))) {
// both SMALL_INT, and not equal if we get here return local_true;
return false;
} else {
mp_obj_t temp = o2; o2 = o1; o1 = temp;
// o2 is now the SMALL_INT, o1 is not
// fall through to generic op
}
} }
// fast path for strings // fast path for strings
if (mp_obj_is_str(o1)) { if (mp_obj_is_str(o1)) {
if (mp_obj_is_str(o2)) { if (mp_obj_is_str(o2)) {
// both strings, use special function // both strings, use special function
return mp_obj_str_equal(o1, o2); return mp_obj_str_equal(o1, o2) ? local_true : local_false;
} else {
// a string is never equal to anything else
goto str_cmp_err;
}
} else if (mp_obj_is_str(o2)) {
// o1 is not a string (else caught above), so the objects are not equal
str_cmp_err:
#if MICROPY_PY_STR_BYTES_CMP_WARN #if MICROPY_PY_STR_BYTES_CMP_WARN
if (mp_obj_is_type(o1, &mp_type_bytes) || mp_obj_is_type(o2, &mp_type_bytes)) { } else if (mp_obj_is_type(o2, &mp_type_bytes)) {
str_bytes_cmp:
mp_warning(MP_WARN_CAT(BytesWarning), "Comparison between bytes and str"); mp_warning(MP_WARN_CAT(BytesWarning), "Comparison between bytes and str");
} return local_false;
#endif #endif
return false; } else {
goto skip_one_pass;
}
#if MICROPY_PY_STR_BYTES_CMP_WARN
} else if (mp_obj_is_str(o2) && mp_obj_is_type(o1, &mp_type_bytes)) {
// o1 is not a string (else caught above), so the objects are not equal
goto str_bytes_cmp;
#endif
}
// fast path for small ints
if (mp_obj_is_small_int(o1)) {
if (mp_obj_is_small_int(o2)) {
// both SMALL_INT, and not equal if we get here
return local_false;
} else {
goto skip_one_pass;
}
} }
// generic type, call binary_op(MP_BINARY_OP_EQUAL) // generic type, call binary_op(MP_BINARY_OP_EQUAL)
while (pass_number < 2) {
const mp_obj_type_t *type = mp_obj_get_type(o1); const mp_obj_type_t *type = mp_obj_get_type(o1);
if (type->binary_op != NULL) { // If a full equality test is not needed and the other object is a different
mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2); // type then we don't need to bother trying the comparison.
if (type->binary_op != NULL &&
((type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST) || mp_obj_get_type(o2) == type)) {
// CPython is asymmetric: it will try __eq__ if there's no __ne__ but not the
// other way around. If the class doesn't need a full test we can skip __ne__.
if (op == MP_BINARY_OP_NOT_EQUAL && (type->flags & MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST)) {
mp_obj_t r = type->binary_op(MP_BINARY_OP_NOT_EQUAL, o1, o2);
if (r != MP_OBJ_NULL) { if (r != MP_OBJ_NULL) {
return r == mp_const_true ? true : false; return r;
} }
} }
// equality not implemented, and objects are not the same object, so // Try calling __eq__.
// they are defined as not equal mp_obj_t r = type->binary_op(MP_BINARY_OP_EQUAL, o1, o2);
return false; if (r != MP_OBJ_NULL) {
if (op == MP_BINARY_OP_EQUAL) {
return r;
} else {
return mp_obj_is_true(r) ? local_true : local_false;
}
}
}
skip_one_pass:
// Try the other way around if none of the above worked
++pass_number;
mp_obj_t temp = o1;
o1 = o2;
o2 = temp;
}
// equality not implemented, so fall back to pointer conparison
return (o1 == o2) ? local_true : local_false;
}
bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) {
return mp_obj_is_true(mp_obj_equal_not_equal(MP_BINARY_OP_EQUAL, o1, o2));
} }
mp_int_t mp_obj_get_int(mp_const_obj_t arg) { mp_int_t mp_obj_get_int(mp_const_obj_t arg) {

View File

@ -445,8 +445,14 @@ typedef mp_obj_t (*mp_fun_var_t)(size_t n, const mp_obj_t *);
typedef mp_obj_t (*mp_fun_kw_t)(size_t n, const mp_obj_t *, mp_map_t *); typedef mp_obj_t (*mp_fun_kw_t)(size_t n, const mp_obj_t *, mp_map_t *);
// Flags for type behaviour (mp_obj_type_t.flags) // Flags for type behaviour (mp_obj_type_t.flags)
// If MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST is clear then all the following hold:
// (a) the type only implements the __eq__ operator and not the __ne__ operator;
// (b) __eq__ returns a boolean result (False or True);
// (c) __eq__ is reflexive (A==A is True);
// (d) the type can't be equal to an instance of any different class that also clears this flag.
#define MP_TYPE_FLAG_IS_SUBCLASSED (0x0001) #define MP_TYPE_FLAG_IS_SUBCLASSED (0x0001)
#define MP_TYPE_FLAG_HAS_SPECIAL_ACCESSORS (0x0002) #define MP_TYPE_FLAG_HAS_SPECIAL_ACCESSORS (0x0002)
#define MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST (0x0004)
typedef enum { typedef enum {
PRINT_STR = 0, PRINT_STR = 0,
@ -729,6 +735,7 @@ void mp_obj_print_exception(const mp_print_t *print, mp_obj_t exc);
bool mp_obj_is_true(mp_obj_t arg); bool mp_obj_is_true(mp_obj_t arg);
bool mp_obj_is_callable(mp_obj_t o_in); bool mp_obj_is_callable(mp_obj_t o_in);
mp_obj_t mp_obj_equal_not_equal(mp_binary_op_t op, mp_obj_t o1, mp_obj_t o2);
bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2); bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2);
static inline bool mp_obj_is_integer(mp_const_obj_t o) { return mp_obj_is_int(o) || mp_obj_is_bool(o); } // returns true if o is bool, small int or long int static inline bool mp_obj_is_integer(mp_const_obj_t o) { return mp_obj_is_int(o) || mp_obj_is_bool(o); } // returns true if o is bool, small int or long int

View File

@ -558,6 +558,7 @@ const mp_obj_type_t mp_type_array = {
const mp_obj_type_t mp_type_bytearray = { const mp_obj_type_t mp_type_bytearray = {
{ &mp_type_type }, { &mp_type_type },
.name = MP_QSTR_bytearray, .name = MP_QSTR_bytearray,
.flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
.print = array_print, .print = array_print,
.make_new = bytearray_make_new, .make_new = bytearray_make_new,
.getiter = array_iterator_new, .getiter = array_iterator_new,

View File

@ -148,6 +148,7 @@ STATIC void complex_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
const mp_obj_type_t mp_type_complex = { const mp_obj_type_t mp_type_complex = {
{ &mp_type_type }, { &mp_type_type },
.name = MP_QSTR_complex, .name = MP_QSTR_complex,
.flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
.print = complex_print, .print = complex_print,
.make_new = complex_make_new, .make_new = complex_make_new,
.unary_op = complex_unary_op, .unary_op = complex_unary_op,

View File

@ -186,6 +186,7 @@ STATIC mp_obj_t float_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs
const mp_obj_type_t mp_type_float = { const mp_obj_type_t mp_type_float = {
{ &mp_type_type }, { &mp_type_type },
.name = MP_QSTR_float, .name = MP_QSTR_float,
.flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
.print = float_print, .print = float_print,
.make_new = float_make_new, .make_new = float_make_new,
.unary_op = float_unary_op, .unary_op = float_unary_op,

View File

@ -564,6 +564,7 @@ STATIC MP_DEFINE_CONST_DICT(frozenset_locals_dict, frozenset_locals_dict_table);
const mp_obj_type_t mp_type_frozenset = { const mp_obj_type_t mp_type_frozenset = {
{ &mp_type_type }, { &mp_type_type },
.name = MP_QSTR_frozenset, .name = MP_QSTR_frozenset,
.flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST,
.print = set_print, .print = set_print,
.make_new = set_make_new, .make_new = set_make_new,
.unary_op = set_unary_op, .unary_op = set_unary_op,

View File

@ -467,7 +467,7 @@ const byte mp_binary_op_method_name[MP_BINARY_OP_NUM_RUNTIME] = {
[MP_BINARY_OP_EQUAL] = MP_QSTR___eq__, [MP_BINARY_OP_EQUAL] = MP_QSTR___eq__,
[MP_BINARY_OP_LESS_EQUAL] = MP_QSTR___le__, [MP_BINARY_OP_LESS_EQUAL] = MP_QSTR___le__,
[MP_BINARY_OP_MORE_EQUAL] = MP_QSTR___ge__, [MP_BINARY_OP_MORE_EQUAL] = MP_QSTR___ge__,
// MP_BINARY_OP_NOT_EQUAL, // a != b calls a == b and inverts result [MP_BINARY_OP_NOT_EQUAL] = MP_QSTR___ne__,
[MP_BINARY_OP_CONTAINS] = MP_QSTR___contains__, [MP_BINARY_OP_CONTAINS] = MP_QSTR___contains__,
// If an inplace method is not found a normal method will be used as a fallback // If an inplace method is not found a normal method will be used as a fallback
@ -1100,7 +1100,7 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict)
// TODO might need to make a copy of locals_dict; at least that's how CPython does it // TODO might need to make a copy of locals_dict; at least that's how CPython does it
// Basic validation of base classes // Basic validation of base classes
uint16_t base_flags = 0; uint16_t base_flags = MP_TYPE_FLAG_NEEDS_FULL_EQ_TEST;
size_t bases_len; size_t bases_len;
mp_obj_t *bases_items; mp_obj_t *bases_items;
mp_obj_tuple_get(bases_tuple, &bases_len, &bases_items); mp_obj_tuple_get(bases_tuple, &bases_len, &bases_items);

View File

@ -323,19 +323,8 @@ mp_obj_t mp_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
// deal with == and != for all types // deal with == and != for all types
if (op == MP_BINARY_OP_EQUAL || op == MP_BINARY_OP_NOT_EQUAL) { if (op == MP_BINARY_OP_EQUAL || op == MP_BINARY_OP_NOT_EQUAL) {
if (mp_obj_equal(lhs, rhs)) { // mp_obj_equal_not_equal supports a bunch of shortcuts
if (op == MP_BINARY_OP_EQUAL) { return mp_obj_equal_not_equal(op, lhs, rhs);
return mp_const_true;
} else {
return mp_const_false;
}
} else {
if (op == MP_BINARY_OP_EQUAL) {
return mp_const_false;
} else {
return mp_const_true;
}
}
} }
// deal with exception_match for all types // deal with exception_match for all types

View File

@ -0,0 +1,33 @@
class A:
def __eq__(self, other):
print("A __eq__ called")
return True
class B:
def __ne__(self, other):
print("B __ne__ called")
return True
class C:
def __eq__(self, other):
print("C __eq__ called")
return False
class D:
def __ne__(self, other):
print("D __ne__ called")
return False
a = A()
b = B()
c = C()
d = D()
def test(s):
print(s)
print(eval(s))
for x in 'abcd':
for y in 'abcd':
test('{} == {}'.format(x,y))
test('{} != {}'.format(x,y))

View File

@ -0,0 +1,31 @@
class E:
def __repr__(self):
return "E"
def __eq__(self, other):
print('E eq', other)
return 123
class F:
def __repr__(self):
return "F"
def __ne__(self, other):
print('F ne', other)
return -456
print(E() != F())
print(F() != E())
tests = (None, 0, 1, 'a')
for val in tests:
print('==== testing', val)
print(E() == val)
print(val == E())
print(E() != val)
print(val != E())
print(F() == val)
print(val == F())
print(F() != val)
print(val != F())