diff --git a/electrum/util.py b/electrum/util.py index 45608bdc1..72ff594cc 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -57,6 +57,7 @@ import enum from contextlib import nullcontext, suppress import traceback import inspect +import weakref import aiohttp from aiohttp_socks import ProxyConnector, ProxyType @@ -1954,22 +1955,38 @@ class CallbackManager(Logger): def __init__(self): Logger.__init__(self) self.callback_lock = threading.Lock() - self.callbacks = defaultdict(set) # type: Dict[str, Set[Callable]] # note: needs self.callback_lock + self._wcallbacks = defaultdict(set) # type: Dict[str, Set[weakref.ref[Callable]]] # note: needs self.callback_lock - def register_callback(self, func: Callable, events: Sequence[str]) -> None: + @staticmethod + def _wcb_from_any_callback(cb: Callable) -> weakref.ref[Callable]: + assert callable(cb), type(cb) + if isinstance(cb, weakref.ref): # no-op + return cb + elif inspect.ismethod(cb): # instance method, such as for a subclass of EventListener + return WeakMethodProper(cb) + else: # proper function? e.g. used by lnpeer unit tests + return weakref.ref(cb) + + def register_callback(self, cb: Callable, events: Sequence[str]) -> None: + wcb = self._wcb_from_any_callback(cb) with self.callback_lock: for event in events: - self.callbacks[event].add(func) + self._wcallbacks[event].add(wcb) - def unregister_callback(self, callback: Callable) -> None: + def unregister_callback(self, cb: Callable) -> None: + wcb = self._wcb_from_any_callback(cb) with self.callback_lock: - for callbacks in self.callbacks.values(): - if callback in callbacks: - callbacks.remove(callback) + for callbacks in self._wcallbacks.values(): + if wcb in callbacks: + callbacks.remove(wcb) + + def count_all_callbacks(self) -> int: + with self.callback_lock: + return sum(len(cbs) for cbs in self._wcallbacks.values()) def clear_all_callbacks(self) -> None: with self.callback_lock: - self.callbacks.clear() + self._wcallbacks.clear() def trigger_callback(self, event: str, *args) -> None: """Trigger a callback with given arguments. @@ -1979,8 +1996,11 @@ class CallbackManager(Logger): loop = get_asyncio_loop() assert loop.is_running(), "event loop not running" with self.callback_lock: - callbacks = copy.copy(self.callbacks[event]) - for callback in callbacks: + wcallbacks = copy.copy(self._wcallbacks[event]) + for wcb in wcallbacks: + callback = wcb() + if callback is None: + continue if inspect.iscoroutinefunction(callback): # async cb fut = asyncio.run_coroutine_threadsafe(callback(*args), loop) @@ -2002,14 +2022,28 @@ unregister_callback = callback_mgr.unregister_callback _event_listeners = defaultdict(set) # type: Dict[str, Set[str]] +class WeakMethodProper(weakref.WeakMethod): + """Unlike weakref.WeakMethod, this class has an __eq__ I can trust.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + meth = self() + self._my_id = (id(meth.__self__), id(meth.__func__)) + + def __hash__(self): + return hash(self._my_id) + + def __eq__(self, other): + if not isinstance(other, WeakMethodProper): + return False + return self._my_id == other._my_id + + class EventListener: """Use as a mixin for a class that has methods to be triggered on events. - Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener. - register_callbacks() should be called once per instance of EventListener, e.g. in __init__ - unregister_callbacks() should be called at least once, e.g. when the instance is destroyed - - if register_callbacks() is called in __init__, as opposed to a separate start() method, - extra care is needed that the call to unregister_callbacks() is not forgotten, - otherwise we will leak memory + - as fallback, __del__() also calls unregister_callbacks() """ def _list_callbacks(self): @@ -2031,6 +2065,9 @@ class EventListener: #_logger.debug(f'unregistering callback {method}') unregister_callback(method) + def __del__(self): + self.unregister_callbacks() + def event_listener(func): """To be used in subclasses of EventListener only. (how to enforce this programmatically?)""" diff --git a/electrum/utils/memory_leak.py b/electrum/utils/memory_leak.py index 06cc3ebca..efda41310 100644 --- a/electrum/utils/memory_leak.py +++ b/electrum/utils/memory_leak.py @@ -1,3 +1,4 @@ +import asyncio from collections import defaultdict import datetime import os @@ -5,6 +6,7 @@ import time from typing import Sequence, Mapping, TypeVar, Optional import weakref +from electrum import util from electrum.util import ThreadJob @@ -52,6 +54,26 @@ class DebugMem(ThreadJob): self.next_time = time.time() + self.interval +async def wait_until_obj_is_garbage_collected(wr: weakref.ref) -> None: + """Async wait until the object referenced by `wr` is GC-ed.""" + obj = wr() + if obj is None: + return + evt_gc = asyncio.Event() # set when obj is finally GC-ed. + wr2 = weakref.ref(obj, lambda _x: util.run_sync_function_on_asyncio_thread(evt_gc.set, block=False)) + del obj + while True: + try: + async with util.async_timeout(0.01): + await evt_gc.wait() + except asyncio.TimeoutError: + import gc + gc.collect() + else: + break + assert evt_gc.is_set() + + def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]: """Return a string listing the most common types in memory.""" import objgraph # 3rd-party dependency diff --git a/tests/test_callbackmgr.py b/tests/test_callbackmgr.py index cd5cd10ef..c09347a77 100644 --- a/tests/test_callbackmgr.py +++ b/tests/test_callbackmgr.py @@ -1,16 +1,20 @@ import asyncio +import weakref from electrum import util from electrum.util import EventListener, event_listener, trigger_callback -from electrum.utils.memory_leak import count_objects_in_memory +from electrum.utils.memory_leak import count_objects_in_memory, wait_until_obj_is_garbage_collected +from electrum.simple_config import SimpleConfig -from . import ElectrumTestCase +from . import ElectrumTestCase, restore_wallet_from_text__for_unittest class MyEventListener(EventListener): - def __init__(self): + def __init__(self, *, autostart: bool = False): self._satoshi_cnt = 0 self._hal_cnt = 0 + if autostart: + self.start() def start(self): self.register_callbacks() @@ -27,8 +31,7 @@ class MyEventListener(EventListener): self._hal_cnt += 1 -def _count_all_callbacks() -> int: - return sum(len(cbs) for cbs in util.callback_mgr.callbacks.values()) +_count_all_callbacks = util.callback_mgr.count_all_callbacks async def fast_sleep(): @@ -89,13 +92,50 @@ class TestCallbackMgr(ElectrumTestCase): async def test_gc(self): objmap = count_objects_in_memory([MyEventListener]) self.assertEqual(len(objmap[MyEventListener]), 0) + self.assertEqual(_count_all_callbacks(), 0) el1 = MyEventListener() el1.start() el2 = MyEventListener() el2.start() objmap = count_objects_in_memory([MyEventListener]) self.assertEqual(len(objmap[MyEventListener]), 2) - el1.stop() + self.assertEqual(_count_all_callbacks(), 4) + # test if we can get GC-ed if we explicitly unregister cbs: + el1.stop() # calls unregister_callbacks del el1 objmap = count_objects_in_memory([MyEventListener]) self.assertEqual(len(objmap[MyEventListener]), 1) + self.assertEqual(_count_all_callbacks(), 2) + # test if we can get GC-ed even without unregistering cbs: + del el2 + objmap = count_objects_in_memory([MyEventListener]) + self.assertEqual(len(objmap[MyEventListener]), 0) + self.assertEqual(_count_all_callbacks(), 0) + + async def test_gc2(self): + def func(): + el1 = MyEventListener(autostart=True) + el1.el2 = MyEventListener(autostart=True) + el1.el2.el3 = MyEventListener(autostart=True) + self.assertEqual(_count_all_callbacks(), 6) + func() + self.assertEqual(_count_all_callbacks(), 0) + + async def test_gc_complex_using_wallet(self): + """This test showcases why EventListener uses WeakMethodProper instead of weakref.WeakMethod. + We need the custom __eq__ for some reason. + """ + self.assertEqual(_count_all_callbacks(), 0) + config = SimpleConfig({'electrum_path': self.electrum_path}) + wallet = restore_wallet_from_text__for_unittest( + "9dk", path=None, config=config, + )["wallet"] + assert wallet.lnworker is not None + # now delete the wallet, and wait for it to get GC-ed + # note: need to wait for cyclic GC. example: wallet.lnworker.wallet + wr = weakref.ref(wallet) + del wallet + async with util.async_timeout(5): + await wait_until_obj_is_garbage_collected(wr) + # by now, all callbacks must have been cleaned up: + self.assertEqual(_count_all_callbacks(), 0) diff --git a/tests/test_daemon.py b/tests/test_daemon.py index f5bbbade2..6da78076a 100644 --- a/tests/test_daemon.py +++ b/tests/test_daemon.py @@ -1,3 +1,5 @@ +import asyncio +from collections import defaultdict import os from typing import Optional, Iterable @@ -5,7 +7,10 @@ from electrum.commands import Commands from electrum.daemon import Daemon from electrum.simple_config import SimpleConfig from electrum.wallet import Abstract_Wallet +from electrum.lnworker import LNWallet, LNPeerManager +from electrum.lnwatcher import LNWatcher from electrum import util +from electrum.utils.memory_leak import count_objects_in_memory from . import ElectrumTestCase, as_testnet, restore_wallet_from_text__for_unittest @@ -30,7 +35,7 @@ class DaemonTestCase(ElectrumTestCase): await self.daemon.stop() await super().asyncTearDown() - def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None) -> str: + def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None, **kwargs) -> str: """Returns path for created wallet.""" basename = util.get_new_wallet_name(self.wallet_dir) path = os.path.join(self.wallet_dir, basename) @@ -40,6 +45,7 @@ class DaemonTestCase(ElectrumTestCase): password=password, encrypt_file=encrypt_file, config=self.config, + **kwargs, ) # We return the path instead of the wallet object, as extreme # care would be needed to use the wallet object directly: @@ -188,6 +194,40 @@ class TestUnifiedPassword(DaemonTestCase): self.assertTrue(is_unified) self._run_post_unif_sanity_checks(paths, password="123456") + # misc ---> + + async def test_wallet_objects_are_properly_garbage_collected_after_check_pw_for_dir(self): + orig_cb_count = util.callback_mgr.count_all_callbacks() + # GC sanity-check: + mclasses = [Abstract_Wallet, LNWallet, LNWatcher, LNPeerManager] + objmap = count_objects_in_memory(mclasses) + for mcls in mclasses: + self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}") + # restore some wallets + paths = [] + paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True)) + paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=False)) + paths.append(self._restore_wallet_from_text("9dk", password=None)) + paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True, passphrase="hunter2")) + paths.append(self._restore_wallet_from_text("9dk", password="999999", encrypt_file=False, passphrase="hunter2")) + paths.append(self._restore_wallet_from_text("9dk", password=None, passphrase="hunter2")) + # test unification + can_be_unified, is_unified, paths_succeeded = self.daemon.check_password_for_directory(old_password="123456", wallet_dir=self.wallet_dir) + self.assertEqual((False, False, 5), (can_be_unified, is_unified, len(paths_succeeded))) + # gc + try: + async with util.async_timeout(5): + while True: + objmap = count_objects_in_memory(mclasses) + if sum(len(lst) for lst in objmap.values()) == 0: + break # all "mclasses"-type objects have been GC-ed + await asyncio.sleep(0.01) + except asyncio.TimeoutError: + for mcls in mclasses: + self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}") + # also check callbacks have been cleaned up: + self.assertEqual(orig_cb_count, util.callback_mgr.count_all_callbacks()) + class TestCommandsWithDaemon(DaemonTestCase): TESTNET = True