diff --git a/py/modmath.c b/py/modmath.c index 4d627c67f4..6c248844df 100644 --- a/py/modmath.c +++ b/py/modmath.c @@ -3,7 +3,7 @@ * * The MIT License (MIT) * - * Copyright (c) 2013, 2014 Damien P. George + * Copyright (c) 2013-2017 Damien P. George * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -39,13 +39,30 @@ STATIC NORETURN void math_error(void) { mp_raise_ValueError("math domain error"); } -#define MATH_FUN_1(py_name, c_name) \ - STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \ - STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); +STATIC mp_obj_t math_generic_1(mp_obj_t x_obj, mp_float_t (*f)(mp_float_t)) { + mp_float_t x = mp_obj_get_float(x_obj); + mp_float_t ans = f(x); + if ((isnan(ans) && !isnan(x)) || (isinf(ans) && !isinf(x))) { + math_error(); + } + return mp_obj_new_float(ans); +} -#define MATH_FUN_2(py_name, c_name) \ - STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj), mp_obj_get_float(y_obj))); } \ - STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name); +STATIC mp_obj_t math_generic_2(mp_obj_t x_obj, mp_obj_t y_obj, mp_float_t (*f)(mp_float_t, mp_float_t)) { + mp_float_t x = mp_obj_get_float(x_obj); + mp_float_t y = mp_obj_get_float(y_obj); + mp_float_t ans = f(x, y); + if ((isnan(ans) && !isnan(x) && !isnan(y)) || (isinf(ans) && !isinf(x))) { + math_error(); + } + return mp_obj_new_float(ans); +} + +#define MATH_FUN_1(py_name, c_name) \ + STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { \ + return math_generic_1(x_obj, MICROPY_FLOAT_C_FUN(c_name)); \ + } \ + STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); #define MATH_FUN_1_TO_BOOL(py_name, c_name) \ STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_bool(c_name(mp_obj_get_float(x_obj))); } \ @@ -55,15 +72,17 @@ STATIC NORETURN void math_error(void) { STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { return mp_obj_new_int_from_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj))); } \ STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); -#define MATH_FUN_1_ERRCOND(py_name, c_name, error_condition) \ - STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj) { \ - mp_float_t x = mp_obj_get_float(x_obj); \ - if (error_condition) { \ - math_error(); \ - } \ - return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(x)); \ +#define MATH_FUN_2(py_name, c_name) \ + STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { \ + return math_generic_2(x_obj, y_obj, MICROPY_FLOAT_C_FUN(c_name)); \ } \ - STATIC MP_DEFINE_CONST_FUN_OBJ_1(mp_math_## py_name ## _obj, mp_math_ ## py_name); + STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name); + +#define MATH_FUN_2_FLT_INT(py_name, c_name) \ + STATIC mp_obj_t mp_math_ ## py_name(mp_obj_t x_obj, mp_obj_t y_obj) { \ + return mp_obj_new_float(MICROPY_FLOAT_C_FUN(c_name)(mp_obj_get_float(x_obj), mp_obj_get_int(y_obj))); \ + } \ + STATIC MP_DEFINE_CONST_FUN_OBJ_2(mp_math_## py_name ## _obj, mp_math_ ## py_name); #if MP_NEED_LOG2 // 1.442695040888963407354163704 is 1/_M_LN2 @@ -71,7 +90,7 @@ STATIC NORETURN void math_error(void) { #endif // sqrt(x): returns the square root of x -MATH_FUN_1_ERRCOND(sqrt, sqrt, (x < (mp_float_t)0.0)) +MATH_FUN_1(sqrt, sqrt) // pow(x, y): returns x to the power of y MATH_FUN_2(pow, pow) // exp(x) @@ -80,9 +99,9 @@ MATH_FUN_1(exp, exp) // expm1(x) MATH_FUN_1(expm1, expm1) // log2(x) -MATH_FUN_1_ERRCOND(log2, log2, (x <= (mp_float_t)0.0)) +MATH_FUN_1(log2, log2) // log10(x) -MATH_FUN_1_ERRCOND(log10, log10, (x <= (mp_float_t)0.0)) +MATH_FUN_1(log10, log10) // cosh(x) MATH_FUN_1(cosh, cosh) // sinh(x) @@ -113,9 +132,15 @@ MATH_FUN_2(atan2, atan2) // ceil(x) MATH_FUN_1_TO_INT(ceil, ceil) // copysign(x, y) -MATH_FUN_2(copysign, copysign) +STATIC mp_float_t MICROPY_FLOAT_C_FUN(copysign_func)(mp_float_t x, mp_float_t y) { + return MICROPY_FLOAT_C_FUN(copysign)(x, y); +} +MATH_FUN_2(copysign, copysign_func) // fabs(x) -MATH_FUN_1(fabs, fabs) +STATIC mp_float_t MICROPY_FLOAT_C_FUN(fabs_func)(mp_float_t x) { + return MICROPY_FLOAT_C_FUN(fabs)(x); +} +MATH_FUN_1(fabs, fabs_func) // floor(x) MATH_FUN_1_TO_INT(floor, floor) //TODO: delegate to x.__floor__() if x is not a float // fmod(x, y) @@ -129,7 +154,7 @@ MATH_FUN_1_TO_BOOL(isnan, isnan) // trunc(x) MATH_FUN_1_TO_INT(trunc, trunc) // ldexp(x, exp) -MATH_FUN_2(ldexp, ldexp) +MATH_FUN_2_FLT_INT(ldexp, ldexp) #if MICROPY_PY_MATH_SPECIAL_FUNCTIONS // erf(x): return the error function of x MATH_FUN_1(erf, erf) diff --git a/tests/float/math_domain.py b/tests/float/math_domain.py new file mode 100644 index 0000000000..0cf10fb2ad --- /dev/null +++ b/tests/float/math_domain.py @@ -0,0 +1,51 @@ +# Tests domain errors in math functions + +try: + import math +except ImportError: + print("SKIP") + raise SystemExit + +inf = float('inf') +nan = float('nan') + +# single argument functions +for name, f, args in ( + ('fabs', math.fabs, ()), + ('ceil', math.ceil, ()), + ('floor', math.floor, ()), + ('trunc', math.trunc, ()), + ('sqrt', math.sqrt, (-1, 0)), + ('exp', math.exp, ()), + ('sin', math.sin, ()), + ('cos', math.cos, ()), + ('tan', math.tan, ()), + ('asin', math.asin, (-1.1, 1, 1.1)), + ('acos', math.acos, (-1.1, 1, 1.1)), + ('atan', math.atan, ()), + ('ldexp', lambda x: math.ldexp(x, 0), ()), + ('radians', math.radians, ()), + ('degrees', math.degrees, ()), + ): + for x in args + (inf, nan): + try: + ans = f(x) + print('%.4f' % ans) + except ValueError: + print(name, 'ValueError') + except OverflowError: + print(name, 'OverflowError') + +# double argument functions +for name, f, args in ( + ('pow', math.pow, ((0, 2), (-1, 2), (0, -1), (-1, 2.3))), + ('fmod', math.fmod, ((1.2, inf), (1.2, 0), (inf, 1.2))), + ('atan2', math.atan2, ((0, 0),)), + ('copysign', math.copysign, ()), + ): + for x in args + ((0, inf), (inf, 0), (inf, inf), (inf, nan), (nan, inf), (nan, nan)): + try: + ans = f(*x) + print('%.4f' % ans) + except ValueError: + print(name, 'ValueError') diff --git a/tests/float/math_domain_special.py b/tests/float/math_domain_special.py new file mode 100644 index 0000000000..388920350e --- /dev/null +++ b/tests/float/math_domain_special.py @@ -0,0 +1,36 @@ +# Tests domain errors in special math functions + +try: + import math + math.erf +except (ImportError, AttributeError): + print("SKIP") + raise SystemExit + +inf = float('inf') +nan = float('nan') + +# single argument functions +for name, f, args in ( + ('expm1', math.exp, ()), + ('log2', math.log2, (-1, 0)), + ('log10', math.log10, (-1, 0)), + ('sinh', math.sinh, ()), + ('cosh', math.cosh, ()), + ('tanh', math.tanh, ()), + ('asinh', math.asinh, ()), + ('acosh', math.acosh, (-1, 0.9, 1)), + ('atanh', math.atanh, (-1, 1)), + ('erf', math.erf, ()), + ('erfc', math.erfc, ()), + ('gamma', math.gamma, (-2, -1, 0, 1)), + ('lgamma', math.lgamma, (-2, -1, 0, 1)), + ): + for x in args + (inf, nan): + try: + ans = f(x) + print('%.4f' % ans) + except ValueError: + print(name, 'ValueError') + except OverflowError: + print(name, 'OverflowError')