From a1afb337d2629a781cf4e171b7db7f05eeacc78f Mon Sep 17 00:00:00 2001 From: Damien George Date: Wed, 1 Jun 2022 14:52:38 +1000 Subject: [PATCH] extmod/uasyncio: Fix edge case for cancellation of wait_for. This fixes the cases where the task being waited on finishes just before or just after the wait_for itself is cancelled. Fixes issue #8717. Signed-off-by: Damien George --- extmod/uasyncio/funcs.py | 46 ++++++++++++++------------- tests/extmod/uasyncio_wait_for.py | 15 +++++++++ tests/extmod/uasyncio_wait_for.py.exp | 27 ++++++++++++++++ 3 files changed, 66 insertions(+), 22 deletions(-) diff --git a/extmod/uasyncio/funcs.py b/extmod/uasyncio/funcs.py index 258948f73e..a1d38fbcbf 100644 --- a/extmod/uasyncio/funcs.py +++ b/extmod/uasyncio/funcs.py @@ -1,49 +1,51 @@ # MicroPython uasyncio module -# MIT license; Copyright (c) 2019-2020 Damien P. George +# MIT license; Copyright (c) 2019-2022 Damien P. George from . import core +def _run(waiter, aw): + try: + result = await aw + status = True + except BaseException as er: + result = None + status = er + if waiter.data is None: + # The waiter is still waiting, cancel it. + if waiter.cancel(): + # Waiter was cancelled by us, change its CancelledError to an instance of + # CancelledError that contains the status and result of waiting on aw. + # If the wait_for task subsequently gets cancelled externally then this + # instance will be reset to a CancelledError instance without arguments. + waiter.data = core.CancelledError(status, result) + + async def wait_for(aw, timeout, sleep=core.sleep): aw = core._promote_to_task(aw) if timeout is None: return await aw - def runner(waiter, aw): - nonlocal status, result - try: - result = await aw - s = True - except BaseException as er: - s = er - if status is None: - # The waiter is still waiting, set status for it and cancel it. - status = s - waiter.cancel() - # Run aw in a separate runner task that manages its exceptions. - status = None - result = None - runner_task = core.create_task(runner(core.cur_task, aw)) + runner_task = core.create_task(_run(core.cur_task, aw)) try: # Wait for the timeout to elapse. await sleep(timeout) except core.CancelledError as er: - if status is True: - # aw completed successfully and cancelled the sleep, so return aw's result. - return result - elif status is None: + status = er.value + if status is None: # This wait_for was cancelled externally, so cancel aw and re-raise. - status = True runner_task.cancel() raise er + elif status is True: + # aw completed successfully and cancelled the sleep, so return aw's result. + return er.args[1] else: # aw raised an exception, propagate it out to the caller. raise status # The sleep finished before aw, so cancel aw and raise TimeoutError. - status = True runner_task.cancel() await runner_task raise core.TimeoutError diff --git a/tests/extmod/uasyncio_wait_for.py b/tests/extmod/uasyncio_wait_for.py index 9612d16204..c636c7dd74 100644 --- a/tests/extmod/uasyncio_wait_for.py +++ b/tests/extmod/uasyncio_wait_for.py @@ -111,6 +111,21 @@ async def main(): await asyncio.sleep(0.01) print(sep) + # When wait_for gets cancelled and the task it's waiting on finishes around the + # same time as the cancellation of the wait_for + for num_sleep in range(1, 5): + t = asyncio.create_task(task_wait_for_cancel(4 + num_sleep, 0, 2)) + for _ in range(num_sleep): + await asyncio.sleep(0) + assert not t.done() + print("cancel wait_for") + t.cancel() + try: + await t + except asyncio.CancelledError as er: + print(repr(er)) + print(sep) + print("finish") diff --git a/tests/extmod/uasyncio_wait_for.py.exp b/tests/extmod/uasyncio_wait_for.py.exp index a4201d31ff..1bbe3d0658 100644 --- a/tests/extmod/uasyncio_wait_for.py.exp +++ b/tests/extmod/uasyncio_wait_for.py.exp @@ -32,4 +32,31 @@ task_wait_for_cancel_ignore cancelled ignore cancel task_catch done ---------- +task_wait_for_cancel start +cancel wait_for +task start 5 +task_wait_for_cancel cancelled +CancelledError() +---------- +task_wait_for_cancel start +task start 6 +cancel wait_for +task end 6 +task_wait_for_cancel cancelled +CancelledError() +---------- +task_wait_for_cancel start +task start 7 +task end 7 +cancel wait_for +task_wait_for_cancel cancelled +CancelledError() +---------- +task_wait_for_cancel start +task start 8 +task end 8 +cancel wait_for +task_wait_for_cancel cancelled +CancelledError() +---------- finish