1
0

util.EventListener: store WeakMethods in CallbackManager to avoid leaks

This patch changes the CallbackManager to use WeakMethods (weakrefs) to
break the ref cycle and allow the GC to clean up the wallet objects.
unregister_callbacks() will also get called automatically, from
EventListener.__del__, to clean up the CallbackManager.

I also added a few unit tests for this.

fixes https://github.com/spesmilo/electrum/issues/10427

-----

original problem:

In many subclasses of `EventListener`, such as `Abstract_Wallet`, `LNWatcher`,
`LNPeerManager`, we call `register_callbacks()` in `__init__`.
`unregister_callbacks()` is usually called in the `stop()` method.

Example - consider the wallet object:
- `Abstract_Wallet.__init__()` calls `register_callbacks()`
- there is a `start_network()` method
- there is a `stop()` method, which calls `unregister_callbacks()`
- typically the wallet API user only calls `stop()` if they also called
  `start_network()`.

This means the callbacks are often left registered, leading to the wallet
objects not getting GC-ed. The GC won't clean them up as
`util.callback_mgr.callbacks` stores strong refs to instance methods
of `Abstract_Wallet`, hence strong refs to the `Abstract_Wallet` objects.

An annoying example is `daemon.check_password_for_directory`, which
potentially creates wallet objects for all wallet files in the datadir.
It simply constructs the wallets, does not call `start_network()` and
neither does it call `stop()`.
This commit is contained in:
SomberNight
2026-01-20 22:27:14 +00:00
parent 15067be527
commit 87540dbe3e
4 changed files with 159 additions and 20 deletions

View File

@@ -57,6 +57,7 @@ import enum
from contextlib import nullcontext, suppress
import traceback
import inspect
import weakref
import aiohttp
from aiohttp_socks import ProxyConnector, ProxyType
@@ -1954,22 +1955,38 @@ class CallbackManager(Logger):
def __init__(self):
Logger.__init__(self)
self.callback_lock = threading.Lock()
self.callbacks = defaultdict(set) # type: Dict[str, Set[Callable]] # note: needs self.callback_lock
self._wcallbacks = defaultdict(set) # type: Dict[str, Set[weakref.ref[Callable]]] # note: needs self.callback_lock
def register_callback(self, func: Callable, events: Sequence[str]) -> None:
@staticmethod
def _wcb_from_any_callback(cb: Callable) -> weakref.ref[Callable]:
assert callable(cb), type(cb)
if isinstance(cb, weakref.ref): # no-op
return cb
elif inspect.ismethod(cb): # instance method, such as for a subclass of EventListener
return WeakMethodProper(cb)
else: # proper function? e.g. used by lnpeer unit tests
return weakref.ref(cb)
def register_callback(self, cb: Callable, events: Sequence[str]) -> None:
wcb = self._wcb_from_any_callback(cb)
with self.callback_lock:
for event in events:
self.callbacks[event].add(func)
self._wcallbacks[event].add(wcb)
def unregister_callback(self, callback: Callable) -> None:
def unregister_callback(self, cb: Callable) -> None:
wcb = self._wcb_from_any_callback(cb)
with self.callback_lock:
for callbacks in self.callbacks.values():
if callback in callbacks:
callbacks.remove(callback)
for callbacks in self._wcallbacks.values():
if wcb in callbacks:
callbacks.remove(wcb)
def count_all_callbacks(self) -> int:
with self.callback_lock:
return sum(len(cbs) for cbs in self._wcallbacks.values())
def clear_all_callbacks(self) -> None:
with self.callback_lock:
self.callbacks.clear()
self._wcallbacks.clear()
def trigger_callback(self, event: str, *args) -> None:
"""Trigger a callback with given arguments.
@@ -1979,8 +1996,11 @@ class CallbackManager(Logger):
loop = get_asyncio_loop()
assert loop.is_running(), "event loop not running"
with self.callback_lock:
callbacks = copy.copy(self.callbacks[event])
for callback in callbacks:
wcallbacks = copy.copy(self._wcallbacks[event])
for wcb in wcallbacks:
callback = wcb()
if callback is None:
continue
if inspect.iscoroutinefunction(callback): # async cb
fut = asyncio.run_coroutine_threadsafe(callback(*args), loop)
@@ -2002,14 +2022,28 @@ unregister_callback = callback_mgr.unregister_callback
_event_listeners = defaultdict(set) # type: Dict[str, Set[str]]
class WeakMethodProper(weakref.WeakMethod):
"""Unlike weakref.WeakMethod, this class has an __eq__ I can trust."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
meth = self()
self._my_id = (id(meth.__self__), id(meth.__func__))
def __hash__(self):
return hash(self._my_id)
def __eq__(self, other):
if not isinstance(other, WeakMethodProper):
return False
return self._my_id == other._my_id
class EventListener:
"""Use as a mixin for a class that has methods to be triggered on events.
- Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener.
- register_callbacks() should be called once per instance of EventListener, e.g. in __init__
- unregister_callbacks() should be called at least once, e.g. when the instance is destroyed
- if register_callbacks() is called in __init__, as opposed to a separate start() method,
extra care is needed that the call to unregister_callbacks() is not forgotten,
otherwise we will leak memory
- as fallback, __del__() also calls unregister_callbacks()
"""
def _list_callbacks(self):
@@ -2031,6 +2065,9 @@ class EventListener:
#_logger.debug(f'unregistering callback {method}')
unregister_callback(method)
def __del__(self):
self.unregister_callbacks()
def event_listener(func):
"""To be used in subclasses of EventListener only. (how to enforce this programmatically?)"""

