1
0

Merge pull request #10428 from SomberNight/202601_wallet_objs_not_gced

fix util.CallbackManager memory leak
This commit is contained in:
ghost43
2026-01-22 15:05:22 +00:00
committed by GitHub
5 changed files with 288 additions and 29 deletions

View File

@@ -21,6 +21,7 @@
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import concurrent.futures
import copy
from dataclasses import dataclass
import logging
import os
@@ -56,6 +57,7 @@ import enum
from contextlib import nullcontext, suppress
import traceback
import inspect
import weakref
import aiohttp
from aiohttp_socks import ProxyConnector, ProxyType
@@ -1956,23 +1958,40 @@ class CallbackManager(Logger):
def __init__(self):
Logger.__init__(self)
self.callback_lock = threading.Lock()
self.callbacks = defaultdict(list) # type: Dict[str, List[Callable]] # note: needs self.callback_lock
self.callback_lock = threading.RLock()
self._wcallbacks = defaultdict(set) # type: Dict[str, Set[weakref.ref[Callable]]] # note: needs self.callback_lock
def register_callback(self, func: Callable, events: Sequence[str]) -> None:
@staticmethod
def _wcb_from_any_callback(cb: Callable) -> weakref.ref[Callable]:
assert callable(cb), type(cb)
if isinstance(cb, weakref.ref): # no-op
return cb
elif inspect.ismethod(cb): # instance method, such as for a subclass of EventListener
return WeakMethodProper(cb)
else: # proper function? e.g. used by lnpeer unit tests
return weakref.ref(cb)
def register_callback(self, cb: Callable, events: Sequence[str]) -> None:
wcb = self._wcb_from_any_callback(cb)
with self.callback_lock:
for event in events:
self.callbacks[event].append(func)
self._wcallbacks[event].add(wcb)
def unregister_callback(self, callback: Callable) -> None:
def unregister_callback(self, cb: Callable) -> None:
wcb = self._wcb_from_any_callback(cb)
with self.callback_lock:
for callbacks in self.callbacks.values():
if callback in callbacks:
callbacks.remove(callback)
# note: ^ callback_lock needs to be re-entrant, as we can now trigger __del__, which also takes the lock
for callbacks in self._wcallbacks.values():
if wcb in callbacks:
callbacks.remove(wcb)
def count_all_callbacks(self) -> int:
with self.callback_lock:
return sum(len(cbs) for cbs in self._wcallbacks.values())
def clear_all_callbacks(self) -> None:
with self.callback_lock:
self.callbacks.clear()
self._wcallbacks.clear()
def trigger_callback(self, event: str, *args) -> None:
"""Trigger a callback with given arguments.
@@ -1982,8 +2001,11 @@ class CallbackManager(Logger):
loop = get_asyncio_loop()
assert loop.is_running(), "event loop not running"
with self.callback_lock:
callbacks = self.callbacks[event][:]
for callback in callbacks:
wcallbacks = copy.copy(self._wcallbacks[event])
for wcb in wcallbacks:
callback = wcb()
if callback is None:
continue
if inspect.iscoroutinefunction(callback): # async cb
fut = asyncio.run_coroutine_threadsafe(callback(*args), loop)
@@ -2005,11 +2027,28 @@ unregister_callback = callback_mgr.unregister_callback
_event_listeners = defaultdict(set) # type: Dict[str, Set[str]]
class WeakMethodProper(weakref.WeakMethod):
"""Unlike weakref.WeakMethod, this class has an __eq__ I can trust."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
meth = self()
self._my_id = (id(meth.__self__), id(meth.__func__))
def __hash__(self):
return hash(self._my_id)
def __eq__(self, other):
if not isinstance(other, WeakMethodProper):
return False
return self._my_id == other._my_id
class EventListener:
"""Use as a mixin for a class that has methods to be triggered on events.
- Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener.
- register_callbacks() should be called exactly once per instance of EventListener, e.g. in __init__
- register_callbacks() should be called once per instance of EventListener, e.g. in __init__
- unregister_callbacks() should be called at least once, e.g. when the instance is destroyed
- as fallback, __del__() also calls unregister_callbacks()
"""
def _list_callbacks(self):
@@ -2031,6 +2070,9 @@ class EventListener:
#_logger.debug(f'unregistering callback {method}')
unregister_callback(method)
def __del__(self):
self.unregister_callbacks()
def event_listener(func):
"""To be used in subclasses of EventListener only. (how to enforce this programmatically?)"""

