diff --git a/py/objcomplex.c b/py/objcomplex.c index 65957cbf60..2ba5226150 100644 --- a/py/objcomplex.c +++ b/py/objcomplex.c @@ -6,6 +6,7 @@ #include "mpconfig.h" #include "qstr.h" #include "obj.h" +#include "parsenum.h" #include "runtime0.h" #include "map.h" @@ -36,15 +37,20 @@ STATIC mp_obj_t complex_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const return mp_obj_new_complex(0, 0); case 1: - // TODO allow string as first arg and parse it - if (MP_OBJ_IS_TYPE(args[0], &mp_type_complex)) { + if (MP_OBJ_IS_STR(args[0])) { + // a string, parse it + uint l; + const char *s = mp_obj_str_get_data(args[0], &l); + return mp_parse_num_decimal(s, l, true, true); + } else if (MP_OBJ_IS_TYPE(args[0], &mp_type_complex)) { + // a complex, just return it return args[0]; } else { + // something else, try to cast it to a complex return mp_obj_new_complex(mp_obj_get_float(args[0]), 0); } - case 2: - { + case 2: { mp_float_t real, imag; if (MP_OBJ_IS_TYPE(args[0], &mp_type_complex)) { mp_obj_complex_get(args[0], &real, &imag); diff --git a/py/objfloat.c b/py/objfloat.c index 65dafa607e..c51e13e7a1 100644 --- a/py/objfloat.c +++ b/py/objfloat.c @@ -38,10 +38,12 @@ STATIC mp_obj_t float_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const m // a string, parse it uint l; const char *s = mp_obj_str_get_data(args[0], &l); - return mp_parse_num_decimal(s, l, false); + return mp_parse_num_decimal(s, l, false, false); } else if (MP_OBJ_IS_TYPE(args[0], &mp_type_float)) { + // a float, just return it return args[0]; } else { + // something else, try to cast it to a float return mp_obj_new_float(mp_obj_get_float(args[0])); } diff --git a/py/parsenum.c b/py/parsenum.c index b1a70c352d..77f00957c6 100644 --- a/py/parsenum.c +++ b/py/parsenum.c @@ -88,7 +88,7 @@ mp_obj_t mp_parse_num_integer(const char *restrict str, uint len, int base) { #define PARSE_DEC_IN_FRAC (2) #define PARSE_DEC_IN_EXP (3) -mp_obj_t mp_parse_num_decimal(const char *str, uint len, bool allow_imag) { +mp_obj_t mp_parse_num_decimal(const char *str, uint len, bool allow_imag, bool force_complex) { #if MICROPY_ENABLE_FLOAT const char *top = str + len; mp_float_t dec_val = 0; @@ -129,7 +129,7 @@ mp_obj_t mp_parse_num_decimal(const char *str, uint len, bool allow_imag) { dec_val = MICROPY_FLOAT_C_FUN(nan)(""); } } else { - // parse the digits + // string should be a decimal number int in = PARSE_DEC_IN_INTG; bool exp_neg = false; int exp_val = 0; @@ -198,6 +198,8 @@ mp_obj_t mp_parse_num_decimal(const char *str, uint len, bool allow_imag) { // return the object if (imag) { return mp_obj_new_complex(0, dec_val); + } else if (force_complex) { + return mp_obj_new_complex(dec_val, 0); } else { return mp_obj_new_float(dec_val); } diff --git a/py/parsenum.h b/py/parsenum.h index f87fefbe77..97578423c7 100644 --- a/py/parsenum.h +++ b/py/parsenum.h @@ -1,2 +1,2 @@ mp_obj_t mp_parse_num_integer(const char *restrict str, uint len, int base); -mp_obj_t mp_parse_num_decimal(const char *str, uint len, bool allow_imag); +mp_obj_t mp_parse_num_decimal(const char *str, uint len, bool allow_imag, bool force_complex); diff --git a/py/runtime.c b/py/runtime.c index 5604e1a945..c268fd5464 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -375,7 +375,7 @@ mp_obj_t rt_load_const_dec(qstr qstr) { DEBUG_OP_printf("load '%s'\n", qstr_str(qstr)); uint len; const byte* data = qstr_data(qstr, &len); - return mp_parse_num_decimal((const char*)data, len, true); + return mp_parse_num_decimal((const char*)data, len, true, false); } mp_obj_t rt_load_const_str(qstr qstr) {