diff --git a/py/mpz.c b/py/mpz.c index b1d6e2b322..a152b6f11b 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -711,16 +711,9 @@ typedef uint32_t mp_float_int_t; // value == 0 || value < 1 mpz_init_zero(z); } else if (u.p.exp == ((1 << EXP_SZ) - 1)) { - // inf or NaN -#if 0 - // TODO: this probably isn't the right place to throw an exception - if(u.p.frc == 0) - nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_OverflowError, "cannot convert float infinity to integer")); - else - nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "cannot convert float NaN to integer")); -#else + // u.p.frc == 0 indicates inf, else NaN + // should be handled by caller mpz_init_zero(z); -#endif } else { const int adj_exp = (int)u.p.exp - ((1 << (EXP_SZ - 1)) - 1); if (adj_exp < 0) { diff --git a/py/objint_mpz.c b/py/objint_mpz.c index 23e3000235..49a9a91e25 100644 --- a/py/objint_mpz.c +++ b/py/objint_mpz.c @@ -298,9 +298,16 @@ mp_obj_t mp_obj_new_int_from_uint(mp_uint_t value) { #if MICROPY_PY_BUILTINS_FLOAT mp_obj_t mp_obj_new_int_from_float(mp_float_t val) { - mp_obj_int_t *o = mp_obj_int_new_mpz(); - mpz_set_from_float(&o->mpz, val); - return o; + int cl = fpclassify(val); + if (cl == FP_INFINITE) { + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_OverflowError, "can't convert inf to int")); + } else if (cl == FP_NAN) { + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "can't convert NaN to int")); + } else { + mp_obj_int_t *o = mp_obj_int_new_mpz(); + mpz_set_from_float(&o->mpz, val); + return o; + } } #endif diff --git a/tests/float/float2int.py b/tests/float/float2int.py index b948755de5..42210b4413 100644 --- a/tests/float/float2int.py +++ b/tests/float/float2int.py @@ -22,3 +22,15 @@ for i in range(0,23): print('fail: 10**%u was %u digits long' % (i, digcnt)); testpass = False print("power of 10 test: %s" % (testpass and 'passed' or 'failed')) + +# test inf conversion +try: + int(float('inf')) +except OverflowError: + print("OverflowError") + +# test nan conversion +try: + int(float('nan')) +except ValueError: + print("ValueError")