diff --git a/electrum/network.py b/electrum/network.py index b860fc226..d871bc42b 100644 --- a/electrum/network.py +++ b/electrum/network.py @@ -755,7 +755,6 @@ class Network(Logger, NetworkRetryManager[ServerAddr]): proxy = self.proxy if proxy and proxy.enabled and proxy.mode == 'socks5': - # FIXME GC issues? do we need to store the Future? asyncio.run_coroutine_threadsafe(tor_probe_task(proxy), self.asyncio_loop) @log_exceptions diff --git a/electrum/util.py b/electrum/util.py index 66188fae6..f505a225a 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -1652,6 +1652,7 @@ def create_and_start_event_loop() -> Tuple[asyncio.AbstractEventLoop, _asyncio_event_loop = None loop.set_exception_handler(on_exception) + _set_custom_task_factory(loop) # loop.set_debug(True) stopping_fut = loop.create_future() loop_thread = threading.Thread( @@ -1670,6 +1671,42 @@ def create_and_start_event_loop() -> Tuple[asyncio.AbstractEventLoop, return loop, stopping_fut, loop_thread +_running_asyncio_tasks = set() # type: Set[asyncio.Future] +def _set_custom_task_factory(loop: asyncio.AbstractEventLoop): + """Wrap task creation to track pending and running tasks. + When tasks are created, asyncio only maintains a weak reference to them. + Hence, the garbage collector might destroy the task mid-execution. + To avoid this, we store a strong reference for the task until it completes. + + Without this, a lot of APIs are basically Heisenbug-generators... e.g.: + - "asyncio.create_task" + - "loop.create_task" + - "asyncio.ensure_future" + - what about "asyncio.run_coroutine_threadsafe"? not sure if that is safe. + + related: + - https://bugs.python.org/issue44665 + - https://github.com/python/cpython/issues/88831 + - https://github.com/python/cpython/issues/91887 + - https://textual.textualize.io/blog/2023/02/11/the-heisenbug-lurking-in-your-async-code/ + - https://github.com/python/cpython/issues/91887#issuecomment-1434816045 + - "Task was destroyed but it is pending!" + """ + + platform_task_factory = loop.get_task_factory() + + def factory(loop_, coro, **kwargs): + if platform_task_factory is not None: + task = platform_task_factory(loop_, coro, **kwargs) + else: + task = asyncio.Task(coro, loop=loop_, **kwargs) + _running_asyncio_tasks.add(task) + task.add_done_callback(_running_asyncio_tasks.discard) + return task + + loop.set_task_factory(factory) + + class OrderedDictWithIndex(OrderedDict): """An OrderedDict that keeps track of the positions of keys. diff --git a/tests/test_util.py b/tests/test_util.py index b354e75e8..7f9018233 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime from decimal import Decimal @@ -472,3 +473,30 @@ class TestUtil(ElectrumTestCase): self.assertTrue(ShortID.from_components(3, 30, 300) > ShortID.from_components(3, 1, 999)) self.assertTrue(ShortID.from_components(3, 30, 300) < ShortID.from_components(3, 999, 1)) + async def test_custom_task_factory(self): + loop = util.get_running_loop() + # set our factory. note: this does not leak into other unit tests + util._set_custom_task_factory(loop) + + evt = asyncio.Event() + async def foo(): + await evt.wait() + + fut = asyncio.ensure_future(foo()) + self.assertTrue(fut in util._running_asyncio_tasks) + fut = asyncio.create_task(foo()) + self.assertTrue(fut in util._running_asyncio_tasks) + fut = loop.create_task(foo()) + self.assertTrue(fut in util._running_asyncio_tasks) + #fut = asyncio.run_coroutine_threadsafe(foobar(), loop=loop) + #self.assertTrue(fut in util._running_asyncio_tasks) + + # we should have stored one ref for each above. + # (though what if test framework is doing stuff ~concurrently?) + self.assertEqual(3, len(util._running_asyncio_tasks)) + evt.set() + for _ in range(10): # wait a few event loop iterations + await asyncio.sleep(0) + # refs should be cleaned up by now: + self.assertEqual(0, len(util._running_asyncio_tasks)) +