diff --git a/py/obj.c b/py/obj.c index 0dab5f2ac7..90ce47e8fb 100644 --- a/py/obj.c +++ b/py/obj.c @@ -264,20 +264,33 @@ bool mp_obj_get_int_maybe(mp_const_obj_t arg, mp_int_t *value) { } #if MICROPY_PY_BUILTINS_FLOAT -mp_float_t mp_obj_get_float(mp_obj_t arg) { +bool mp_obj_get_float_maybe(mp_obj_t arg, mp_float_t *value) { + mp_float_t val; + if (arg == mp_const_false) { - return 0; + val = 0; } else if (arg == mp_const_true) { - return 1; + val = 1; } else if (MP_OBJ_IS_SMALL_INT(arg)) { - return MP_OBJ_SMALL_INT_VALUE(arg); + val = MP_OBJ_SMALL_INT_VALUE(arg); #if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE } else if (MP_OBJ_IS_TYPE(arg, &mp_type_int)) { - return mp_obj_int_as_float_impl(arg); + val = mp_obj_int_as_float_impl(arg); #endif } else if (mp_obj_is_float(arg)) { - return mp_obj_float_get(arg); + val = mp_obj_float_get(arg); } else { + return false; + } + + *value = val; + return true; +} + +mp_float_t mp_obj_get_float(mp_obj_t arg) { + mp_float_t val; + + if (!mp_obj_get_float_maybe(arg, &val)) { if (MICROPY_ERROR_REPORTING == MICROPY_ERROR_REPORTING_TERSE) { mp_raise_TypeError("can't convert to float"); } else { @@ -285,6 +298,8 @@ mp_float_t mp_obj_get_float(mp_obj_t arg) { "can't convert %s to float", mp_obj_get_type_str(arg))); } } + + return val; } #if MICROPY_PY_BUILTINS_COMPLEX diff --git a/py/obj.h b/py/obj.h index 323423b3ed..70cdd15fa8 100644 --- a/py/obj.h +++ b/py/obj.h @@ -686,6 +686,7 @@ mp_int_t mp_obj_get_int_truncated(mp_const_obj_t arg); bool mp_obj_get_int_maybe(mp_const_obj_t arg, mp_int_t *value); #if MICROPY_PY_BUILTINS_FLOAT mp_float_t mp_obj_get_float(mp_obj_t self_in); +bool mp_obj_get_float_maybe(mp_obj_t arg, mp_float_t *value); void mp_obj_get_complex(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag); #endif //qstr mp_obj_get_qstr(mp_obj_t arg); diff --git a/py/objfloat.c b/py/objfloat.c index b1900b236c..fadbbcb795 100644 --- a/py/objfloat.c +++ b/py/objfloat.c @@ -240,7 +240,11 @@ STATIC void mp_obj_float_divmod(mp_float_t *x, mp_float_t *y) { } mp_obj_t mp_obj_float_binary_op(mp_binary_op_t op, mp_float_t lhs_val, mp_obj_t rhs_in) { - mp_float_t rhs_val = mp_obj_get_float(rhs_in); // can be any type, this function will convert to float (if possible) + mp_float_t rhs_val; + if (!mp_obj_get_float_maybe(rhs_in, &rhs_val)) { + return MP_OBJ_NULL; // op not supported + } + switch (op) { case MP_BINARY_OP_ADD: case MP_BINARY_OP_INPLACE_ADD: lhs_val += rhs_val; break; diff --git a/tests/float/float_compare.py b/tests/float/float_compare.py new file mode 100644 index 0000000000..105923ac73 --- /dev/null +++ b/tests/float/float_compare.py @@ -0,0 +1,22 @@ +# Extended float comparisons + +class Foo: + pass + +foo = Foo() + +print(foo == 1.0) +print(1.0 == foo) +print(1.0 == Foo) +print(1.0 == []) +print(1.0 == {}) + +try: + print(foo < 1.0) +except TypeError: + print("TypeError") + +try: + print(1.0 < foo) +except TypeError: + print("TypeError")