View File

@@ -1,3 +1,4 @@
import asyncio
from collections import defaultdict
import datetime
import os
@@ -5,6 +6,7 @@ import time
from typing import Sequence, Mapping, TypeVar, Optional
import weakref
from electrum import util
from electrum.util import ThreadJob
@@ -52,6 +54,26 @@ class DebugMem(ThreadJob):
self.next_time = time.time() + self.interval
async def wait_until_obj_is_garbage_collected(wr: weakref.ref) -> None:
"""Async wait until the object referenced by `wr` is GC-ed."""
obj = wr()
if obj is None:
return
evt_gc = asyncio.Event() # set when obj is finally GC-ed.
wr2 = weakref.ref(obj, lambda _x: util.run_sync_function_on_asyncio_thread(evt_gc.set, block=False))
del obj
while True:
try:
async with util.async_timeout(0.01):
await evt_gc.wait()
except asyncio.TimeoutError:
import gc
gc.collect()
else:
break
assert evt_gc.is_set()
def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
"""Return a string listing the most common types in memory."""
import objgraph # 3rd-party dependency

View File

@@ -1,16 +1,20 @@
import asyncio
import weakref
from electrum import util
from electrum.util import EventListener, event_listener, trigger_callback
from electrum.utils.memory_leak import count_objects_in_memory
from electrum.utils.memory_leak import count_objects_in_memory, wait_until_obj_is_garbage_collected
from electrum.simple_config import SimpleConfig
from . import ElectrumTestCase
from . import ElectrumTestCase, restore_wallet_from_text__for_unittest
class MyEventListener(EventListener):
def __init__(self):
def __init__(self, *, autostart: bool = False):
self._satoshi_cnt = 0
self._hal_cnt = 0
if autostart:
self.start()
def start(self):
self.register_callbacks()
@@ -27,8 +31,7 @@ class MyEventListener(EventListener):
self._hal_cnt += 1
def _count_all_callbacks() -> int:
return sum(len(cbs) for cbs in util.callback_mgr.callbacks.values())
_count_all_callbacks = util.callback_mgr.count_all_callbacks
async def fast_sleep():
@@ -89,13 +92,50 @@ class TestCallbackMgr(ElectrumTestCase):
async def test_gc(self):
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 0)
self.assertEqual(_count_all_callbacks(), 0)
el1 = MyEventListener()
el1.start()
el2 = MyEventListener()
el2.start()
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 2)
el1.stop()
self.assertEqual(_count_all_callbacks(), 4)
# test if we can get GC-ed if we explicitly unregister cbs:
el1.stop() # calls unregister_callbacks
del el1
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 1)
self.assertEqual(_count_all_callbacks(), 2)
# test if we can get GC-ed even without unregistering cbs:
del el2
objmap = count_objects_in_memory([MyEventListener])
self.assertEqual(len(objmap[MyEventListener]), 0)
self.assertEqual(_count_all_callbacks(), 0)
async def test_gc2(self):
def func():
el1 = MyEventListener(autostart=True)
el1.el2 = MyEventListener(autostart=True)
el1.el2.el3 = MyEventListener(autostart=True)
self.assertEqual(_count_all_callbacks(), 6)
func()
self.assertEqual(_count_all_callbacks(), 0)
async def test_gc_complex_using_wallet(self):
"""This test showcases why EventListener uses WeakMethodProper instead of weakref.WeakMethod.
We need the custom __eq__ for some reason.
"""
self.assertEqual(_count_all_callbacks(), 0)
config = SimpleConfig({'electrum_path': self.electrum_path})
wallet = restore_wallet_from_text__for_unittest(
"9dk", path=None, config=config,
)["wallet"]
assert wallet.lnworker is not None
# now delete the wallet, and wait for it to get GC-ed
# note: need to wait for cyclic GC. example: wallet.lnworker.wallet
wr = weakref.ref(wallet)
del wallet
async with util.async_timeout(5):
await wait_until_obj_is_garbage_collected(wr)
# by now, all callbacks must have been cleaned up:
self.assertEqual(_count_all_callbacks(), 0)

