diff --git a/py/obj.h b/py/obj.h index 88d9253ec6..eff021e70c 100644 --- a/py/obj.h +++ b/py/obj.h @@ -422,6 +422,7 @@ mp_obj_t mp_obj_complex_binary_op(int op, mp_float_t lhs_real, mp_float_t lhs_im void mp_obj_tuple_get(mp_obj_t self_in, uint *len, mp_obj_t **items); void mp_obj_tuple_del(mp_obj_t self_in); machine_int_t mp_obj_tuple_hash(mp_obj_t self_in); +mp_obj_t mp_obj_tuple_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const mp_obj_t *args); // list mp_obj_t mp_obj_list_append(mp_obj_t self_in, mp_obj_t arg); @@ -485,7 +486,7 @@ typedef struct _mp_obj_static_class_method_t { void mp_seq_multiply(const void *items, uint item_sz, uint len, uint times, void *dest); bool m_seq_get_fast_slice_indexes(machine_uint_t len, mp_obj_t slice, machine_uint_t *begin, machine_uint_t *end); #define m_seq_copy(dest, src, len, item_t) memcpy(dest, src, len * sizeof(item_t)) -#define m_seq_cat(dest, src1, len1, src2, len2, item_t) { memcpy(dest, src1, len1 * sizeof(item_t)); memcpy(dest + len1, src2, len2 * sizeof(item_t)); } +#define m_seq_cat(dest, src1, len1, src2, len2, item_t) { memcpy(dest, src1, (len1) * sizeof(item_t)); memcpy(dest + (len1), src2, (len2) * sizeof(item_t)); } bool mp_seq_cmp_bytes(int op, const byte *data1, uint len1, const byte *data2, uint len2); bool mp_seq_cmp_objs(int op, const mp_obj_t *items1, uint len1, const mp_obj_t *items2, uint len2); mp_obj_t mp_seq_index_obj(const mp_obj_t *items, uint len, uint n_args, const mp_obj_t *args); diff --git a/py/objtuple.c b/py/objtuple.c index 6186640604..7f14509029 100644 --- a/py/objtuple.c +++ b/py/objtuple.c @@ -30,7 +30,7 @@ void tuple_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_o print(env, ")"); } -STATIC mp_obj_t tuple_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const mp_obj_t *args) { +mp_obj_t mp_obj_tuple_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const mp_obj_t *args) { // TODO check n_kw == 0 switch (n_args) { @@ -175,7 +175,7 @@ const mp_obj_type_t mp_type_tuple = { { &mp_type_type }, .name = MP_QSTR_tuple, .print = tuple_print, - .make_new = tuple_make_new, + .make_new = mp_obj_tuple_make_new, .unary_op = tuple_unary_op, .binary_op = tuple_binary_op, .getiter = tuple_getiter, diff --git a/py/vm.c b/py/vm.c index 8e16c12ffe..d7e7227ac4 100644 --- a/py/vm.c +++ b/py/vm.c @@ -1,4 +1,5 @@ #include +#include #include #include "nlr.h" @@ -10,6 +11,7 @@ #include "bc0.h" #include "bc.h" #include "objgenerator.h" +#include "objtuple.h" // Value stack grows up (this makes it incompatible with native C stack, but // makes sure that arguments to functions are in natural order arg1..argN @@ -670,6 +672,32 @@ unwind_jump: SET_TOP(mp_call_function_n_kw(*sp, unum & 0xff, (unum >> 8) & 0xff, sp + 1)); break; + case MP_BC_CALL_FUNCTION_VAR: { + DECODE_UINT; + // unum & 0xff == n_positional + // (unum >> 8) & 0xff == n_keyword + // We have folowing stack layout here: + // arg0 arg1 ... kw0 val0 kw1 val1 ... seq <- TOS + // We need to splice seq after all positional args and before kwargs + // TODO: optimize one day to avoid constructing new arg array? Will be hard. + mp_obj_t seq = POP(); + int total_stack_args = (unum & 0xff) + ((unum >> 7) & 0x1fe); + sp -= total_stack_args; + + // Convert vararg sequence to tuple. Note that it can be arbitrary iterator. + // This is null call for tuple, and TODO: we actually could optimize case of list. + mp_obj_tuple_t *varargs = mp_obj_tuple_make_new(MP_OBJ_NULL, 1, 0, &seq); + + int pos_args_len = (unum & 0xff) + varargs->len; + mp_obj_t *args = m_new(mp_obj_t, total_stack_args + varargs->len); + m_seq_cat(args, sp + 1, unum & 0xff, varargs->items, varargs->len, mp_obj_t); + m_seq_copy(args + pos_args_len, sp + (unum & 0xff) + 1, ((unum >> 7) & 0x1fe), mp_obj_t); + + SET_TOP(mp_call_function_n_kw(*sp, pos_args_len, (unum >> 8) & 0xff, args)); + m_del(mp_obj_t, args, total_stack_args + varargs->len); + break; + } + case MP_BC_CALL_METHOD: DECODE_UINT; // unum & 0xff == n_positional diff --git a/tests/basics/fun-callstar.py b/tests/basics/fun-callstar.py new file mode 100644 index 0000000000..49b40d9594 --- /dev/null +++ b/tests/basics/fun-callstar.py @@ -0,0 +1,13 @@ +def foo(a, b, c): + print(a, b, c) + +foo(*(1, 2, 3)) +foo(1, *(2, 3)) +foo(1, 2, *(3,)) +foo(1, 2, 3, *()) + +# Another sequence type +foo(1, 2, *[100]) + +# Iterator +foo(*range(3))