diff --git a/py/bc0.h b/py/bc0.h index e6dc84277a..773e23d2ea 100644 --- a/py/bc0.h +++ b/py/bc0.h @@ -13,12 +13,11 @@ #define MP_BC_LOAD_FAST_2 (0x22) #define MP_BC_LOAD_FAST_N (0x23) // uint #define MP_BC_LOAD_DEREF (0x24) // uint -#define MP_BC_LOAD_CLOSURE (0x25) // uint -#define MP_BC_LOAD_NAME (0x26) // qstr -#define MP_BC_LOAD_GLOBAL (0x27) // qstr -#define MP_BC_LOAD_ATTR (0x28) // qstr -#define MP_BC_LOAD_METHOD (0x29) // qstr -#define MP_BC_LOAD_BUILD_CLASS (0x2a) +#define MP_BC_LOAD_NAME (0x25) // qstr +#define MP_BC_LOAD_GLOBAL (0x26) // qstr +#define MP_BC_LOAD_ATTR (0x27) // qstr +#define MP_BC_LOAD_METHOD (0x28) // qstr +#define MP_BC_LOAD_BUILD_CLASS (0x29) #define MP_BC_STORE_FAST_0 (0x30) #define MP_BC_STORE_FAST_1 (0x31) diff --git a/py/compile.c b/py/compile.c index 604d5a2437..fe22a90174 100644 --- a/py/compile.c +++ b/py/compile.c @@ -767,7 +767,12 @@ void close_over_variables_etc(compiler_t *comp, scope_t *this_scope, int n_dict_ for (int j = 0; j < this_scope->id_info_len; j++) { id_info_t *id2 = &this_scope->id_info[j]; if (id2->kind == ID_INFO_KIND_FREE && id->qstr == id2->qstr) { +#if MICROPY_EMIT_CPYTHON EMIT(load_closure, id->qstr, id->local_num); +#else + // in Micro Python we load closures using LOAD_FAST + EMIT(load_fast, id->qstr, id->local_num); +#endif nfree += 1; } } @@ -2806,7 +2811,11 @@ void compile_scope(compiler_t *comp, scope_t *scope, pass_kind_t pass) { if (id->kind == ID_INFO_KIND_LOCAL) { EMIT(load_const_tok, MP_TOKEN_KW_NONE); } else { +#if MICROPY_EMIT_CPYTHON EMIT(load_closure, comp->qstr___class__, 0); // XXX check this is the correct local num +#else + EMIT(load_fast, comp->qstr___class__, 0); // XXX check this is the correct local num +#endif } EMIT(return_value); } @@ -2894,7 +2903,7 @@ void compile_scope_inline_asm(compiler_t *comp, scope_t *scope, pass_kind_t pass void compile_scope_compute_things(compiler_t *comp, scope_t *scope) { // in functions, turn implicit globals into explicit globals - // compute num_locals, and the index of each local + // compute the index of each local scope->num_locals = 0; for (int i = 0; i < scope->id_info_len; i++) { id_info_t *id = &scope->id_info[i]; @@ -2913,19 +2922,27 @@ void compile_scope_compute_things(compiler_t *comp, scope_t *scope) { } // compute the index of cell vars (freevars[idx] in CPython) - int num_closed = 0; +#if MICROPY_EMIT_CPYTHON + int num_cell = 0; +#endif for (int i = 0; i < scope->id_info_len; i++) { id_info_t *id = &scope->id_info[i]; +#if MICROPY_EMIT_CPYTHON + // in CPython the cells are numbered starting from 0 if (id->kind == ID_INFO_KIND_CELL) { - id->local_num = num_closed; -#if !MICROPY_EMIT_CPYTHON - // the cells come right after the fast locals (CPython doesn't add this offset) - id->local_num += scope->num_locals; -#endif - num_closed += 1; + id->local_num = num_cell; + num_cell += 1; } +#else + // in Micro Python the cells come right after the fast locals + // parameters are not counted here, since they remain at the start + // of the locals, even if they are cell vars + if (!id->param && id->kind == ID_INFO_KIND_CELL) { + id->local_num = scope->num_locals; + scope->num_locals += 1; + } +#endif } - scope->num_cells = num_closed; // compute the index of free vars (freevars[idx] in CPython) // make sure they are in the order of the parent scope @@ -2937,16 +2954,32 @@ void compile_scope_compute_things(compiler_t *comp, scope_t *scope) { for (int j = 0; j < scope->id_info_len; j++) { id_info_t *id2 = &scope->id_info[j]; if (id2->kind == ID_INFO_KIND_FREE && id->qstr == id2->qstr) { - id2->local_num = num_closed + num_free; -#if !MICROPY_EMIT_CPYTHON - // the frees come right after the cells (CPython doesn't add this offset) - id2->local_num += scope->num_locals; + assert(!id2->param); // free vars should not be params +#if MICROPY_EMIT_CPYTHON + // in CPython the frees are numbered after the cells + id2->local_num = num_cell + num_free; +#else + // in Micro Python the frees come first, before the params + id2->local_num = num_free; #endif num_free += 1; } } } } +#if !MICROPY_EMIT_CPYTHON + // in Micro Python shift all other locals after the free locals + if (num_free > 0) { + for (int i = 0; i < scope->id_info_len; i++) { + id_info_t *id = &scope->id_info[i]; + if (id->param || id->kind != ID_INFO_KIND_FREE) { + id->local_num += num_free; + } + } + scope->num_params += num_free; // free vars are counted as params for passing them into the function + scope->num_locals += num_free; + } +#endif } // compute flags diff --git a/py/emitbc.c b/py/emitbc.c index a10a3b96eb..790fe3e4e5 100644 --- a/py/emitbc.c +++ b/py/emitbc.c @@ -49,36 +49,6 @@ void* emit_bc_get_code(emit_t* emit) { return emit->code_base; } -static void emit_bc_set_native_types(emit_t *emit, bool do_native_types) { -} - -static void emit_bc_start_pass(emit_t *emit, pass_kind_t pass, scope_t *scope) { - emit->pass = pass; - emit->stack_size = 0; - emit->last_emit_was_return_value = false; - emit->scope = scope; - if (pass == PASS_2) { - memset(emit->label_offsets, -1, emit->max_num_labels * sizeof(uint)); - } - emit->code_offset = 0; -} - -static void emit_bc_end_pass(emit_t *emit) { - // check stack is back to zero size - if (emit->stack_size != 0) { - printf("ERROR: stack size not back to zero; got %d\n", emit->stack_size); - } - - if (emit->pass == PASS_2) { - // calculate size of code in bytes - emit->code_size = emit->code_offset; - emit->code_base = m_new(byte, emit->code_size); - - } else if (emit->pass == PASS_3) { - rt_assign_byte_code(emit->scope->unique_code_id, emit->code_base, emit->code_size, emit->scope->num_params, emit->scope->num_locals, emit->scope->num_cells, emit->scope->stack_size, (emit->scope->flags & SCOPE_FLAG_GENERATOR) != 0); - } -} - // all functions must go through this one to emit bytes static byte* emit_get_cur_to_write_bytes(emit_t* emit, int num_bytes_to_write) { //printf("emit %d\n", num_bytes_to_write); @@ -166,6 +136,53 @@ static void emit_write_byte_1_signed_label(emit_t* emit, byte b1, int label) { c[2] = code_offset >> 8; } +static void emit_bc_set_native_types(emit_t *emit, bool do_native_types) { +} + +static void emit_bc_start_pass(emit_t *emit, pass_kind_t pass, scope_t *scope) { + emit->pass = pass; + emit->stack_size = 0; + emit->last_emit_was_return_value = false; + emit->scope = scope; + if (pass == PASS_2) { + memset(emit->label_offsets, -1, emit->max_num_labels * sizeof(uint)); + } + emit->code_offset = 0; + + // prelude for initialising closed over variables + int num_cell = 0; + for (int i = 0; i < scope->id_info_len; i++) { + id_info_t *id = &scope->id_info[i]; + if (id->kind == ID_INFO_KIND_CELL) { + num_cell += 1; + } + } + assert(num_cell <= 255); + emit_write_byte_1(emit, num_cell); // write number of locals that are cells + for (int i = 0; i < scope->id_info_len; i++) { + id_info_t *id = &scope->id_info[i]; + if (id->kind == ID_INFO_KIND_CELL) { + emit_write_byte_1(emit, id->local_num); // write the local which should be converted to a cell + } + } +} + +static void emit_bc_end_pass(emit_t *emit) { + // check stack is back to zero size + if (emit->stack_size != 0) { + printf("ERROR: stack size not back to zero; got %d\n", emit->stack_size); + } + + if (emit->pass == PASS_2) { + // calculate size of code in bytes + emit->code_size = emit->code_offset; + emit->code_base = m_new(byte, emit->code_size); + + } else if (emit->pass == PASS_3) { + rt_assign_byte_code(emit->scope->unique_code_id, emit->code_base, emit->code_size, emit->scope->num_params, emit->scope->num_locals, emit->scope->stack_size, (emit->scope->flags & SCOPE_FLAG_GENERATOR) != 0); + } +} + bool emit_bc_last_emit_was_return_value(emit_t *emit) { return emit->last_emit_was_return_value; } @@ -288,8 +305,8 @@ static void emit_bc_load_deref(emit_t *emit, qstr qstr, int local_num) { } static void emit_bc_load_closure(emit_t *emit, qstr qstr, int local_num) { - emit_pre(emit, 1); - emit_write_byte_1_uint(emit, MP_BC_LOAD_CLOSURE, local_num); + // not needed/supported for BC + assert(0); } static void emit_bc_load_name(emit_t *emit, qstr qstr) { diff --git a/py/obj.h b/py/obj.h index 1a7b91aaa4..7f2f2da202 100644 --- a/py/obj.h +++ b/py/obj.h @@ -120,6 +120,7 @@ typedef struct _mp_map_t mp_map_t; mp_obj_t mp_obj_new_none(void); mp_obj_t mp_obj_new_bool(bool value); +mp_obj_t mp_obj_new_cell(mp_obj_t obj); mp_obj_t mp_obj_new_int(machine_int_t value); mp_obj_t mp_obj_new_str(qstr qstr); #if MICROPY_ENABLE_FLOAT @@ -134,7 +135,7 @@ mp_obj_t mp_obj_new_range(int start, int stop, int step); mp_obj_t mp_obj_new_range_iterator(int cur, int stop, int step); mp_obj_t mp_obj_new_fun_bc(int n_args, uint n_state, const byte *code); mp_obj_t mp_obj_new_fun_asm(uint n_args, void *fun); -mp_obj_t mp_obj_new_gen_wrap(uint n_locals, uint n_cells, uint n_stack, mp_obj_t fun); +mp_obj_t mp_obj_new_gen_wrap(uint n_locals, uint n_stack, mp_obj_t fun); mp_obj_t mp_obj_new_gen_instance(mp_obj_t state, const byte *ip, mp_obj_t *sp); mp_obj_t mp_obj_new_closure(mp_obj_t fun, mp_obj_t closure_tuple); mp_obj_t mp_obj_new_tuple(uint n, mp_obj_t *items); diff --git a/py/objclosure.c b/py/objclosure.c index e699c5daaa..e3354d42d9 100644 --- a/py/objclosure.c +++ b/py/objclosure.c @@ -1,5 +1,6 @@ #include #include +#include #include #include "nlr.h" @@ -11,14 +12,31 @@ typedef struct _mp_obj_closure_t { mp_obj_base_t base; mp_obj_t fun; - mp_obj_t vars; + uint n_closed; + mp_obj_t *closed; } mp_obj_closure_t; +// args are in reverse order in the array +mp_obj_t closure_call_n(mp_obj_t self_in, int n_args, const mp_obj_t *args) { + mp_obj_closure_t *self = self_in; + + // concatenate args and closed-over-vars, in reverse order + // TODO perhaps cache this array so we don't need to create it each time we are called + mp_obj_t *args2 = m_new(mp_obj_t, self->n_closed + n_args); + memcpy(args2, args, n_args * sizeof(mp_obj_t)); + for (int i = 0; i < self->n_closed; i++) { + args2[n_args + i] = self->closed[self->n_closed - 1 - i]; + } + + // call the function with the new vars array + return rt_call_function_n(self->fun, n_args + self->n_closed, args2); +} + const mp_obj_type_t closure_type = { { &mp_const_type }, "closure", NULL, // print - NULL, // call_n + closure_call_n, // call_n NULL, // unary_op NULL, // binary_op NULL, // getiter @@ -30,6 +48,6 @@ mp_obj_t mp_obj_new_closure(mp_obj_t fun, mp_obj_t closure_tuple) { mp_obj_closure_t *o = m_new_obj(mp_obj_closure_t); o->base.type = &closure_type; o->fun = fun; - o->vars = closure_tuple; + mp_obj_tuple_get(closure_tuple, &o->n_closed, &o->closed); return o; } diff --git a/py/objgenerator.c b/py/objgenerator.c index 2a43b95348..1247662f65 100644 --- a/py/objgenerator.c +++ b/py/objgenerator.c @@ -38,6 +38,14 @@ mp_obj_t gen_wrap_call_n(mp_obj_t self_in, int n_args, const mp_obj_t *args) { for (int i = 0; i < n_args; i++) { state[1 + i] = args[n_args - 1 - i]; } + + // TODO + // prelude for making cells (closed over variables) + // for now we just make sure there are no cells variables + // need to work out how to implement closed over variables in generators + assert(bc_code[0] == 0); + bc_code += 1; + return mp_obj_new_gen_instance(state, bc_code, state + self->n_state); } @@ -53,11 +61,11 @@ const mp_obj_type_t gen_wrap_type = { {{NULL, NULL},}, // method list }; -mp_obj_t mp_obj_new_gen_wrap(uint n_locals, uint n_cells, uint n_stack, mp_obj_t fun) { +mp_obj_t mp_obj_new_gen_wrap(uint n_locals, uint n_stack, mp_obj_t fun) { mp_obj_gen_wrap_t *o = m_new_obj(mp_obj_gen_wrap_t); o->base.type = &gen_wrap_type; // we have at least 3 locals so the bc can write back fast[0,1,2] safely; should improve how this is done - o->n_state = ((n_locals + n_cells) < 3 ? 3 : (n_locals + n_cells)) + n_stack; + o->n_state = (n_locals < 3 ? 3 : n_locals) + n_stack; o->fun = fun; return o; } diff --git a/py/runtime.c b/py/runtime.c index 748294c350..c3e1f5d9cf 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -59,7 +59,6 @@ typedef struct _mp_code_t { mp_code_kind_t kind; int n_args; int n_locals; - int n_cells; int n_stack; bool is_generator; union { @@ -178,14 +177,13 @@ static void alloc_unique_codes(void) { } } -void rt_assign_byte_code(int unique_code_id, byte *code, uint len, int n_args, int n_locals, int n_cells, int n_stack, bool is_generator) { +void rt_assign_byte_code(int unique_code_id, byte *code, uint len, int n_args, int n_locals, int n_stack, bool is_generator) { alloc_unique_codes(); assert(unique_code_id < next_unique_code_id); unique_codes[unique_code_id].kind = MP_CODE_BYTE; unique_codes[unique_code_id].n_args = n_args; unique_codes[unique_code_id].n_locals = n_locals; - unique_codes[unique_code_id].n_cells = n_cells; unique_codes[unique_code_id].n_stack = n_stack; unique_codes[unique_code_id].is_generator = is_generator; unique_codes[unique_code_id].u_byte.code = code; @@ -221,7 +219,6 @@ void rt_assign_native_code(int unique_code_id, void *fun, uint len, int n_args) unique_codes[unique_code_id].kind = MP_CODE_NATIVE; unique_codes[unique_code_id].n_args = n_args; unique_codes[unique_code_id].n_locals = 0; - unique_codes[unique_code_id].n_cells = 0; unique_codes[unique_code_id].n_stack = 0; unique_codes[unique_code_id].is_generator = false; unique_codes[unique_code_id].u_native.fun = fun; @@ -255,7 +252,6 @@ void rt_assign_inline_asm_code(int unique_code_id, void *fun, uint len, int n_ar unique_codes[unique_code_id].kind = MP_CODE_INLINE_ASM; unique_codes[unique_code_id].n_args = n_args; unique_codes[unique_code_id].n_locals = 0; - unique_codes[unique_code_id].n_cells = 0; unique_codes[unique_code_id].n_stack = 0; unique_codes[unique_code_id].is_generator = false; unique_codes[unique_code_id].u_inline_asm.fun = fun; @@ -632,7 +628,7 @@ mp_obj_t rt_make_function_from_id(int unique_code_id) { mp_obj_t fun; switch (c->kind) { case MP_CODE_BYTE: - fun = mp_obj_new_fun_bc(c->n_args, c->n_locals + c->n_cells + c->n_stack, c->u_byte.code); + fun = mp_obj_new_fun_bc(c->n_args, c->n_locals + c->n_stack, c->u_byte.code); break; case MP_CODE_NATIVE: switch (c->n_args) { @@ -652,13 +648,14 @@ mp_obj_t rt_make_function_from_id(int unique_code_id) { // check for generator functions and if so wrap in generator object if (c->is_generator) { - fun = mp_obj_new_gen_wrap(c->n_locals, c->n_cells, c->n_stack, fun); + fun = mp_obj_new_gen_wrap(c->n_locals, c->n_stack, fun); } return fun; } mp_obj_t rt_make_closure_from_id(int unique_code_id, mp_obj_t closure_tuple) { + DEBUG_OP_printf("make_closure_from_id %d\n", unique_code_id); // make function object mp_obj_t ffun = rt_make_function_from_id(unique_code_id); // wrap function in closure object diff --git a/py/runtime0.h b/py/runtime0.h index f68b7b961b..8ec2c058f0 100644 --- a/py/runtime0.h +++ b/py/runtime0.h @@ -82,6 +82,6 @@ extern void *const rt_fun_table[RT_F_NUMBER_OF]; void rt_init(void); void rt_deinit(void); int rt_get_unique_code_id(bool is_main_module); -void rt_assign_byte_code(int unique_code_id, byte *code, uint len, int n_args, int n_locals, int n_cells, int n_stack, bool is_generator); +void rt_assign_byte_code(int unique_code_id, byte *code, uint len, int n_args, int n_locals, int n_stack, bool is_generator); void rt_assign_native_code(int unique_code_id, void *f, uint len, int n_args); void rt_assign_inline_asm_code(int unique_code_id, void *f, uint len, int n_args); diff --git a/py/scope.c b/py/scope.c index 38ea5a9e2f..5d97393ae3 100644 --- a/py/scope.c +++ b/py/scope.c @@ -52,7 +52,6 @@ scope_t *scope_new(scope_kind_t kind, mp_parse_node_t pn, uint unique_code_id, u scope->num_dict_params = 0; */ scope->num_locals = 0; - scope->num_cells = 0; scope->unique_code_id = unique_code_id; scope->emit_options = emit_options; diff --git a/py/scope.h b/py/scope.h index 1231b3cc5e..761a4d7119 100644 --- a/py/scope.h +++ b/py/scope.h @@ -49,7 +49,6 @@ typedef struct _scope_t { int num_dict_params; */ int num_locals; - int num_cells; int stack_size; uint unique_code_id; uint emit_options; diff --git a/py/showbc.c b/py/showbc.c index b063c846ac..d5ea704313 100644 --- a/py/showbc.c +++ b/py/showbc.c @@ -15,6 +15,19 @@ void mp_show_byte_code(const byte *ip, int len) { const byte *ip_start = ip; + + // decode prelude + { + uint n_local = *ip++; + printf("(NUM_LOCAL %u)\n", n_local); + for (; n_local > 0; n_local--) { + uint local_num = *ip++; + printf("(INIT_CELL %u)\n", local_num); + } + len -= ip - ip_start; + ip_start = ip; + } + machine_uint_t unum; qstr qstr; while (ip - ip_start < len) { @@ -73,6 +86,11 @@ void mp_show_byte_code(const byte *ip, int len) { printf("LOAD_FAST_N " UINT_FMT, unum); break; + case MP_BC_LOAD_DEREF: + DECODE_UINT; + printf("LOAD_DEREF " UINT_FMT, unum); + break; + case MP_BC_LOAD_NAME: DECODE_QSTR; printf("LOAD_NAME %s", qstr_str(qstr)); @@ -114,6 +132,11 @@ void mp_show_byte_code(const byte *ip, int len) { printf("STORE_FAST_N " UINT_FMT, unum); break; + case MP_BC_STORE_DEREF: + DECODE_UINT; + printf("STORE_DEREF " UINT_FMT, unum); + break; + case MP_BC_STORE_NAME: DECODE_QSTR; printf("STORE_NAME %s", qstr_str(qstr)); @@ -301,6 +324,11 @@ void mp_show_byte_code(const byte *ip, int len) { printf("MAKE_FUNCTION " UINT_FMT, unum); break; + case MP_BC_MAKE_CLOSURE: + DECODE_UINT; + printf("MAKE_CLOSURE " UINT_FMT, unum); + break; + case MP_BC_CALL_FUNCTION: DECODE_UINT; printf("CALL_FUNCTION n=" UINT_FMT " nkw=" UINT_FMT, unum & 0xff, (unum >> 8) & 0xff); diff --git a/py/vm.c b/py/vm.c index a23b0e89fb..3c3f398c04 100644 --- a/py/vm.c +++ b/py/vm.c @@ -39,10 +39,25 @@ mp_obj_t mp_execute_byte_code(const byte *code, const mp_obj_t *args, uint n_arg state[i] = args[n_args - 1 - i]; } const byte *ip = code; + + // execute prelude to make any cells (closed over variables) + { + for (uint n_local = *ip++; n_local > 0; n_local--) { + uint local_num = *ip++; + if (local_num < n_args) { + state[local_num] = mp_obj_new_cell(state[local_num]); + } else { + state[local_num] = mp_obj_new_cell(MP_OBJ_NULL); + } + } + } + + // execute the byte code if (mp_execute_byte_code_2(&ip, &state[0], &sp)) { // it shouldn't yield assert(0); } + // TODO check fails if, eg, return from within for loop //assert(sp == &state[17]); return *sp; @@ -127,11 +142,6 @@ bool mp_execute_byte_code_2(const byte **ip_in_out, mp_obj_t *fastn, mp_obj_t ** PUSH(rt_get_cell(fastn[unum])); break; - case MP_BC_LOAD_CLOSURE: - DECODE_UINT; - PUSH(fastn[unum]); - break; - case MP_BC_LOAD_NAME: DECODE_QSTR; PUSH(rt_load_name(qstr)); diff --git a/tests/basics/tests/closure1.py b/tests/basics/tests/closure1.py new file mode 100644 index 0000000000..610cb70020 --- /dev/null +++ b/tests/basics/tests/closure1.py @@ -0,0 +1,16 @@ +# closures + +def f(x): + y = 2 * x + def g(z): + return y + z + return g + +print(f(1)(1)) + +x = f(2) +y = f(3) +print(x(1), x(2), x(3)) +print(y(1), y(2), y(3)) +print(x(1), x(2), x(3)) +print(y(1), y(2), y(3)) diff --git a/tests/basics/tests/closure2.py b/tests/basics/tests/closure2.py new file mode 100644 index 0000000000..e4e5154a94 --- /dev/null +++ b/tests/basics/tests/closure2.py @@ -0,0 +1,16 @@ +# closures; closing over an argument + +def f(x): + y = 2 * x + def g(z): + return x + y + z + return g + +print(f(1)(1)) + +x = f(2) +y = f(3) +print(x(1), x(2), x(3)) +print(y(1), y(2), y(3)) +print(x(1), x(2), x(3)) +print(y(1), y(2), y(3))