View File

@@ -1,11 +1,32 @@
import asyncio
from collections import defaultdict
import datetime
import os
import time
from typing import Sequence, Mapping, TypeVar, Optional
import weakref
from electrum import util
from electrum.util import ThreadJob
_U = TypeVar('_U')
def count_objects_in_memory(mclasses: Sequence[type[_U]]) -> Mapping[type[_U], Sequence[weakref.ref[_U]]]:
import gc
gc.collect()
objmap = defaultdict(list)
for obj in gc.get_objects():
for class_ in mclasses:
try:
_isinstance = isinstance(obj, class_)
except AttributeError:
_isinstance = False
if _isinstance:
objmap[class_].append(weakref.ref(obj))
return objmap
class DebugMem(ThreadJob):
'''A handy class for debugging GC memory leaks
@@ -21,18 +42,8 @@ class DebugMem(ThreadJob):
self.interval = interval
def mem_stats(self):
import gc
self.logger.info("Start memscan")
gc.collect()
objmap = defaultdict(list)
for obj in gc.get_objects():
for class_ in self.classes:
try:
_isinstance = isinstance(obj, class_)
except AttributeError:
_isinstance = False
if _isinstance:
objmap[class_].append(obj)
objmap = count_objects_in_memory(self.classes)
for class_, objs in objmap.items():
self.logger.info(f"{class_.__name__}: {len(objs)}")
self.logger.info("Finish memscan")
@@ -43,6 +54,26 @@ class DebugMem(ThreadJob):
self.next_time = time.time() + self.interval
async def wait_until_obj_is_garbage_collected(wr: weakref.ref) -> None:
"""Async wait until the object referenced by `wr` is GC-ed."""
obj = wr()
if obj is None:
return
evt_gc = asyncio.Event() # set when obj is finally GC-ed.
wr2 = weakref.ref(obj, lambda _x: util.run_sync_function_on_asyncio_thread(evt_gc.set, block=False))
del obj
while True:
try:
async with util.async_timeout(0.01):
await evt_gc.wait()
except asyncio.TimeoutError:
import gc
gc.collect()
else:
break
assert evt_gc.is_set()
def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
"""Return a string listing the most common types in memory."""
import objgraph # 3rd-party dependency
@@ -52,7 +83,7 @@ def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
)
def debug_memusage_dump_random_backref_chain(objtype: str) -> str:
def debug_memusage_dump_random_backref_chain(objtype: str) -> Optional[str]:
"""Writes a dotfile to cwd, containing the backref chain
for a randomly selected object of type objtype.
@@ -68,10 +99,14 @@ def debug_memusage_dump_random_backref_chain(objtype: str) -> str:
import random
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ")
fpath = os.path.abspath(f"electrum_backref_chain_{timestamp}.dot")
objects = objgraph.by_type(objtype)
if not objects:
return None
random_obj = random.choice(objects)
with open(fpath, "w") as f:
objgraph.show_chain(
objgraph.find_backref_chain(
random.choice(objgraph.by_type(objtype)),
random_obj,
objgraph.is_proper_module),
output=f)
return fpath

View File

@@ -562,11 +562,12 @@ class Abstract_Wallet(ABC, Logger, EventListener):
self.unregister_callbacks()
try:
async with ignore_after(5):
if self.lnworker:
await self.lnworker.stop()
self.lnworker = None
if self.network:
if self.lnworker:
await self.lnworker.stop()
self.lnworker = None
self.network = None
if self.taskgroup:
await self.taskgroup.cancel_remaining()
self.taskgroup = None
await self.adb.stop()

141
tests/test_callbackmgr.py Normal file
View File

@@ -0,0 +1,141 @@
import asyncio
import weakref
from electrum import util
from electrum.util import EventListener, event_listener, trigger_callback
from electrum.utils.memory_leak import count_objects_in_memory, wait_until_obj_is_garbage_collected
from electrum.simple_config import SimpleConfig
from . import ElectrumTestCase, restore_wallet_from_text__for_unittest
class MyEventListener(EventListener):
def __init__(self, *, autostart: bool = False):
self._satoshi_cnt = 0
self._hal_cnt = 0
if autostart:
self.start()
def start(self):
self.register_callbacks()
def stop(self):
self.unregister_callbacks()
@event_listener
async def on_event_satoshi_moves_his_coins(self, arg1, arg2):
self._satoshi_cnt += 1
@event_listener
def on_event_hal_moves_his_coins(self, arg1, arg2): # non-async
self._hal_cnt += 1
_count_all_callbacks = util.callback_mgr.count_all_callbacks
async def fast_sleep():
# sleep a few event loop iterations
for i in range(5):
await asyncio.sleep(0)
class TestCallbackMgr(ElectrumTestCase):
def test_multiple_calls_to_register_callbacks(self):
self.assertEqual(0, _count_all_callbacks())
el1 = MyEventListener()
el2 = MyEventListener()
self.assertEqual(0, _count_all_callbacks())
el1.start()
self.assertEqual(2, _count_all_callbacks())
el2.start()
self.assertEqual(4, _count_all_callbacks())
el1.start()
self.assertEqual(4, _count_all_callbacks())
el1.stop()
self.assertEqual(2, _count_all_callbacks())
el1.stop()
self.assertEqual(2, _count_all_callbacks())
el1.stop()
self.assertEqual(2, _count_all_callbacks())
el2.stop()
self.assertEqual(0, _count_all_callbacks())
async def test_trigger_callback(self):
el1 = MyEventListener()
el1.start()
el2 = MyEventListener()
el2.start()
# trigger some cbs
self.assertEqual(el1._satoshi_cnt, 0)
self.assertEqual(el1._hal_cnt, 0)
trigger_callback('satoshi_moves_his_coins', 0, 0)
trigger_callback('satoshi_moves_his_coins', 0, 0)
trigger_callback('satoshi_moves_his_coins', 0, 0)
trigger_callback('hal_moves_his_coins', 0, 0)
await fast_sleep()
self.assertEqual(el1._satoshi_cnt, 3)
self.assertEqual(el2._satoshi_cnt, 3)
self.assertEqual(el1._hal_cnt, 1)
self.assertEqual(el2._hal_cnt, 1)
# stop one listener, see new triggers are only seen by other one still running
el1.stop()
trigger_callback('satoshi_moves_his_coins', 0, 0)
trigger_callback('hal_moves_his_coins', 0, 0)
await fast_sleep()
self.assertEqual(el1._satoshi_cnt, 3)
self.assertEqual(el2._satoshi_cnt, 4)
self.assertEqual(el1._hal_cnt, 1)
self.assertEqual(el2._hal_cnt, 2)
async def test_gc(self):
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 0)
self.assertEqual(_count_all_callbacks(), 0)
el1 = MyEventListener()
el1.start()
el2 = MyEventListener()
el2.start()
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 2)
self.assertEqual(_count_all_callbacks(), 4)
# test if we can get GC-ed if we explicitly unregister cbs:
el1.stop() # calls unregister_callbacks
del el1
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 1)
self.assertEqual(_count_all_callbacks(), 2)
# test if we can get GC-ed even without unregistering cbs:
del el2
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 0)
self.assertEqual(_count_all_callbacks(), 0)
async def test_gc2(self):
def func():
el1 = MyEventListener(autostart=True)
el1.el2 = MyEventListener(autostart=True)
el1.el2.el3 = MyEventListener(autostart=True)
self.assertEqual(_count_all_callbacks(), 6)
func()
self.assertEqual(_count_all_callbacks(), 0)
async def test_gc_complex_using_wallet(self):
"""This test showcases why EventListener uses WeakMethodProper instead of weakref.WeakMethod.
We need the custom __eq__ for some reason.
"""
self.assertEqual(_count_all_callbacks(), 0)
config = SimpleConfig({'electrum_path': self.electrum_path})
wallet = restore_wallet_from_text__for_unittest(
"9dk", path=None, config=config,
)["wallet"]
assert wallet.lnworker is not None
# now delete the wallet, and wait for it to get GC-ed
# note: need to wait for cyclic GC. example: wallet.lnworker.wallet
wr = weakref.ref(wallet)
del wallet
async with util.async_timeout(5):
await wait_until_obj_is_garbage_collected(wr)
# by now, all callbacks must have been cleaned up:
self.assertEqual(_count_all_callbacks(), 0)

View File

@@ -1,3 +1,5 @@
import asyncio
from collections import defaultdict
import os
from typing import Optional, Iterable
@@ -5,7 +7,10 @@ from electrum.commands import Commands
from electrum.daemon import Daemon
from electrum.simple_config import SimpleConfig
from electrum.wallet import Abstract_Wallet
from electrum.lnworker import LNWallet, LNPeerManager
from electrum.lnwatcher import LNWatcher
from electrum import util
from electrum.utils.memory_leak import count_objects_in_memory
from . import ElectrumTestCase, as_testnet, restore_wallet_from_text__for_unittest
@@ -30,7 +35,7 @@ class DaemonTestCase(ElectrumTestCase):
await self.daemon.stop()
await super().asyncTearDown()
def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None) -> str:
def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None, **kwargs) -> str:
"""Returns path for created wallet."""
basename = util.get_new_wallet_name(self.wallet_dir)
path = os.path.join(self.wallet_dir, basename)
@@ -40,6 +45,7 @@ class DaemonTestCase(ElectrumTestCase):
password=password,
encrypt_file=encrypt_file,
config=self.config,
**kwargs,
)
# We return the path instead of the wallet object, as extreme
# care would be needed to use the wallet object directly:
@@ -188,6 +194,40 @@ class TestUnifiedPassword(DaemonTestCase):
self.assertTrue(is_unified)
self._run_post_unif_sanity_checks(paths, password="123456")
# misc --->
async def test_wallet_objects_are_properly_garbage_collected_after_check_pw_for_dir(self):
orig_cb_count = util.callback_mgr.count_all_callbacks()
# GC sanity-check:
mclasses = [Abstract_Wallet, LNWallet, LNWatcher, LNPeerManager]
objmap = count_objects_in_memory(mclasses)
for mcls in mclasses:
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
# restore some wallets
paths = []
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True))
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=False))
paths.append(self._restore_wallet_from_text("9dk", password=None))
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True, passphrase="hunter2"))
paths.append(self._restore_wallet_from_text("9dk", password="999999", encrypt_file=False, passphrase="hunter2"))
paths.append(self._restore_wallet_from_text("9dk", password=None, passphrase="hunter2"))
# test unification
can_be_unified, is_unified, paths_succeeded = self.daemon.check_password_for_directory(old_password="123456", wallet_dir=self.wallet_dir)
self.assertEqual((False, False, 5), (can_be_unified, is_unified, len(paths_succeeded)))
# gc
try:
async with util.async_timeout(5):
while True:
objmap = count_objects_in_memory(mclasses)
if sum(len(lst) for lst in objmap.values()) == 0:
break # all "mclasses"-type objects have been GC-ed
await asyncio.sleep(0.01)
except asyncio.TimeoutError:
for mcls in mclasses:
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
# also check callbacks have been cleaned up:
self.assertEqual(orig_cb_count, util.callback_mgr.count_all_callbacks())
class TestCommandsWithDaemon(DaemonTestCase):
TESTNET = True