View File

@@ -1,3 +1,5 @@
import asyncio
from collections import defaultdict
import os
from typing import Optional, Iterable
@@ -5,7 +7,10 @@ from electrum.commands import Commands
from electrum.daemon import Daemon
from electrum.simple_config import SimpleConfig
from electrum.wallet import Abstract_Wallet
from electrum.lnworker import LNWallet, LNPeerManager
from electrum.lnwatcher import LNWatcher
from electrum import util
from electrum.utils.memory_leak import count_objects_in_memory
from . import ElectrumTestCase, as_testnet, restore_wallet_from_text__for_unittest
@@ -30,7 +35,7 @@ class DaemonTestCase(ElectrumTestCase):
await self.daemon.stop()
await super().asyncTearDown()
def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None) -> str:
def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None, **kwargs) -> str:
"""Returns path for created wallet."""
basename = util.get_new_wallet_name(self.wallet_dir)
path = os.path.join(self.wallet_dir, basename)
@@ -40,6 +45,7 @@ class DaemonTestCase(ElectrumTestCase):
password=password,
encrypt_file=encrypt_file,
config=self.config,
**kwargs,
)
# We return the path instead of the wallet object, as extreme
# care would be needed to use the wallet object directly:
@@ -188,6 +194,40 @@ class TestUnifiedPassword(DaemonTestCase):
self.assertTrue(is_unified)
self._run_post_unif_sanity_checks(paths, password="123456")
# misc --->
async def test_wallet_objects_are_properly_garbage_collected_after_check_pw_for_dir(self):
orig_cb_count = util.callback_mgr.count_all_callbacks()
# GC sanity-check:
mclasses = [Abstract_Wallet, LNWallet, LNWatcher, LNPeerManager]
objmap = count_objects_in_memory(mclasses)
for mcls in mclasses:
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
# restore some wallets
paths = []
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True))
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=False))
paths.append(self._restore_wallet_from_text("9dk", password=None))
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True, passphrase="hunter2"))
paths.append(self._restore_wallet_from_text("9dk", password="999999", encrypt_file=False, passphrase="hunter2"))
paths.append(self._restore_wallet_from_text("9dk", password=None, passphrase="hunter2"))
# test unification
can_be_unified, is_unified, paths_succeeded = self.daemon.check_password_for_directory(old_password="123456", wallet_dir=self.wallet_dir)
self.assertEqual((False, False, 5), (can_be_unified, is_unified, len(paths_succeeded)))
# gc
try:
async with util.async_timeout(5):
while True:
objmap = count_objects_in_memory(mclasses)
if sum(len(lst) for lst in objmap.values()) == 0:
break # all "mclasses"-type objects have been GC-ed
await asyncio.sleep(0.01)
except asyncio.TimeoutError:
for mcls in mclasses:
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
# also check callbacks have been cleaned up:
self.assertEqual(orig_cb_count, util.callback_mgr.count_all_callbacks())
class TestCommandsWithDaemon(DaemonTestCase):
TESTNET = True