1
0

Merge pull request #10428 from SomberNight/202601_wallet_objs_not_gced

fix util.CallbackManager memory leak
This commit is contained in:
ghost43
2026-01-22 15:05:22 +00:00
committed by GitHub
5 changed files with 288 additions and 29 deletions

View File

@@ -21,6 +21,7 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import concurrent.futures import concurrent.futures
import copy
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
import os import os
@@ -56,6 +57,7 @@ import enum
from contextlib import nullcontext, suppress from contextlib import nullcontext, suppress
import traceback import traceback
import inspect import inspect
import weakref
import aiohttp import aiohttp
from aiohttp_socks import ProxyConnector, ProxyType from aiohttp_socks import ProxyConnector, ProxyType
@@ -1956,23 +1958,40 @@ class CallbackManager(Logger):
def __init__(self): def __init__(self):
Logger.__init__(self) Logger.__init__(self)
self.callback_lock = threading.Lock() self.callback_lock = threading.RLock()
self.callbacks = defaultdict(list) # type: Dict[str, List[Callable]] # note: needs self.callback_lock self._wcallbacks = defaultdict(set) # type: Dict[str, Set[weakref.ref[Callable]]] # note: needs self.callback_lock
def register_callback(self, func: Callable, events: Sequence[str]) -> None: @staticmethod
def _wcb_from_any_callback(cb: Callable) -> weakref.ref[Callable]:
assert callable(cb), type(cb)
if isinstance(cb, weakref.ref): # no-op
return cb
elif inspect.ismethod(cb): # instance method, such as for a subclass of EventListener
return WeakMethodProper(cb)
else: # proper function? e.g. used by lnpeer unit tests
return weakref.ref(cb)
def register_callback(self, cb: Callable, events: Sequence[str]) -> None:
wcb = self._wcb_from_any_callback(cb)
with self.callback_lock: with self.callback_lock:
for event in events: for event in events:
self.callbacks[event].append(func) self._wcallbacks[event].add(wcb)
def unregister_callback(self, callback: Callable) -> None: def unregister_callback(self, cb: Callable) -> None:
wcb = self._wcb_from_any_callback(cb)
with self.callback_lock: with self.callback_lock:
for callbacks in self.callbacks.values(): # note: ^ callback_lock needs to be re-entrant, as we can now trigger __del__, which also takes the lock
if callback in callbacks: for callbacks in self._wcallbacks.values():
callbacks.remove(callback) if wcb in callbacks:
callbacks.remove(wcb)
def count_all_callbacks(self) -> int:
with self.callback_lock:
return sum(len(cbs) for cbs in self._wcallbacks.values())
def clear_all_callbacks(self) -> None: def clear_all_callbacks(self) -> None:
with self.callback_lock: with self.callback_lock:
self.callbacks.clear() self._wcallbacks.clear()
def trigger_callback(self, event: str, *args) -> None: def trigger_callback(self, event: str, *args) -> None:
"""Trigger a callback with given arguments. """Trigger a callback with given arguments.
@@ -1982,8 +2001,11 @@ class CallbackManager(Logger):
loop = get_asyncio_loop() loop = get_asyncio_loop()
assert loop.is_running(), "event loop not running" assert loop.is_running(), "event loop not running"
with self.callback_lock: with self.callback_lock:
callbacks = self.callbacks[event][:] wcallbacks = copy.copy(self._wcallbacks[event])
for callback in callbacks: for wcb in wcallbacks:
callback = wcb()
if callback is None:
continue
if inspect.iscoroutinefunction(callback): # async cb if inspect.iscoroutinefunction(callback): # async cb
fut = asyncio.run_coroutine_threadsafe(callback(*args), loop) fut = asyncio.run_coroutine_threadsafe(callback(*args), loop)
@@ -2005,11 +2027,28 @@ unregister_callback = callback_mgr.unregister_callback
_event_listeners = defaultdict(set) # type: Dict[str, Set[str]] _event_listeners = defaultdict(set) # type: Dict[str, Set[str]]
class WeakMethodProper(weakref.WeakMethod):
"""Unlike weakref.WeakMethod, this class has an __eq__ I can trust."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
meth = self()
self._my_id = (id(meth.__self__), id(meth.__func__))
def __hash__(self):
return hash(self._my_id)
def __eq__(self, other):
if not isinstance(other, WeakMethodProper):
return False
return self._my_id == other._my_id
class EventListener: class EventListener:
"""Use as a mixin for a class that has methods to be triggered on events. """Use as a mixin for a class that has methods to be triggered on events.
- Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener. - Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener.
- register_callbacks() should be called exactly once per instance of EventListener, e.g. in __init__ - register_callbacks() should be called once per instance of EventListener, e.g. in __init__
- unregister_callbacks() should be called at least once, e.g. when the instance is destroyed - unregister_callbacks() should be called at least once, e.g. when the instance is destroyed
- as fallback, __del__() also calls unregister_callbacks()
""" """
def _list_callbacks(self): def _list_callbacks(self):
@@ -2031,6 +2070,9 @@ class EventListener:
#_logger.debug(f'unregistering callback {method}') #_logger.debug(f'unregistering callback {method}')
unregister_callback(method) unregister_callback(method)
def __del__(self):
self.unregister_callbacks()
def event_listener(func): def event_listener(func):
"""To be used in subclasses of EventListener only. (how to enforce this programmatically?)""" """To be used in subclasses of EventListener only. (how to enforce this programmatically?)"""

View File

@@ -1,11 +1,32 @@
import asyncio
from collections import defaultdict from collections import defaultdict
import datetime import datetime
import os import os
import time import time
from typing import Sequence, Mapping, TypeVar, Optional
import weakref
from electrum import util
from electrum.util import ThreadJob from electrum.util import ThreadJob
_U = TypeVar('_U')
def count_objects_in_memory(mclasses: Sequence[type[_U]]) -> Mapping[type[_U], Sequence[weakref.ref[_U]]]:
import gc
gc.collect()
objmap = defaultdict(list)
for obj in gc.get_objects():
for class_ in mclasses:
try:
_isinstance = isinstance(obj, class_)
except AttributeError:
_isinstance = False
if _isinstance:
objmap[class_].append(weakref.ref(obj))
return objmap
class DebugMem(ThreadJob): class DebugMem(ThreadJob):
'''A handy class for debugging GC memory leaks '''A handy class for debugging GC memory leaks
@@ -21,18 +42,8 @@ class DebugMem(ThreadJob):
self.interval = interval self.interval = interval
def mem_stats(self): def mem_stats(self):
import gc
self.logger.info("Start memscan") self.logger.info("Start memscan")
gc.collect() objmap = count_objects_in_memory(self.classes)
objmap = defaultdict(list)
for obj in gc.get_objects():
for class_ in self.classes:
try:
_isinstance = isinstance(obj, class_)
except AttributeError:
_isinstance = False
if _isinstance:
objmap[class_].append(obj)
for class_, objs in objmap.items(): for class_, objs in objmap.items():
self.logger.info(f"{class_.__name__}: {len(objs)}") self.logger.info(f"{class_.__name__}: {len(objs)}")
self.logger.info("Finish memscan") self.logger.info("Finish memscan")
@@ -43,6 +54,26 @@ class DebugMem(ThreadJob):
self.next_time = time.time() + self.interval self.next_time = time.time() + self.interval
async def wait_until_obj_is_garbage_collected(wr: weakref.ref) -> None:
"""Async wait until the object referenced by `wr` is GC-ed."""
obj = wr()
if obj is None:
return
evt_gc = asyncio.Event() # set when obj is finally GC-ed.
wr2 = weakref.ref(obj, lambda _x: util.run_sync_function_on_asyncio_thread(evt_gc.set, block=False))
del obj
while True:
try:
async with util.async_timeout(0.01):
await evt_gc.wait()
except asyncio.TimeoutError:
import gc
gc.collect()
else:
break
assert evt_gc.is_set()
def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]: def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
"""Return a string listing the most common types in memory.""" """Return a string listing the most common types in memory."""
import objgraph # 3rd-party dependency import objgraph # 3rd-party dependency
@@ -52,7 +83,7 @@ def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
) )
def debug_memusage_dump_random_backref_chain(objtype: str) -> str: def debug_memusage_dump_random_backref_chain(objtype: str) -> Optional[str]:
"""Writes a dotfile to cwd, containing the backref chain """Writes a dotfile to cwd, containing the backref chain
for a randomly selected object of type objtype. for a randomly selected object of type objtype.
@@ -68,10 +99,14 @@ def debug_memusage_dump_random_backref_chain(objtype: str) -> str:
import random import random
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ") timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ")
fpath = os.path.abspath(f"electrum_backref_chain_{timestamp}.dot") fpath = os.path.abspath(f"electrum_backref_chain_{timestamp}.dot")
objects = objgraph.by_type(objtype)
if not objects:
return None
random_obj = random.choice(objects)
with open(fpath, "w") as f: with open(fpath, "w") as f:
objgraph.show_chain( objgraph.show_chain(
objgraph.find_backref_chain( objgraph.find_backref_chain(
random.choice(objgraph.by_type(objtype)), random_obj,
objgraph.is_proper_module), objgraph.is_proper_module),
output=f) output=f)
return fpath return fpath

View File

@@ -562,11 +562,12 @@ class Abstract_Wallet(ABC, Logger, EventListener):
self.unregister_callbacks() self.unregister_callbacks()
try: try:
async with ignore_after(5): async with ignore_after(5):
if self.lnworker:
await self.lnworker.stop()
self.lnworker = None
if self.network: if self.network:
if self.lnworker:
await self.lnworker.stop()
self.lnworker = None
self.network = None self.network = None
if self.taskgroup:
await self.taskgroup.cancel_remaining() await self.taskgroup.cancel_remaining()
self.taskgroup = None self.taskgroup = None
await self.adb.stop() await self.adb.stop()

141
tests/test_callbackmgr.py Normal file
View File

@@ -0,0 +1,141 @@
import asyncio
import weakref
from electrum import util
from electrum.util import EventListener, event_listener, trigger_callback
from electrum.utils.memory_leak import count_objects_in_memory, wait_until_obj_is_garbage_collected
from electrum.simple_config import SimpleConfig
from . import ElectrumTestCase, restore_wallet_from_text__for_unittest
class MyEventListener(EventListener):
def __init__(self, *, autostart: bool = False):
self._satoshi_cnt = 0
self._hal_cnt = 0
if autostart:
self.start()
def start(self):
self.register_callbacks()
def stop(self):
self.unregister_callbacks()
@event_listener
async def on_event_satoshi_moves_his_coins(self, arg1, arg2):
self._satoshi_cnt += 1
@event_listener
def on_event_hal_moves_his_coins(self, arg1, arg2): # non-async
self._hal_cnt += 1
_count_all_callbacks = util.callback_mgr.count_all_callbacks
async def fast_sleep():
# sleep a few event loop iterations
for i in range(5):
await asyncio.sleep(0)
class TestCallbackMgr(ElectrumTestCase):
def test_multiple_calls_to_register_callbacks(self):
self.assertEqual(0, _count_all_callbacks())
el1 = MyEventListener()
el2 = MyEventListener()
self.assertEqual(0, _count_all_callbacks())
el1.start()
self.assertEqual(2, _count_all_callbacks())
el2.start()
self.assertEqual(4, _count_all_callbacks())
el1.start()
self.assertEqual(4, _count_all_callbacks())
el1.stop()
self.assertEqual(2, _count_all_callbacks())
el1.stop()
self.assertEqual(2, _count_all_callbacks())
el1.stop()
self.assertEqual(2, _count_all_callbacks())
el2.stop()
self.assertEqual(0, _count_all_callbacks())
async def test_trigger_callback(self):
el1 = MyEventListener()
el1.start()
el2 = MyEventListener()
el2.start()
# trigger some cbs
self.assertEqual(el1._satoshi_cnt, 0)
self.assertEqual(el1._hal_cnt, 0)
trigger_callback('satoshi_moves_his_coins', 0, 0)
trigger_callback('satoshi_moves_his_coins', 0, 0)
trigger_callback('satoshi_moves_his_coins', 0, 0)
trigger_callback('hal_moves_his_coins', 0, 0)
await fast_sleep()
self.assertEqual(el1._satoshi_cnt, 3)
self.assertEqual(el2._satoshi_cnt, 3)
self.assertEqual(el1._hal_cnt, 1)
self.assertEqual(el2._hal_cnt, 1)
# stop one listener, see new triggers are only seen by other one still running
el1.stop()
trigger_callback('satoshi_moves_his_coins', 0, 0)
trigger_callback('hal_moves_his_coins', 0, 0)
await fast_sleep()
self.assertEqual(el1._satoshi_cnt, 3)
self.assertEqual(el2._satoshi_cnt, 4)
self.assertEqual(el1._hal_cnt, 1)
self.assertEqual(el2._hal_cnt, 2)
async def test_gc(self):
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 0)
self.assertEqual(_count_all_callbacks(), 0)
el1 = MyEventListener()
el1.start()
el2 = MyEventListener()
el2.start()
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 2)
self.assertEqual(_count_all_callbacks(), 4)
# test if we can get GC-ed if we explicitly unregister cbs:
el1.stop() # calls unregister_callbacks
del el1
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 1)
self.assertEqual(_count_all_callbacks(), 2)
# test if we can get GC-ed even without unregistering cbs:
del el2
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 0)
self.assertEqual(_count_all_callbacks(), 0)
async def test_gc2(self):
def func():
el1 = MyEventListener(autostart=True)
el1.el2 = MyEventListener(autostart=True)
el1.el2.el3 = MyEventListener(autostart=True)
self.assertEqual(_count_all_callbacks(), 6)
func()
self.assertEqual(_count_all_callbacks(), 0)
async def test_gc_complex_using_wallet(self):
"""This test showcases why EventListener uses WeakMethodProper instead of weakref.WeakMethod.
We need the custom __eq__ for some reason.
"""
self.assertEqual(_count_all_callbacks(), 0)
config = SimpleConfig({'electrum_path': self.electrum_path})
wallet = restore_wallet_from_text__for_unittest(
"9dk", path=None, config=config,
)["wallet"]
assert wallet.lnworker is not None
# now delete the wallet, and wait for it to get GC-ed
# note: need to wait for cyclic GC. example: wallet.lnworker.wallet
wr = weakref.ref(wallet)
del wallet
async with util.async_timeout(5):
await wait_until_obj_is_garbage_collected(wr)
# by now, all callbacks must have been cleaned up:
self.assertEqual(_count_all_callbacks(), 0)

View File

@@ -1,3 +1,5 @@
import asyncio
from collections import defaultdict
import os import os
from typing import Optional, Iterable from typing import Optional, Iterable
@@ -5,7 +7,10 @@ from electrum.commands import Commands
from electrum.daemon import Daemon from electrum.daemon import Daemon
from electrum.simple_config import SimpleConfig from electrum.simple_config import SimpleConfig
from electrum.wallet import Abstract_Wallet from electrum.wallet import Abstract_Wallet
from electrum.lnworker import LNWallet, LNPeerManager
from electrum.lnwatcher import LNWatcher
from electrum import util from electrum import util
from electrum.utils.memory_leak import count_objects_in_memory
from . import ElectrumTestCase, as_testnet, restore_wallet_from_text__for_unittest from . import ElectrumTestCase, as_testnet, restore_wallet_from_text__for_unittest
@@ -30,7 +35,7 @@ class DaemonTestCase(ElectrumTestCase):
await self.daemon.stop() await self.daemon.stop()
await super().asyncTearDown() await super().asyncTearDown()
def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None) -> str: def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None, **kwargs) -> str:
"""Returns path for created wallet.""" """Returns path for created wallet."""
basename = util.get_new_wallet_name(self.wallet_dir) basename = util.get_new_wallet_name(self.wallet_dir)
path = os.path.join(self.wallet_dir, basename) path = os.path.join(self.wallet_dir, basename)
@@ -40,6 +45,7 @@ class DaemonTestCase(ElectrumTestCase):
password=password, password=password,
encrypt_file=encrypt_file, encrypt_file=encrypt_file,
config=self.config, config=self.config,
**kwargs,
) )
# We return the path instead of the wallet object, as extreme # We return the path instead of the wallet object, as extreme
# care would be needed to use the wallet object directly: # care would be needed to use the wallet object directly:
@@ -188,6 +194,40 @@ class TestUnifiedPassword(DaemonTestCase):
self.assertTrue(is_unified) self.assertTrue(is_unified)
self._run_post_unif_sanity_checks(paths, password="123456") self._run_post_unif_sanity_checks(paths, password="123456")
# misc --->
async def test_wallet_objects_are_properly_garbage_collected_after_check_pw_for_dir(self):
orig_cb_count = util.callback_mgr.count_all_callbacks()
# GC sanity-check:
mclasses = [Abstract_Wallet, LNWallet, LNWatcher, LNPeerManager]
objmap = count_objects_in_memory(mclasses)
for mcls in mclasses:
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
# restore some wallets
paths = []
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True))
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=False))
paths.append(self._restore_wallet_from_text("9dk", password=None))
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True, passphrase="hunter2"))
paths.append(self._restore_wallet_from_text("9dk", password="999999", encrypt_file=False, passphrase="hunter2"))
paths.append(self._restore_wallet_from_text("9dk", password=None, passphrase="hunter2"))
# test unification
can_be_unified, is_unified, paths_succeeded = self.daemon.check_password_for_directory(old_password="123456", wallet_dir=self.wallet_dir)
self.assertEqual((False, False, 5), (can_be_unified, is_unified, len(paths_succeeded)))
# gc
try:
async with util.async_timeout(5):
while True:
objmap = count_objects_in_memory(mclasses)
if sum(len(lst) for lst in objmap.values()) == 0:
break # all "mclasses"-type objects have been GC-ed
await asyncio.sleep(0.01)
except asyncio.TimeoutError:
for mcls in mclasses:
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
# also check callbacks have been cleaned up:
self.assertEqual(orig_cb_count, util.callback_mgr.count_all_callbacks())
class TestCommandsWithDaemon(DaemonTestCase): class TestCommandsWithDaemon(DaemonTestCase):
TESTNET = True TESTNET = True