py/obj: Fix comparison of float/complex NaN with itself.

IEEE floating point is specified such that a comparison of NaN with itself
returns false, and Python respects these semantics.  This patch makes uPy
also have these semantics.  The fix has a minor impact on the speed of the
object-equality fast-path, but that seems to be unavoidable and it's much
more important to have correct behaviour (especially in this case where
the wrong answer for nan==nan is silently returned).
This commit is contained in:
Damien George 2017-09-04 14:16:27 +10:00
parent 9950865c39
commit d4b75f6b68
3 changed files with 20 additions and 1 deletions

View File

@ -162,7 +162,16 @@ bool mp_obj_is_callable(mp_obj_t o_in) {
// comparison returns NotImplemented, == and != are decided by comparing the object // comparison returns NotImplemented, == and != are decided by comparing the object
// pointer." // pointer."
bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) { bool mp_obj_equal(mp_obj_t o1, mp_obj_t o2) {
if (o1 == o2) { // Float (and complex) NaN is never equal to anything, not even itself,
// so we must have a special check here to cover those cases.
if (o1 == o2
#if MICROPY_PY_BUILTINS_FLOAT
&& !mp_obj_is_float(o1)
#endif
#if MICROPY_PY_BUILTINS_COMPLEX
&& !MP_OBJ_IS_TYPE(o1, &mp_type_complex)
#endif
) {
return true; return true;
} }
if (o1 == mp_const_none || o2 == mp_const_none) { if (o1 == mp_const_none || o2 == mp_const_none) {

View File

@ -37,6 +37,11 @@ ans = 1j ** 2.5j; print("%.5g %.5g" % (ans.real, ans.imag))
print(1j == 1) print(1j == 1)
print(1j == 1j) print(1j == 1j)
# comparison of nan is special
nan = float('nan') * 1j
print(nan == 1j)
print(nan == nan)
# builtin abs # builtin abs
print(abs(1j)) print(abs(1j))
print("%.5g" % abs(1j + 2)) print("%.5g" % abs(1j + 2))

View File

@ -60,6 +60,11 @@ print(1.2 <= -3.4)
print(1.2 >= 3.4) print(1.2 >= 3.4)
print(1.2 >= -3.4) print(1.2 >= -3.4)
# comparison of nan is special
nan = float('nan')
print(nan == 1.2)
print(nan == nan)
try: try:
1.0 / 0 1.0 / 0
except ZeroDivisionError: except ZeroDivisionError: