diff --git a/py/mpz.c b/py/mpz.c index 4a8941e298..21b390996a 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -993,8 +993,11 @@ void mpz_pow_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { if (mpz_is_odd(n)) { mpz_mul_inpl(dest, dest, x); } - mpz_mul_inpl(x, x, x); n->len = mpn_shr(n->dig, n->dig, n->len, 1); + if (n->len == 0) { + break; + } + mpz_mul_inpl(x, x, x); } mpz_free(x); diff --git a/py/runtime.c b/py/runtime.c index e2e3495648..d9e2298c4c 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -367,18 +367,34 @@ mp_obj_t mp_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { nlr_jump(mp_obj_new_exception_msg(&mp_type_ValueError, "negative power with no float support")); #endif } else { - // TODO check for overflow machine_int_t ans = 1; while (rhs_val > 0) { if (rhs_val & 1) { + machine_int_t old = ans; ans *= lhs_val; + if (ans < old) { + goto power_overflow; + } + } + if (rhs_val == 1) { + break; } - lhs_val *= lhs_val; rhs_val /= 2; + machine_int_t old = lhs_val; + lhs_val *= lhs_val; + if (lhs_val < old) { + goto power_overflow; + } } lhs_val = ans; } break; + + power_overflow: + // use higher precision + lhs = mp_obj_new_int_from_ll(MP_OBJ_SMALL_INT_VALUE(lhs)); + goto generic_binary_op; + case MP_BINARY_OP_LESS: return MP_BOOL(lhs_val < rhs_val); break; case MP_BINARY_OP_MORE: return MP_BOOL(lhs_val > rhs_val); break; case MP_BINARY_OP_LESS_EQUAL: return MP_BOOL(lhs_val <= rhs_val); break;