Merge pull request #10428 from SomberNight/202601_wallet_objs_not_gced
fix util.CallbackManager memory leak
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import concurrent.futures
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import os
|
||||
@@ -56,6 +57,7 @@ import enum
|
||||
from contextlib import nullcontext, suppress
|
||||
import traceback
|
||||
import inspect
|
||||
import weakref
|
||||
|
||||
import aiohttp
|
||||
from aiohttp_socks import ProxyConnector, ProxyType
|
||||
@@ -1956,23 +1958,40 @@ class CallbackManager(Logger):
|
||||
|
||||
def __init__(self):
|
||||
Logger.__init__(self)
|
||||
self.callback_lock = threading.Lock()
|
||||
self.callbacks = defaultdict(list) # type: Dict[str, List[Callable]] # note: needs self.callback_lock
|
||||
self.callback_lock = threading.RLock()
|
||||
self._wcallbacks = defaultdict(set) # type: Dict[str, Set[weakref.ref[Callable]]] # note: needs self.callback_lock
|
||||
|
||||
def register_callback(self, func: Callable, events: Sequence[str]) -> None:
|
||||
@staticmethod
|
||||
def _wcb_from_any_callback(cb: Callable) -> weakref.ref[Callable]:
|
||||
assert callable(cb), type(cb)
|
||||
if isinstance(cb, weakref.ref): # no-op
|
||||
return cb
|
||||
elif inspect.ismethod(cb): # instance method, such as for a subclass of EventListener
|
||||
return WeakMethodProper(cb)
|
||||
else: # proper function? e.g. used by lnpeer unit tests
|
||||
return weakref.ref(cb)
|
||||
|
||||
def register_callback(self, cb: Callable, events: Sequence[str]) -> None:
|
||||
wcb = self._wcb_from_any_callback(cb)
|
||||
with self.callback_lock:
|
||||
for event in events:
|
||||
self.callbacks[event].append(func)
|
||||
self._wcallbacks[event].add(wcb)
|
||||
|
||||
def unregister_callback(self, callback: Callable) -> None:
|
||||
def unregister_callback(self, cb: Callable) -> None:
|
||||
wcb = self._wcb_from_any_callback(cb)
|
||||
with self.callback_lock:
|
||||
for callbacks in self.callbacks.values():
|
||||
if callback in callbacks:
|
||||
callbacks.remove(callback)
|
||||
# note: ^ callback_lock needs to be re-entrant, as we can now trigger __del__, which also takes the lock
|
||||
for callbacks in self._wcallbacks.values():
|
||||
if wcb in callbacks:
|
||||
callbacks.remove(wcb)
|
||||
|
||||
def count_all_callbacks(self) -> int:
|
||||
with self.callback_lock:
|
||||
return sum(len(cbs) for cbs in self._wcallbacks.values())
|
||||
|
||||
def clear_all_callbacks(self) -> None:
|
||||
with self.callback_lock:
|
||||
self.callbacks.clear()
|
||||
self._wcallbacks.clear()
|
||||
|
||||
def trigger_callback(self, event: str, *args) -> None:
|
||||
"""Trigger a callback with given arguments.
|
||||
@@ -1982,8 +2001,11 @@ class CallbackManager(Logger):
|
||||
loop = get_asyncio_loop()
|
||||
assert loop.is_running(), "event loop not running"
|
||||
with self.callback_lock:
|
||||
callbacks = self.callbacks[event][:]
|
||||
for callback in callbacks:
|
||||
wcallbacks = copy.copy(self._wcallbacks[event])
|
||||
for wcb in wcallbacks:
|
||||
callback = wcb()
|
||||
if callback is None:
|
||||
continue
|
||||
if inspect.iscoroutinefunction(callback): # async cb
|
||||
fut = asyncio.run_coroutine_threadsafe(callback(*args), loop)
|
||||
|
||||
@@ -2005,11 +2027,28 @@ unregister_callback = callback_mgr.unregister_callback
|
||||
_event_listeners = defaultdict(set) # type: Dict[str, Set[str]]
|
||||
|
||||
|
||||
class WeakMethodProper(weakref.WeakMethod):
|
||||
"""Unlike weakref.WeakMethod, this class has an __eq__ I can trust."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
meth = self()
|
||||
self._my_id = (id(meth.__self__), id(meth.__func__))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._my_id)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, WeakMethodProper):
|
||||
return False
|
||||
return self._my_id == other._my_id
|
||||
|
||||
|
||||
class EventListener:
|
||||
"""Use as a mixin for a class that has methods to be triggered on events.
|
||||
- Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener.
|
||||
- register_callbacks() should be called exactly once per instance of EventListener, e.g. in __init__
|
||||
- register_callbacks() should be called once per instance of EventListener, e.g. in __init__
|
||||
- unregister_callbacks() should be called at least once, e.g. when the instance is destroyed
|
||||
- as fallback, __del__() also calls unregister_callbacks()
|
||||
"""
|
||||
|
||||
def _list_callbacks(self):
|
||||
@@ -2031,6 +2070,9 @@ class EventListener:
|
||||
#_logger.debug(f'unregistering callback {method}')
|
||||
unregister_callback(method)
|
||||
|
||||
def __del__(self):
|
||||
self.unregister_callbacks()
|
||||
|
||||
|
||||
def event_listener(func):
|
||||
"""To be used in subclasses of EventListener only. (how to enforce this programmatically?)"""
|
||||
|
||||
@@ -1,11 +1,32 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
from typing import Sequence, Mapping, TypeVar, Optional
|
||||
import weakref
|
||||
|
||||
from electrum import util
|
||||
from electrum.util import ThreadJob
|
||||
|
||||
|
||||
_U = TypeVar('_U')
|
||||
|
||||
def count_objects_in_memory(mclasses: Sequence[type[_U]]) -> Mapping[type[_U], Sequence[weakref.ref[_U]]]:
|
||||
import gc
|
||||
gc.collect()
|
||||
objmap = defaultdict(list)
|
||||
for obj in gc.get_objects():
|
||||
for class_ in mclasses:
|
||||
try:
|
||||
_isinstance = isinstance(obj, class_)
|
||||
except AttributeError:
|
||||
_isinstance = False
|
||||
if _isinstance:
|
||||
objmap[class_].append(weakref.ref(obj))
|
||||
return objmap
|
||||
|
||||
|
||||
class DebugMem(ThreadJob):
|
||||
'''A handy class for debugging GC memory leaks
|
||||
|
||||
@@ -21,18 +42,8 @@ class DebugMem(ThreadJob):
|
||||
self.interval = interval
|
||||
|
||||
def mem_stats(self):
|
||||
import gc
|
||||
self.logger.info("Start memscan")
|
||||
gc.collect()
|
||||
objmap = defaultdict(list)
|
||||
for obj in gc.get_objects():
|
||||
for class_ in self.classes:
|
||||
try:
|
||||
_isinstance = isinstance(obj, class_)
|
||||
except AttributeError:
|
||||
_isinstance = False
|
||||
if _isinstance:
|
||||
objmap[class_].append(obj)
|
||||
objmap = count_objects_in_memory(self.classes)
|
||||
for class_, objs in objmap.items():
|
||||
self.logger.info(f"{class_.__name__}: {len(objs)}")
|
||||
self.logger.info("Finish memscan")
|
||||
@@ -43,6 +54,26 @@ class DebugMem(ThreadJob):
|
||||
self.next_time = time.time() + self.interval
|
||||
|
||||
|
||||
async def wait_until_obj_is_garbage_collected(wr: weakref.ref) -> None:
|
||||
"""Async wait until the object referenced by `wr` is GC-ed."""
|
||||
obj = wr()
|
||||
if obj is None:
|
||||
return
|
||||
evt_gc = asyncio.Event() # set when obj is finally GC-ed.
|
||||
wr2 = weakref.ref(obj, lambda _x: util.run_sync_function_on_asyncio_thread(evt_gc.set, block=False))
|
||||
del obj
|
||||
while True:
|
||||
try:
|
||||
async with util.async_timeout(0.01):
|
||||
await evt_gc.wait()
|
||||
except asyncio.TimeoutError:
|
||||
import gc
|
||||
gc.collect()
|
||||
else:
|
||||
break
|
||||
assert evt_gc.is_set()
|
||||
|
||||
|
||||
def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
|
||||
"""Return a string listing the most common types in memory."""
|
||||
import objgraph # 3rd-party dependency
|
||||
@@ -52,7 +83,7 @@ def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
|
||||
)
|
||||
|
||||
|
||||
def debug_memusage_dump_random_backref_chain(objtype: str) -> str:
|
||||
def debug_memusage_dump_random_backref_chain(objtype: str) -> Optional[str]:
|
||||
"""Writes a dotfile to cwd, containing the backref chain
|
||||
for a randomly selected object of type objtype.
|
||||
|
||||
@@ -68,10 +99,14 @@ def debug_memusage_dump_random_backref_chain(objtype: str) -> str:
|
||||
import random
|
||||
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
fpath = os.path.abspath(f"electrum_backref_chain_{timestamp}.dot")
|
||||
objects = objgraph.by_type(objtype)
|
||||
if not objects:
|
||||
return None
|
||||
random_obj = random.choice(objects)
|
||||
with open(fpath, "w") as f:
|
||||
objgraph.show_chain(
|
||||
objgraph.find_backref_chain(
|
||||
random.choice(objgraph.by_type(objtype)),
|
||||
random_obj,
|
||||
objgraph.is_proper_module),
|
||||
output=f)
|
||||
return fpath
|
||||
|
||||
@@ -562,11 +562,12 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
||||
self.unregister_callbacks()
|
||||
try:
|
||||
async with ignore_after(5):
|
||||
if self.lnworker:
|
||||
await self.lnworker.stop()
|
||||
self.lnworker = None
|
||||
if self.network:
|
||||
if self.lnworker:
|
||||
await self.lnworker.stop()
|
||||
self.lnworker = None
|
||||
self.network = None
|
||||
if self.taskgroup:
|
||||
await self.taskgroup.cancel_remaining()
|
||||
self.taskgroup = None
|
||||
await self.adb.stop()
|
||||
|
||||
141
tests/test_callbackmgr.py
Normal file
141
tests/test_callbackmgr.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import asyncio
|
||||
import weakref
|
||||
|
||||
from electrum import util
|
||||
from electrum.util import EventListener, event_listener, trigger_callback
|
||||
from electrum.utils.memory_leak import count_objects_in_memory, wait_until_obj_is_garbage_collected
|
||||
from electrum.simple_config import SimpleConfig
|
||||
|
||||
from . import ElectrumTestCase, restore_wallet_from_text__for_unittest
|
||||
|
||||
|
||||
class MyEventListener(EventListener):
|
||||
def __init__(self, *, autostart: bool = False):
|
||||
self._satoshi_cnt = 0
|
||||
self._hal_cnt = 0
|
||||
if autostart:
|
||||
self.start()
|
||||
|
||||
def start(self):
|
||||
self.register_callbacks()
|
||||
|
||||
def stop(self):
|
||||
self.unregister_callbacks()
|
||||
|
||||
@event_listener
|
||||
async def on_event_satoshi_moves_his_coins(self, arg1, arg2):
|
||||
self._satoshi_cnt += 1
|
||||
|
||||
@event_listener
|
||||
def on_event_hal_moves_his_coins(self, arg1, arg2): # non-async
|
||||
self._hal_cnt += 1
|
||||
|
||||
|
||||
_count_all_callbacks = util.callback_mgr.count_all_callbacks
|
||||
|
||||
|
||||
async def fast_sleep():
|
||||
# sleep a few event loop iterations
|
||||
for i in range(5):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
class TestCallbackMgr(ElectrumTestCase):
|
||||
|
||||
def test_multiple_calls_to_register_callbacks(self):
|
||||
self.assertEqual(0, _count_all_callbacks())
|
||||
el1 = MyEventListener()
|
||||
el2 = MyEventListener()
|
||||
self.assertEqual(0, _count_all_callbacks())
|
||||
el1.start()
|
||||
self.assertEqual(2, _count_all_callbacks())
|
||||
el2.start()
|
||||
self.assertEqual(4, _count_all_callbacks())
|
||||
el1.start()
|
||||
self.assertEqual(4, _count_all_callbacks())
|
||||
el1.stop()
|
||||
self.assertEqual(2, _count_all_callbacks())
|
||||
el1.stop()
|
||||
self.assertEqual(2, _count_all_callbacks())
|
||||
el1.stop()
|
||||
self.assertEqual(2, _count_all_callbacks())
|
||||
el2.stop()
|
||||
self.assertEqual(0, _count_all_callbacks())
|
||||
|
||||
async def test_trigger_callback(self):
|
||||
el1 = MyEventListener()
|
||||
el1.start()
|
||||
el2 = MyEventListener()
|
||||
el2.start()
|
||||
# trigger some cbs
|
||||
self.assertEqual(el1._satoshi_cnt, 0)
|
||||
self.assertEqual(el1._hal_cnt, 0)
|
||||
trigger_callback('satoshi_moves_his_coins', 0, 0)
|
||||
trigger_callback('satoshi_moves_his_coins', 0, 0)
|
||||
trigger_callback('satoshi_moves_his_coins', 0, 0)
|
||||
trigger_callback('hal_moves_his_coins', 0, 0)
|
||||
await fast_sleep()
|
||||
self.assertEqual(el1._satoshi_cnt, 3)
|
||||
self.assertEqual(el2._satoshi_cnt, 3)
|
||||
self.assertEqual(el1._hal_cnt, 1)
|
||||
self.assertEqual(el2._hal_cnt, 1)
|
||||
# stop one listener, see new triggers are only seen by other one still running
|
||||
el1.stop()
|
||||
trigger_callback('satoshi_moves_his_coins', 0, 0)
|
||||
trigger_callback('hal_moves_his_coins', 0, 0)
|
||||
await fast_sleep()
|
||||
self.assertEqual(el1._satoshi_cnt, 3)
|
||||
self.assertEqual(el2._satoshi_cnt, 4)
|
||||
self.assertEqual(el1._hal_cnt, 1)
|
||||
self.assertEqual(el2._hal_cnt, 2)
|
||||
|
||||
async def test_gc(self):
|
||||
objmap = count_objects_in_memory([MyEventListener])
|
||||
self.assertEqual(len(objmap[MyEventListener]), 0)
|
||||
self.assertEqual(_count_all_callbacks(), 0)
|
||||
el1 = MyEventListener()
|
||||
el1.start()
|
||||
el2 = MyEventListener()
|
||||
el2.start()
|
||||
objmap = count_objects_in_memory([MyEventListener])
|
||||
self.assertEqual(len(objmap[MyEventListener]), 2)
|
||||
self.assertEqual(_count_all_callbacks(), 4)
|
||||
# test if we can get GC-ed if we explicitly unregister cbs:
|
||||
el1.stop() # calls unregister_callbacks
|
||||
del el1
|
||||
objmap = count_objects_in_memory([MyEventListener])
|
||||
self.assertEqual(len(objmap[MyEventListener]), 1)
|
||||
self.assertEqual(_count_all_callbacks(), 2)
|
||||
# test if we can get GC-ed even without unregistering cbs:
|
||||
del el2
|
||||
objmap = count_objects_in_memory([MyEventListener])
|
||||
self.assertEqual(len(objmap[MyEventListener]), 0)
|
||||
self.assertEqual(_count_all_callbacks(), 0)
|
||||
|
||||
async def test_gc2(self):
|
||||
def func():
|
||||
el1 = MyEventListener(autostart=True)
|
||||
el1.el2 = MyEventListener(autostart=True)
|
||||
el1.el2.el3 = MyEventListener(autostart=True)
|
||||
self.assertEqual(_count_all_callbacks(), 6)
|
||||
func()
|
||||
self.assertEqual(_count_all_callbacks(), 0)
|
||||
|
||||
async def test_gc_complex_using_wallet(self):
|
||||
"""This test showcases why EventListener uses WeakMethodProper instead of weakref.WeakMethod.
|
||||
We need the custom __eq__ for some reason.
|
||||
"""
|
||||
self.assertEqual(_count_all_callbacks(), 0)
|
||||
config = SimpleConfig({'electrum_path': self.electrum_path})
|
||||
wallet = restore_wallet_from_text__for_unittest(
|
||||
"9dk", path=None, config=config,
|
||||
)["wallet"]
|
||||
assert wallet.lnworker is not None
|
||||
# now delete the wallet, and wait for it to get GC-ed
|
||||
# note: need to wait for cyclic GC. example: wallet.lnworker.wallet
|
||||
wr = weakref.ref(wallet)
|
||||
del wallet
|
||||
async with util.async_timeout(5):
|
||||
await wait_until_obj_is_garbage_collected(wr)
|
||||
# by now, all callbacks must have been cleaned up:
|
||||
self.assertEqual(_count_all_callbacks(), 0)
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from typing import Optional, Iterable
|
||||
|
||||
@@ -5,7 +7,10 @@ from electrum.commands import Commands
|
||||
from electrum.daemon import Daemon
|
||||
from electrum.simple_config import SimpleConfig
|
||||
from electrum.wallet import Abstract_Wallet
|
||||
from electrum.lnworker import LNWallet, LNPeerManager
|
||||
from electrum.lnwatcher import LNWatcher
|
||||
from electrum import util
|
||||
from electrum.utils.memory_leak import count_objects_in_memory
|
||||
|
||||
from . import ElectrumTestCase, as_testnet, restore_wallet_from_text__for_unittest
|
||||
|
||||
@@ -30,7 +35,7 @@ class DaemonTestCase(ElectrumTestCase):
|
||||
await self.daemon.stop()
|
||||
await super().asyncTearDown()
|
||||
|
||||
def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None) -> str:
|
||||
def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None, **kwargs) -> str:
|
||||
"""Returns path for created wallet."""
|
||||
basename = util.get_new_wallet_name(self.wallet_dir)
|
||||
path = os.path.join(self.wallet_dir, basename)
|
||||
@@ -40,6 +45,7 @@ class DaemonTestCase(ElectrumTestCase):
|
||||
password=password,
|
||||
encrypt_file=encrypt_file,
|
||||
config=self.config,
|
||||
**kwargs,
|
||||
)
|
||||
# We return the path instead of the wallet object, as extreme
|
||||
# care would be needed to use the wallet object directly:
|
||||
@@ -188,6 +194,40 @@ class TestUnifiedPassword(DaemonTestCase):
|
||||
self.assertTrue(is_unified)
|
||||
self._run_post_unif_sanity_checks(paths, password="123456")
|
||||
|
||||
# misc --->
|
||||
|
||||
async def test_wallet_objects_are_properly_garbage_collected_after_check_pw_for_dir(self):
|
||||
orig_cb_count = util.callback_mgr.count_all_callbacks()
|
||||
# GC sanity-check:
|
||||
mclasses = [Abstract_Wallet, LNWallet, LNWatcher, LNPeerManager]
|
||||
objmap = count_objects_in_memory(mclasses)
|
||||
for mcls in mclasses:
|
||||
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
|
||||
# restore some wallets
|
||||
paths = []
|
||||
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True))
|
||||
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=False))
|
||||
paths.append(self._restore_wallet_from_text("9dk", password=None))
|
||||
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True, passphrase="hunter2"))
|
||||
paths.append(self._restore_wallet_from_text("9dk", password="999999", encrypt_file=False, passphrase="hunter2"))
|
||||
paths.append(self._restore_wallet_from_text("9dk", password=None, passphrase="hunter2"))
|
||||
# test unification
|
||||
can_be_unified, is_unified, paths_succeeded = self.daemon.check_password_for_directory(old_password="123456", wallet_dir=self.wallet_dir)
|
||||
self.assertEqual((False, False, 5), (can_be_unified, is_unified, len(paths_succeeded)))
|
||||
# gc
|
||||
try:
|
||||
async with util.async_timeout(5):
|
||||
while True:
|
||||
objmap = count_objects_in_memory(mclasses)
|
||||
if sum(len(lst) for lst in objmap.values()) == 0:
|
||||
break # all "mclasses"-type objects have been GC-ed
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.TimeoutError:
|
||||
for mcls in mclasses:
|
||||
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
|
||||
# also check callbacks have been cleaned up:
|
||||
self.assertEqual(orig_cb_count, util.callback_mgr.count_all_callbacks())
|
||||
|
||||
|
||||
class TestCommandsWithDaemon(DaemonTestCase):
|
||||
TESTNET = True
|
||||
|
||||
Reference in New Issue
Block a user