diff --git a/electrum/util.py b/electrum/util.py index 300a410a2..1bf6025cd 100644 --- a/electrum/util.py +++ b/electrum/util.py @@ -21,6 +21,7 @@ # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import concurrent.futures +import copy from dataclasses import dataclass import logging import os @@ -56,6 +57,7 @@ import enum from contextlib import nullcontext, suppress import traceback import inspect +import weakref import aiohttp from aiohttp_socks import ProxyConnector, ProxyType @@ -1956,23 +1958,40 @@ class CallbackManager(Logger): def __init__(self): Logger.__init__(self) - self.callback_lock = threading.Lock() - self.callbacks = defaultdict(list) # type: Dict[str, List[Callable]] # note: needs self.callback_lock + self.callback_lock = threading.RLock() + self._wcallbacks = defaultdict(set) # type: Dict[str, Set[weakref.ref[Callable]]] # note: needs self.callback_lock - def register_callback(self, func: Callable, events: Sequence[str]) -> None: + @staticmethod + def _wcb_from_any_callback(cb: Callable) -> weakref.ref[Callable]: + assert callable(cb), type(cb) + if isinstance(cb, weakref.ref): # no-op + return cb + elif inspect.ismethod(cb): # instance method, such as for a subclass of EventListener + return WeakMethodProper(cb) + else: # proper function? e.g. used by lnpeer unit tests + return weakref.ref(cb) + + def register_callback(self, cb: Callable, events: Sequence[str]) -> None: + wcb = self._wcb_from_any_callback(cb) with self.callback_lock: for event in events: - self.callbacks[event].append(func) + self._wcallbacks[event].add(wcb) - def unregister_callback(self, callback: Callable) -> None: + def unregister_callback(self, cb: Callable) -> None: + wcb = self._wcb_from_any_callback(cb) with self.callback_lock: - for callbacks in self.callbacks.values(): - if callback in callbacks: - callbacks.remove(callback) + # note: ^ callback_lock needs to be re-entrant, as we can now trigger __del__, which also takes the lock + for callbacks in self._wcallbacks.values(): + if wcb in callbacks: + callbacks.remove(wcb) + + def count_all_callbacks(self) -> int: + with self.callback_lock: + return sum(len(cbs) for cbs in self._wcallbacks.values()) def clear_all_callbacks(self) -> None: with self.callback_lock: - self.callbacks.clear() + self._wcallbacks.clear() def trigger_callback(self, event: str, *args) -> None: """Trigger a callback with given arguments. @@ -1982,8 +2001,11 @@ class CallbackManager(Logger): loop = get_asyncio_loop() assert loop.is_running(), "event loop not running" with self.callback_lock: - callbacks = self.callbacks[event][:] - for callback in callbacks: + wcallbacks = copy.copy(self._wcallbacks[event]) + for wcb in wcallbacks: + callback = wcb() + if callback is None: + continue if inspect.iscoroutinefunction(callback): # async cb fut = asyncio.run_coroutine_threadsafe(callback(*args), loop) @@ -2005,11 +2027,28 @@ unregister_callback = callback_mgr.unregister_callback _event_listeners = defaultdict(set) # type: Dict[str, Set[str]] +class WeakMethodProper(weakref.WeakMethod): + """Unlike weakref.WeakMethod, this class has an __eq__ I can trust.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + meth = self() + self._my_id = (id(meth.__self__), id(meth.__func__)) + + def __hash__(self): + return hash(self._my_id) + + def __eq__(self, other): + if not isinstance(other, WeakMethodProper): + return False + return self._my_id == other._my_id + + class EventListener: """Use as a mixin for a class that has methods to be triggered on events. - Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener. - - register_callbacks() should be called exactly once per instance of EventListener, e.g. in __init__ + - register_callbacks() should be called once per instance of EventListener, e.g. in __init__ - unregister_callbacks() should be called at least once, e.g. when the instance is destroyed + - as fallback, __del__() also calls unregister_callbacks() """ def _list_callbacks(self): @@ -2031,6 +2070,9 @@ class EventListener: #_logger.debug(f'unregistering callback {method}') unregister_callback(method) + def __del__(self): + self.unregister_callbacks() + def event_listener(func): """To be used in subclasses of EventListener only. (how to enforce this programmatically?)""" diff --git a/electrum/utils/memory_leak.py b/electrum/utils/memory_leak.py index 0d3b463bd..efda41310 100644 --- a/electrum/utils/memory_leak.py +++ b/electrum/utils/memory_leak.py @@ -1,11 +1,32 @@ +import asyncio from collections import defaultdict import datetime import os import time +from typing import Sequence, Mapping, TypeVar, Optional +import weakref +from electrum import util from electrum.util import ThreadJob +_U = TypeVar('_U') + +def count_objects_in_memory(mclasses: Sequence[type[_U]]) -> Mapping[type[_U], Sequence[weakref.ref[_U]]]: + import gc + gc.collect() + objmap = defaultdict(list) + for obj in gc.get_objects(): + for class_ in mclasses: + try: + _isinstance = isinstance(obj, class_) + except AttributeError: + _isinstance = False + if _isinstance: + objmap[class_].append(weakref.ref(obj)) + return objmap + + class DebugMem(ThreadJob): '''A handy class for debugging GC memory leaks @@ -21,18 +42,8 @@ class DebugMem(ThreadJob): self.interval = interval def mem_stats(self): - import gc self.logger.info("Start memscan") - gc.collect() - objmap = defaultdict(list) - for obj in gc.get_objects(): - for class_ in self.classes: - try: - _isinstance = isinstance(obj, class_) - except AttributeError: - _isinstance = False - if _isinstance: - objmap[class_].append(obj) + objmap = count_objects_in_memory(self.classes) for class_, objs in objmap.items(): self.logger.info(f"{class_.__name__}: {len(objs)}") self.logger.info("Finish memscan") @@ -43,6 +54,26 @@ class DebugMem(ThreadJob): self.next_time = time.time() + self.interval +async def wait_until_obj_is_garbage_collected(wr: weakref.ref) -> None: + """Async wait until the object referenced by `wr` is GC-ed.""" + obj = wr() + if obj is None: + return + evt_gc = asyncio.Event() # set when obj is finally GC-ed. + wr2 = weakref.ref(obj, lambda _x: util.run_sync_function_on_asyncio_thread(evt_gc.set, block=False)) + del obj + while True: + try: + async with util.async_timeout(0.01): + await evt_gc.wait() + except asyncio.TimeoutError: + import gc + gc.collect() + else: + break + assert evt_gc.is_set() + + def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]: """Return a string listing the most common types in memory.""" import objgraph # 3rd-party dependency @@ -52,7 +83,7 @@ def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]: ) -def debug_memusage_dump_random_backref_chain(objtype: str) -> str: +def debug_memusage_dump_random_backref_chain(objtype: str) -> Optional[str]: """Writes a dotfile to cwd, containing the backref chain for a randomly selected object of type objtype. @@ -68,10 +99,14 @@ def debug_memusage_dump_random_backref_chain(objtype: str) -> str: import random timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ") fpath = os.path.abspath(f"electrum_backref_chain_{timestamp}.dot") + objects = objgraph.by_type(objtype) + if not objects: + return None + random_obj = random.choice(objects) with open(fpath, "w") as f: objgraph.show_chain( objgraph.find_backref_chain( - random.choice(objgraph.by_type(objtype)), + random_obj, objgraph.is_proper_module), output=f) return fpath diff --git a/electrum/wallet.py b/electrum/wallet.py index 0c608c9b7..0ab2343ef 100644 --- a/electrum/wallet.py +++ b/electrum/wallet.py @@ -562,11 +562,12 @@ class Abstract_Wallet(ABC, Logger, EventListener): self.unregister_callbacks() try: async with ignore_after(5): + if self.lnworker: + await self.lnworker.stop() + self.lnworker = None if self.network: - if self.lnworker: - await self.lnworker.stop() - self.lnworker = None self.network = None + if self.taskgroup: await self.taskgroup.cancel_remaining() self.taskgroup = None await self.adb.stop() diff --git a/tests/test_callbackmgr.py b/tests/test_callbackmgr.py new file mode 100644 index 000000000..c09347a77 --- /dev/null +++ b/tests/test_callbackmgr.py @@ -0,0 +1,141 @@ +import asyncio +import weakref + +from electrum import util +from electrum.util import EventListener, event_listener, trigger_callback +from electrum.utils.memory_leak import count_objects_in_memory, wait_until_obj_is_garbage_collected +from electrum.simple_config import SimpleConfig + +from . import ElectrumTestCase, restore_wallet_from_text__for_unittest + + +class MyEventListener(EventListener): + def __init__(self, *, autostart: bool = False): + self._satoshi_cnt = 0 + self._hal_cnt = 0 + if autostart: + self.start() + + def start(self): + self.register_callbacks() + + def stop(self): + self.unregister_callbacks() + + @event_listener + async def on_event_satoshi_moves_his_coins(self, arg1, arg2): + self._satoshi_cnt += 1 + + @event_listener + def on_event_hal_moves_his_coins(self, arg1, arg2): # non-async + self._hal_cnt += 1 + + +_count_all_callbacks = util.callback_mgr.count_all_callbacks + + +async def fast_sleep(): + # sleep a few event loop iterations + for i in range(5): + await asyncio.sleep(0) + + +class TestCallbackMgr(ElectrumTestCase): + + def test_multiple_calls_to_register_callbacks(self): + self.assertEqual(0, _count_all_callbacks()) + el1 = MyEventListener() + el2 = MyEventListener() + self.assertEqual(0, _count_all_callbacks()) + el1.start() + self.assertEqual(2, _count_all_callbacks()) + el2.start() + self.assertEqual(4, _count_all_callbacks()) + el1.start() + self.assertEqual(4, _count_all_callbacks()) + el1.stop() + self.assertEqual(2, _count_all_callbacks()) + el1.stop() + self.assertEqual(2, _count_all_callbacks()) + el1.stop() + self.assertEqual(2, _count_all_callbacks()) + el2.stop() + self.assertEqual(0, _count_all_callbacks()) + + async def test_trigger_callback(self): + el1 = MyEventListener() + el1.start() + el2 = MyEventListener() + el2.start() + # trigger some cbs + self.assertEqual(el1._satoshi_cnt, 0) + self.assertEqual(el1._hal_cnt, 0) + trigger_callback('satoshi_moves_his_coins', 0, 0) + trigger_callback('satoshi_moves_his_coins', 0, 0) + trigger_callback('satoshi_moves_his_coins', 0, 0) + trigger_callback('hal_moves_his_coins', 0, 0) + await fast_sleep() + self.assertEqual(el1._satoshi_cnt, 3) + self.assertEqual(el2._satoshi_cnt, 3) + self.assertEqual(el1._hal_cnt, 1) + self.assertEqual(el2._hal_cnt, 1) + # stop one listener, see new triggers are only seen by other one still running + el1.stop() + trigger_callback('satoshi_moves_his_coins', 0, 0) + trigger_callback('hal_moves_his_coins', 0, 0) + await fast_sleep() + self.assertEqual(el1._satoshi_cnt, 3) + self.assertEqual(el2._satoshi_cnt, 4) + self.assertEqual(el1._hal_cnt, 1) + self.assertEqual(el2._hal_cnt, 2) + + async def test_gc(self): + objmap = count_objects_in_memory([MyEventListener]) + self.assertEqual(len(objmap[MyEventListener]), 0) + self.assertEqual(_count_all_callbacks(), 0) + el1 = MyEventListener() + el1.start() + el2 = MyEventListener() + el2.start() + objmap = count_objects_in_memory([MyEventListener]) + self.assertEqual(len(objmap[MyEventListener]), 2) + self.assertEqual(_count_all_callbacks(), 4) + # test if we can get GC-ed if we explicitly unregister cbs: + el1.stop() # calls unregister_callbacks + del el1 + objmap = count_objects_in_memory([MyEventListener]) + self.assertEqual(len(objmap[MyEventListener]), 1) + self.assertEqual(_count_all_callbacks(), 2) + # test if we can get GC-ed even without unregistering cbs: + del el2 + objmap = count_objects_in_memory([MyEventListener]) + self.assertEqual(len(objmap[MyEventListener]), 0) + self.assertEqual(_count_all_callbacks(), 0) + + async def test_gc2(self): + def func(): + el1 = MyEventListener(autostart=True) + el1.el2 = MyEventListener(autostart=True) + el1.el2.el3 = MyEventListener(autostart=True) + self.assertEqual(_count_all_callbacks(), 6) + func() + self.assertEqual(_count_all_callbacks(), 0) + + async def test_gc_complex_using_wallet(self): + """This test showcases why EventListener uses WeakMethodProper instead of weakref.WeakMethod. + We need the custom __eq__ for some reason. + """ + self.assertEqual(_count_all_callbacks(), 0) + config = SimpleConfig({'electrum_path': self.electrum_path}) + wallet = restore_wallet_from_text__for_unittest( + "9dk", path=None, config=config, + )["wallet"] + assert wallet.lnworker is not None + # now delete the wallet, and wait for it to get GC-ed + # note: need to wait for cyclic GC. example: wallet.lnworker.wallet + wr = weakref.ref(wallet) + del wallet + async with util.async_timeout(5): + await wait_until_obj_is_garbage_collected(wr) + # by now, all callbacks must have been cleaned up: + self.assertEqual(_count_all_callbacks(), 0) diff --git a/tests/test_daemon.py b/tests/test_daemon.py index f5bbbade2..6da78076a 100644 --- a/tests/test_daemon.py +++ b/tests/test_daemon.py @@ -1,3 +1,5 @@ +import asyncio +from collections import defaultdict import os from typing import Optional, Iterable @@ -5,7 +7,10 @@ from electrum.commands import Commands from electrum.daemon import Daemon from electrum.simple_config import SimpleConfig from electrum.wallet import Abstract_Wallet +from electrum.lnworker import LNWallet, LNPeerManager +from electrum.lnwatcher import LNWatcher from electrum import util +from electrum.utils.memory_leak import count_objects_in_memory from . import ElectrumTestCase, as_testnet, restore_wallet_from_text__for_unittest @@ -30,7 +35,7 @@ class DaemonTestCase(ElectrumTestCase): await self.daemon.stop() await super().asyncTearDown() - def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None) -> str: + def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None, **kwargs) -> str: """Returns path for created wallet.""" basename = util.get_new_wallet_name(self.wallet_dir) path = os.path.join(self.wallet_dir, basename) @@ -40,6 +45,7 @@ class DaemonTestCase(ElectrumTestCase): password=password, encrypt_file=encrypt_file, config=self.config, + **kwargs, ) # We return the path instead of the wallet object, as extreme # care would be needed to use the wallet object directly: @@ -188,6 +194,40 @@ class TestUnifiedPassword(DaemonTestCase): self.assertTrue(is_unified) self._run_post_unif_sanity_checks(paths, password="123456") + # misc ---> + + async def test_wallet_objects_are_properly_garbage_collected_after_check_pw_for_dir(self): + orig_cb_count = util.callback_mgr.count_all_callbacks() + # GC sanity-check: + mclasses = [Abstract_Wallet, LNWallet, LNWatcher, LNPeerManager] + objmap = count_objects_in_memory(mclasses) + for mcls in mclasses: + self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}") + # restore some wallets + paths = [] + paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True)) + paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=False)) + paths.append(self._restore_wallet_from_text("9dk", password=None)) + paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True, passphrase="hunter2")) + paths.append(self._restore_wallet_from_text("9dk", password="999999", encrypt_file=False, passphrase="hunter2")) + paths.append(self._restore_wallet_from_text("9dk", password=None, passphrase="hunter2")) + # test unification + can_be_unified, is_unified, paths_succeeded = self.daemon.check_password_for_directory(old_password="123456", wallet_dir=self.wallet_dir) + self.assertEqual((False, False, 5), (can_be_unified, is_unified, len(paths_succeeded))) + # gc + try: + async with util.async_timeout(5): + while True: + objmap = count_objects_in_memory(mclasses) + if sum(len(lst) for lst in objmap.values()) == 0: + break # all "mclasses"-type objects have been GC-ed + await asyncio.sleep(0.01) + except asyncio.TimeoutError: + for mcls in mclasses: + self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}") + # also check callbacks have been cleaned up: + self.assertEqual(orig_cb_count, util.callback_mgr.count_all_callbacks()) + class TestCommandsWithDaemon(DaemonTestCase): TESTNET = True