py/compile: Fix async for's stack handling of iterator expression.

Prior to this fix, async for assumed the iterator expression was a simple
identifier, and used that identifier as a local to store the intermediate
iterator object.  This is incorrect behaviour.

This commit fixes the issue by keeping the iterator object on the stack as
an anonymous local variable.

Fixes issue #11511.

Signed-off-by: Damien George <damien@micropython.org>
This commit is contained in:
Damien George 2023-05-19 17:00:53 +10:00
parent 14c2b64131
commit 606ec9bfb1
3 changed files with 120 additions and 18 deletions

View File

@ -1768,18 +1768,21 @@ STATIC void compile_await_object_method(compiler_t *comp, qstr method) {
} }
STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) { STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) {
// comp->break_label |= MP_EMIT_BREAK_FROM_FOR; // Allocate labels.
qstr context = MP_PARSE_NODE_LEAF_ARG(pns->nodes[1]);
uint while_else_label = comp_next_label(comp); uint while_else_label = comp_next_label(comp);
uint try_exception_label = comp_next_label(comp); uint try_exception_label = comp_next_label(comp);
uint try_else_label = comp_next_label(comp); uint try_else_label = comp_next_label(comp);
uint try_finally_label = comp_next_label(comp); uint try_finally_label = comp_next_label(comp);
// Stack: (...)
// Compile the iterator expression and load and call its __aiter__ method.
compile_node(comp, pns->nodes[1]); // iterator compile_node(comp, pns->nodes[1]); // iterator
// Stack: (..., iterator)
EMIT_ARG(load_method, MP_QSTR___aiter__, false); EMIT_ARG(load_method, MP_QSTR___aiter__, false);
// Stack: (..., iterator, __aiter__)
EMIT_ARG(call_method, 0, 0, 0); EMIT_ARG(call_method, 0, 0, 0);
compile_store_id(comp, context); // Stack: (..., iterable)
START_BREAK_CONTINUE_BLOCK START_BREAK_CONTINUE_BLOCK
@ -1787,9 +1790,15 @@ STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns
compile_increase_except_level(comp, try_exception_label, MP_EMIT_SETUP_BLOCK_EXCEPT); compile_increase_except_level(comp, try_exception_label, MP_EMIT_SETUP_BLOCK_EXCEPT);
compile_load_id(comp, context); EMIT(dup_top);
// Stack: (..., iterable, iterable)
// Compile: yield from iterable.__anext__()
compile_await_object_method(comp, MP_QSTR___anext__); compile_await_object_method(comp, MP_QSTR___anext__);
// Stack: (..., iterable, yielded_value)
c_assign(comp, pns->nodes[0], ASSIGN_STORE); // variable c_assign(comp, pns->nodes[0], ASSIGN_STORE); // variable
// Stack: (..., iterable)
EMIT_ARG(pop_except_jump, try_else_label, false); EMIT_ARG(pop_except_jump, try_else_label, false);
EMIT_ARG(label_assign, try_exception_label); EMIT_ARG(label_assign, try_exception_label);
@ -1806,6 +1815,8 @@ STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns
compile_decrease_except_level(comp); compile_decrease_except_level(comp);
EMIT(end_except_handler); EMIT(end_except_handler);
// Stack: (..., iterable)
EMIT_ARG(label_assign, try_else_label); EMIT_ARG(label_assign, try_else_label);
compile_node(comp, pns->nodes[2]); // body compile_node(comp, pns->nodes[2]); // body
@ -1817,6 +1828,10 @@ STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns
compile_node(comp, pns->nodes[3]); // else compile_node(comp, pns->nodes[3]); // else
EMIT_ARG(label_assign, break_label); EMIT_ARG(label_assign, break_label);
// Stack: (..., iterable)
EMIT(pop_top);
// Stack: (...)
} }
STATIC void compile_async_with_stmt_helper(compiler_t *comp, size_t n, mp_parse_node_t *nodes, mp_parse_node_t body) { STATIC void compile_async_with_stmt_helper(compiler_t *comp, size_t n, mp_parse_node_t *nodes, mp_parse_node_t body) {

View File

@ -1,29 +1,75 @@
# test basic async for execution # test basic async for execution
# example taken from PEP0492 # example taken from PEP0492
class AsyncIteratorWrapper: class AsyncIteratorWrapper:
def __init__(self, obj): def __init__(self, obj):
print('init') print("init")
self._it = iter(obj) self._obj = obj
def __repr__(self):
return "AsyncIteratorWrapper-" + self._obj
def __aiter__(self): def __aiter__(self):
print('aiter') print("aiter")
return self return AsyncIteratorWrapperIterator(self._obj)
class AsyncIteratorWrapperIterator:
def __init__(self, obj):
print("init")
self._it = iter(obj)
async def __anext__(self): async def __anext__(self):
print('anext') print("anext")
try: try:
value = next(self._it) value = next(self._it)
except StopIteration: except StopIteration:
raise StopAsyncIteration raise StopAsyncIteration
return value return value
async def coro():
async for letter in AsyncIteratorWrapper('abc'): def run_coro(c):
print("== start ==")
try:
c.send(None)
except StopIteration:
print("== finish ==")
async def coro0():
async for letter in AsyncIteratorWrapper("abc"):
print(letter) print(letter)
o = coro()
try: run_coro(coro0())
o.send(None)
except StopIteration:
print('finished') async def coro1():
a = AsyncIteratorWrapper("def")
async for letter in a:
print(letter)
print(a)
run_coro(coro1())
a_global = AsyncIteratorWrapper("ghi")
async def coro2():
async for letter in a_global:
print(letter)
print(a_global)
run_coro(coro2())
async def coro3(a):
async for letter in a:
print(letter)
print(a)
run_coro(coro3(AsyncIteratorWrapper("jkl")))

View File

@ -1,5 +1,7 @@
== start ==
init init
aiter aiter
init
anext anext
a a
anext anext
@ -7,4 +9,43 @@ b
anext anext
c c
anext anext
finished == finish ==
== start ==
init
aiter
init
anext
d
anext
e
anext
f
anext
AsyncIteratorWrapper-def
== finish ==
init
== start ==
aiter
init
anext
g
anext
h
anext
i
anext
AsyncIteratorWrapper-ghi
== finish ==
init
== start ==
aiter
init
anext
j
anext
k
anext
l
anext
AsyncIteratorWrapper-jkl
== finish ==