Merge pull request #10428 from SomberNight/202601_wallet_objs_not_gced
fix util.CallbackManager memory leak
This commit is contained in:
@@ -21,6 +21,7 @@
|
|||||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -56,6 +57,7 @@ import enum
|
|||||||
from contextlib import nullcontext, suppress
|
from contextlib import nullcontext, suppress
|
||||||
import traceback
|
import traceback
|
||||||
import inspect
|
import inspect
|
||||||
|
import weakref
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from aiohttp_socks import ProxyConnector, ProxyType
|
from aiohttp_socks import ProxyConnector, ProxyType
|
||||||
@@ -1956,23 +1958,40 @@ class CallbackManager(Logger):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
Logger.__init__(self)
|
Logger.__init__(self)
|
||||||
self.callback_lock = threading.Lock()
|
self.callback_lock = threading.RLock()
|
||||||
self.callbacks = defaultdict(list) # type: Dict[str, List[Callable]] # note: needs self.callback_lock
|
self._wcallbacks = defaultdict(set) # type: Dict[str, Set[weakref.ref[Callable]]] # note: needs self.callback_lock
|
||||||
|
|
||||||
def register_callback(self, func: Callable, events: Sequence[str]) -> None:
|
@staticmethod
|
||||||
|
def _wcb_from_any_callback(cb: Callable) -> weakref.ref[Callable]:
|
||||||
|
assert callable(cb), type(cb)
|
||||||
|
if isinstance(cb, weakref.ref): # no-op
|
||||||
|
return cb
|
||||||
|
elif inspect.ismethod(cb): # instance method, such as for a subclass of EventListener
|
||||||
|
return WeakMethodProper(cb)
|
||||||
|
else: # proper function? e.g. used by lnpeer unit tests
|
||||||
|
return weakref.ref(cb)
|
||||||
|
|
||||||
|
def register_callback(self, cb: Callable, events: Sequence[str]) -> None:
|
||||||
|
wcb = self._wcb_from_any_callback(cb)
|
||||||
with self.callback_lock:
|
with self.callback_lock:
|
||||||
for event in events:
|
for event in events:
|
||||||
self.callbacks[event].append(func)
|
self._wcallbacks[event].add(wcb)
|
||||||
|
|
||||||
def unregister_callback(self, callback: Callable) -> None:
|
def unregister_callback(self, cb: Callable) -> None:
|
||||||
|
wcb = self._wcb_from_any_callback(cb)
|
||||||
with self.callback_lock:
|
with self.callback_lock:
|
||||||
for callbacks in self.callbacks.values():
|
# note: ^ callback_lock needs to be re-entrant, as we can now trigger __del__, which also takes the lock
|
||||||
if callback in callbacks:
|
for callbacks in self._wcallbacks.values():
|
||||||
callbacks.remove(callback)
|
if wcb in callbacks:
|
||||||
|
callbacks.remove(wcb)
|
||||||
|
|
||||||
|
def count_all_callbacks(self) -> int:
|
||||||
|
with self.callback_lock:
|
||||||
|
return sum(len(cbs) for cbs in self._wcallbacks.values())
|
||||||
|
|
||||||
def clear_all_callbacks(self) -> None:
|
def clear_all_callbacks(self) -> None:
|
||||||
with self.callback_lock:
|
with self.callback_lock:
|
||||||
self.callbacks.clear()
|
self._wcallbacks.clear()
|
||||||
|
|
||||||
def trigger_callback(self, event: str, *args) -> None:
|
def trigger_callback(self, event: str, *args) -> None:
|
||||||
"""Trigger a callback with given arguments.
|
"""Trigger a callback with given arguments.
|
||||||
@@ -1982,8 +2001,11 @@ class CallbackManager(Logger):
|
|||||||
loop = get_asyncio_loop()
|
loop = get_asyncio_loop()
|
||||||
assert loop.is_running(), "event loop not running"
|
assert loop.is_running(), "event loop not running"
|
||||||
with self.callback_lock:
|
with self.callback_lock:
|
||||||
callbacks = self.callbacks[event][:]
|
wcallbacks = copy.copy(self._wcallbacks[event])
|
||||||
for callback in callbacks:
|
for wcb in wcallbacks:
|
||||||
|
callback = wcb()
|
||||||
|
if callback is None:
|
||||||
|
continue
|
||||||
if inspect.iscoroutinefunction(callback): # async cb
|
if inspect.iscoroutinefunction(callback): # async cb
|
||||||
fut = asyncio.run_coroutine_threadsafe(callback(*args), loop)
|
fut = asyncio.run_coroutine_threadsafe(callback(*args), loop)
|
||||||
|
|
||||||
@@ -2005,11 +2027,28 @@ unregister_callback = callback_mgr.unregister_callback
|
|||||||
_event_listeners = defaultdict(set) # type: Dict[str, Set[str]]
|
_event_listeners = defaultdict(set) # type: Dict[str, Set[str]]
|
||||||
|
|
||||||
|
|
||||||
|
class WeakMethodProper(weakref.WeakMethod):
|
||||||
|
"""Unlike weakref.WeakMethod, this class has an __eq__ I can trust."""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
meth = self()
|
||||||
|
self._my_id = (id(meth.__self__), id(meth.__func__))
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self._my_id)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if not isinstance(other, WeakMethodProper):
|
||||||
|
return False
|
||||||
|
return self._my_id == other._my_id
|
||||||
|
|
||||||
|
|
||||||
class EventListener:
|
class EventListener:
|
||||||
"""Use as a mixin for a class that has methods to be triggered on events.
|
"""Use as a mixin for a class that has methods to be triggered on events.
|
||||||
- Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener.
|
- Methods that receive the callbacks should be named "on_event_*" and decorated with @event_listener.
|
||||||
- register_callbacks() should be called exactly once per instance of EventListener, e.g. in __init__
|
- register_callbacks() should be called once per instance of EventListener, e.g. in __init__
|
||||||
- unregister_callbacks() should be called at least once, e.g. when the instance is destroyed
|
- unregister_callbacks() should be called at least once, e.g. when the instance is destroyed
|
||||||
|
- as fallback, __del__() also calls unregister_callbacks()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _list_callbacks(self):
|
def _list_callbacks(self):
|
||||||
@@ -2031,6 +2070,9 @@ class EventListener:
|
|||||||
#_logger.debug(f'unregistering callback {method}')
|
#_logger.debug(f'unregistering callback {method}')
|
||||||
unregister_callback(method)
|
unregister_callback(method)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.unregister_callbacks()
|
||||||
|
|
||||||
|
|
||||||
def event_listener(func):
|
def event_listener(func):
|
||||||
"""To be used in subclasses of EventListener only. (how to enforce this programmatically?)"""
|
"""To be used in subclasses of EventListener only. (how to enforce this programmatically?)"""
|
||||||
|
|||||||
@@ -1,11 +1,32 @@
|
|||||||
|
import asyncio
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from typing import Sequence, Mapping, TypeVar, Optional
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
from electrum import util
|
||||||
from electrum.util import ThreadJob
|
from electrum.util import ThreadJob
|
||||||
|
|
||||||
|
|
||||||
|
_U = TypeVar('_U')
|
||||||
|
|
||||||
|
def count_objects_in_memory(mclasses: Sequence[type[_U]]) -> Mapping[type[_U], Sequence[weakref.ref[_U]]]:
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
objmap = defaultdict(list)
|
||||||
|
for obj in gc.get_objects():
|
||||||
|
for class_ in mclasses:
|
||||||
|
try:
|
||||||
|
_isinstance = isinstance(obj, class_)
|
||||||
|
except AttributeError:
|
||||||
|
_isinstance = False
|
||||||
|
if _isinstance:
|
||||||
|
objmap[class_].append(weakref.ref(obj))
|
||||||
|
return objmap
|
||||||
|
|
||||||
|
|
||||||
class DebugMem(ThreadJob):
|
class DebugMem(ThreadJob):
|
||||||
'''A handy class for debugging GC memory leaks
|
'''A handy class for debugging GC memory leaks
|
||||||
|
|
||||||
@@ -21,18 +42,8 @@ class DebugMem(ThreadJob):
|
|||||||
self.interval = interval
|
self.interval = interval
|
||||||
|
|
||||||
def mem_stats(self):
|
def mem_stats(self):
|
||||||
import gc
|
|
||||||
self.logger.info("Start memscan")
|
self.logger.info("Start memscan")
|
||||||
gc.collect()
|
objmap = count_objects_in_memory(self.classes)
|
||||||
objmap = defaultdict(list)
|
|
||||||
for obj in gc.get_objects():
|
|
||||||
for class_ in self.classes:
|
|
||||||
try:
|
|
||||||
_isinstance = isinstance(obj, class_)
|
|
||||||
except AttributeError:
|
|
||||||
_isinstance = False
|
|
||||||
if _isinstance:
|
|
||||||
objmap[class_].append(obj)
|
|
||||||
for class_, objs in objmap.items():
|
for class_, objs in objmap.items():
|
||||||
self.logger.info(f"{class_.__name__}: {len(objs)}")
|
self.logger.info(f"{class_.__name__}: {len(objs)}")
|
||||||
self.logger.info("Finish memscan")
|
self.logger.info("Finish memscan")
|
||||||
@@ -43,6 +54,26 @@ class DebugMem(ThreadJob):
|
|||||||
self.next_time = time.time() + self.interval
|
self.next_time = time.time() + self.interval
|
||||||
|
|
||||||
|
|
||||||
|
async def wait_until_obj_is_garbage_collected(wr: weakref.ref) -> None:
|
||||||
|
"""Async wait until the object referenced by `wr` is GC-ed."""
|
||||||
|
obj = wr()
|
||||||
|
if obj is None:
|
||||||
|
return
|
||||||
|
evt_gc = asyncio.Event() # set when obj is finally GC-ed.
|
||||||
|
wr2 = weakref.ref(obj, lambda _x: util.run_sync_function_on_asyncio_thread(evt_gc.set, block=False))
|
||||||
|
del obj
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
async with util.async_timeout(0.01):
|
||||||
|
await evt_gc.wait()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
assert evt_gc.is_set()
|
||||||
|
|
||||||
|
|
||||||
def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
|
def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
|
||||||
"""Return a string listing the most common types in memory."""
|
"""Return a string listing the most common types in memory."""
|
||||||
import objgraph # 3rd-party dependency
|
import objgraph # 3rd-party dependency
|
||||||
@@ -52,7 +83,7 @@ def debug_memusage_list_all_objects(limit: int = 50) -> list[tuple[str, int]]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def debug_memusage_dump_random_backref_chain(objtype: str) -> str:
|
def debug_memusage_dump_random_backref_chain(objtype: str) -> Optional[str]:
|
||||||
"""Writes a dotfile to cwd, containing the backref chain
|
"""Writes a dotfile to cwd, containing the backref chain
|
||||||
for a randomly selected object of type objtype.
|
for a randomly selected object of type objtype.
|
||||||
|
|
||||||
@@ -68,10 +99,14 @@ def debug_memusage_dump_random_backref_chain(objtype: str) -> str:
|
|||||||
import random
|
import random
|
||||||
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||||
fpath = os.path.abspath(f"electrum_backref_chain_{timestamp}.dot")
|
fpath = os.path.abspath(f"electrum_backref_chain_{timestamp}.dot")
|
||||||
|
objects = objgraph.by_type(objtype)
|
||||||
|
if not objects:
|
||||||
|
return None
|
||||||
|
random_obj = random.choice(objects)
|
||||||
with open(fpath, "w") as f:
|
with open(fpath, "w") as f:
|
||||||
objgraph.show_chain(
|
objgraph.show_chain(
|
||||||
objgraph.find_backref_chain(
|
objgraph.find_backref_chain(
|
||||||
random.choice(objgraph.by_type(objtype)),
|
random_obj,
|
||||||
objgraph.is_proper_module),
|
objgraph.is_proper_module),
|
||||||
output=f)
|
output=f)
|
||||||
return fpath
|
return fpath
|
||||||
|
|||||||
@@ -562,11 +562,12 @@ class Abstract_Wallet(ABC, Logger, EventListener):
|
|||||||
self.unregister_callbacks()
|
self.unregister_callbacks()
|
||||||
try:
|
try:
|
||||||
async with ignore_after(5):
|
async with ignore_after(5):
|
||||||
|
if self.lnworker:
|
||||||
|
await self.lnworker.stop()
|
||||||
|
self.lnworker = None
|
||||||
if self.network:
|
if self.network:
|
||||||
if self.lnworker:
|
|
||||||
await self.lnworker.stop()
|
|
||||||
self.lnworker = None
|
|
||||||
self.network = None
|
self.network = None
|
||||||
|
if self.taskgroup:
|
||||||
await self.taskgroup.cancel_remaining()
|
await self.taskgroup.cancel_remaining()
|
||||||
self.taskgroup = None
|
self.taskgroup = None
|
||||||
await self.adb.stop()
|
await self.adb.stop()
|
||||||
|
|||||||
141
tests/test_callbackmgr.py
Normal file
141
tests/test_callbackmgr.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
import asyncio
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
from electrum import util
|
||||||
|
from electrum.util import EventListener, event_listener, trigger_callback
|
||||||
|
from electrum.utils.memory_leak import count_objects_in_memory, wait_until_obj_is_garbage_collected
|
||||||
|
from electrum.simple_config import SimpleConfig
|
||||||
|
|
||||||
|
from . import ElectrumTestCase, restore_wallet_from_text__for_unittest
|
||||||
|
|
||||||
|
|
||||||
|
class MyEventListener(EventListener):
|
||||||
|
def __init__(self, *, autostart: bool = False):
|
||||||
|
self._satoshi_cnt = 0
|
||||||
|
self._hal_cnt = 0
|
||||||
|
if autostart:
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.register_callbacks()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.unregister_callbacks()
|
||||||
|
|
||||||
|
@event_listener
|
||||||
|
async def on_event_satoshi_moves_his_coins(self, arg1, arg2):
|
||||||
|
self._satoshi_cnt += 1
|
||||||
|
|
||||||
|
@event_listener
|
||||||
|
def on_event_hal_moves_his_coins(self, arg1, arg2): # non-async
|
||||||
|
self._hal_cnt += 1
|
||||||
|
|
||||||
|
|
||||||
|
_count_all_callbacks = util.callback_mgr.count_all_callbacks
|
||||||
|
|
||||||
|
|
||||||
|
async def fast_sleep():
|
||||||
|
# sleep a few event loop iterations
|
||||||
|
for i in range(5):
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbackMgr(ElectrumTestCase):
|
||||||
|
|
||||||
|
def test_multiple_calls_to_register_callbacks(self):
|
||||||
|
self.assertEqual(0, _count_all_callbacks())
|
||||||
|
el1 = MyEventListener()
|
||||||
|
el2 = MyEventListener()
|
||||||
|
self.assertEqual(0, _count_all_callbacks())
|
||||||
|
el1.start()
|
||||||
|
self.assertEqual(2, _count_all_callbacks())
|
||||||
|
el2.start()
|
||||||
|
self.assertEqual(4, _count_all_callbacks())
|
||||||
|
el1.start()
|
||||||
|
self.assertEqual(4, _count_all_callbacks())
|
||||||
|
el1.stop()
|
||||||
|
self.assertEqual(2, _count_all_callbacks())
|
||||||
|
el1.stop()
|
||||||
|
self.assertEqual(2, _count_all_callbacks())
|
||||||
|
el1.stop()
|
||||||
|
self.assertEqual(2, _count_all_callbacks())
|
||||||
|
el2.stop()
|
||||||
|
self.assertEqual(0, _count_all_callbacks())
|
||||||
|
|
||||||
|
async def test_trigger_callback(self):
|
||||||
|
el1 = MyEventListener()
|
||||||
|
el1.start()
|
||||||
|
el2 = MyEventListener()
|
||||||
|
el2.start()
|
||||||
|
# trigger some cbs
|
||||||
|
self.assertEqual(el1._satoshi_cnt, 0)
|
||||||
|
self.assertEqual(el1._hal_cnt, 0)
|
||||||
|
trigger_callback('satoshi_moves_his_coins', 0, 0)
|
||||||
|
trigger_callback('satoshi_moves_his_coins', 0, 0)
|
||||||
|
trigger_callback('satoshi_moves_his_coins', 0, 0)
|
||||||
|
trigger_callback('hal_moves_his_coins', 0, 0)
|
||||||
|
await fast_sleep()
|
||||||
|
self.assertEqual(el1._satoshi_cnt, 3)
|
||||||
|
self.assertEqual(el2._satoshi_cnt, 3)
|
||||||
|
self.assertEqual(el1._hal_cnt, 1)
|
||||||
|
self.assertEqual(el2._hal_cnt, 1)
|
||||||
|
# stop one listener, see new triggers are only seen by other one still running
|
||||||
|
el1.stop()
|
||||||
|
trigger_callback('satoshi_moves_his_coins', 0, 0)
|
||||||
|
trigger_callback('hal_moves_his_coins', 0, 0)
|
||||||
|
await fast_sleep()
|
||||||
|
self.assertEqual(el1._satoshi_cnt, 3)
|
||||||
|
self.assertEqual(el2._satoshi_cnt, 4)
|
||||||
|
self.assertEqual(el1._hal_cnt, 1)
|
||||||
|
self.assertEqual(el2._hal_cnt, 2)
|
||||||
|
|
||||||
|
async def test_gc(self):
|
||||||
|
objmap = count_objects_in_memory([MyEventListener])
|
||||||
|
self.assertEqual(len(objmap[MyEventListener]), 0)
|
||||||
|
self.assertEqual(_count_all_callbacks(), 0)
|
||||||
|
el1 = MyEventListener()
|
||||||
|
el1.start()
|
||||||
|
el2 = MyEventListener()
|
||||||
|
el2.start()
|
||||||
|
objmap = count_objects_in_memory([MyEventListener])
|
||||||
|
self.assertEqual(len(objmap[MyEventListener]), 2)
|
||||||
|
self.assertEqual(_count_all_callbacks(), 4)
|
||||||
|
# test if we can get GC-ed if we explicitly unregister cbs:
|
||||||
|
el1.stop() # calls unregister_callbacks
|
||||||
|
del el1
|
||||||
|
objmap = count_objects_in_memory([MyEventListener])
|
||||||
|
self.assertEqual(len(objmap[MyEventListener]), 1)
|
||||||
|
self.assertEqual(_count_all_callbacks(), 2)
|
||||||
|
# test if we can get GC-ed even without unregistering cbs:
|
||||||
|
del el2
|
||||||
|
objmap = count_objects_in_memory([MyEventListener])
|
||||||
|
self.assertEqual(len(objmap[MyEventListener]), 0)
|
||||||
|
self.assertEqual(_count_all_callbacks(), 0)
|
||||||
|
|
||||||
|
async def test_gc2(self):
|
||||||
|
def func():
|
||||||
|
el1 = MyEventListener(autostart=True)
|
||||||
|
el1.el2 = MyEventListener(autostart=True)
|
||||||
|
el1.el2.el3 = MyEventListener(autostart=True)
|
||||||
|
self.assertEqual(_count_all_callbacks(), 6)
|
||||||
|
func()
|
||||||
|
self.assertEqual(_count_all_callbacks(), 0)
|
||||||
|
|
||||||
|
async def test_gc_complex_using_wallet(self):
|
||||||
|
"""This test showcases why EventListener uses WeakMethodProper instead of weakref.WeakMethod.
|
||||||
|
We need the custom __eq__ for some reason.
|
||||||
|
"""
|
||||||
|
self.assertEqual(_count_all_callbacks(), 0)
|
||||||
|
config = SimpleConfig({'electrum_path': self.electrum_path})
|
||||||
|
wallet = restore_wallet_from_text__for_unittest(
|
||||||
|
"9dk", path=None, config=config,
|
||||||
|
)["wallet"]
|
||||||
|
assert wallet.lnworker is not None
|
||||||
|
# now delete the wallet, and wait for it to get GC-ed
|
||||||
|
# note: need to wait for cyclic GC. example: wallet.lnworker.wallet
|
||||||
|
wr = weakref.ref(wallet)
|
||||||
|
del wallet
|
||||||
|
async with util.async_timeout(5):
|
||||||
|
await wait_until_obj_is_garbage_collected(wr)
|
||||||
|
# by now, all callbacks must have been cleaned up:
|
||||||
|
self.assertEqual(_count_all_callbacks(), 0)
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Iterable
|
from typing import Optional, Iterable
|
||||||
|
|
||||||
@@ -5,7 +7,10 @@ from electrum.commands import Commands
|
|||||||
from electrum.daemon import Daemon
|
from electrum.daemon import Daemon
|
||||||
from electrum.simple_config import SimpleConfig
|
from electrum.simple_config import SimpleConfig
|
||||||
from electrum.wallet import Abstract_Wallet
|
from electrum.wallet import Abstract_Wallet
|
||||||
|
from electrum.lnworker import LNWallet, LNPeerManager
|
||||||
|
from electrum.lnwatcher import LNWatcher
|
||||||
from electrum import util
|
from electrum import util
|
||||||
|
from electrum.utils.memory_leak import count_objects_in_memory
|
||||||
|
|
||||||
from . import ElectrumTestCase, as_testnet, restore_wallet_from_text__for_unittest
|
from . import ElectrumTestCase, as_testnet, restore_wallet_from_text__for_unittest
|
||||||
|
|
||||||
@@ -30,7 +35,7 @@ class DaemonTestCase(ElectrumTestCase):
|
|||||||
await self.daemon.stop()
|
await self.daemon.stop()
|
||||||
await super().asyncTearDown()
|
await super().asyncTearDown()
|
||||||
|
|
||||||
def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None) -> str:
|
def _restore_wallet_from_text(self, text, *, password: Optional[str], encrypt_file: bool = None, **kwargs) -> str:
|
||||||
"""Returns path for created wallet."""
|
"""Returns path for created wallet."""
|
||||||
basename = util.get_new_wallet_name(self.wallet_dir)
|
basename = util.get_new_wallet_name(self.wallet_dir)
|
||||||
path = os.path.join(self.wallet_dir, basename)
|
path = os.path.join(self.wallet_dir, basename)
|
||||||
@@ -40,6 +45,7 @@ class DaemonTestCase(ElectrumTestCase):
|
|||||||
password=password,
|
password=password,
|
||||||
encrypt_file=encrypt_file,
|
encrypt_file=encrypt_file,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
# We return the path instead of the wallet object, as extreme
|
# We return the path instead of the wallet object, as extreme
|
||||||
# care would be needed to use the wallet object directly:
|
# care would be needed to use the wallet object directly:
|
||||||
@@ -188,6 +194,40 @@ class TestUnifiedPassword(DaemonTestCase):
|
|||||||
self.assertTrue(is_unified)
|
self.assertTrue(is_unified)
|
||||||
self._run_post_unif_sanity_checks(paths, password="123456")
|
self._run_post_unif_sanity_checks(paths, password="123456")
|
||||||
|
|
||||||
|
# misc --->
|
||||||
|
|
||||||
|
async def test_wallet_objects_are_properly_garbage_collected_after_check_pw_for_dir(self):
|
||||||
|
orig_cb_count = util.callback_mgr.count_all_callbacks()
|
||||||
|
# GC sanity-check:
|
||||||
|
mclasses = [Abstract_Wallet, LNWallet, LNWatcher, LNPeerManager]
|
||||||
|
objmap = count_objects_in_memory(mclasses)
|
||||||
|
for mcls in mclasses:
|
||||||
|
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
|
||||||
|
# restore some wallets
|
||||||
|
paths = []
|
||||||
|
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True))
|
||||||
|
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=False))
|
||||||
|
paths.append(self._restore_wallet_from_text("9dk", password=None))
|
||||||
|
paths.append(self._restore_wallet_from_text("9dk", password="123456", encrypt_file=True, passphrase="hunter2"))
|
||||||
|
paths.append(self._restore_wallet_from_text("9dk", password="999999", encrypt_file=False, passphrase="hunter2"))
|
||||||
|
paths.append(self._restore_wallet_from_text("9dk", password=None, passphrase="hunter2"))
|
||||||
|
# test unification
|
||||||
|
can_be_unified, is_unified, paths_succeeded = self.daemon.check_password_for_directory(old_password="123456", wallet_dir=self.wallet_dir)
|
||||||
|
self.assertEqual((False, False, 5), (can_be_unified, is_unified, len(paths_succeeded)))
|
||||||
|
# gc
|
||||||
|
try:
|
||||||
|
async with util.async_timeout(5):
|
||||||
|
while True:
|
||||||
|
objmap = count_objects_in_memory(mclasses)
|
||||||
|
if sum(len(lst) for lst in objmap.values()) == 0:
|
||||||
|
break # all "mclasses"-type objects have been GC-ed
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
for mcls in mclasses:
|
||||||
|
self.assertEqual(len(objmap[mcls]), 0, msg=f"too many lingering objs of type={mcls}")
|
||||||
|
# also check callbacks have been cleaned up:
|
||||||
|
self.assertEqual(orig_cb_count, util.callback_mgr.count_all_callbacks())
|
||||||
|
|
||||||
|
|
||||||
class TestCommandsWithDaemon(DaemonTestCase):
|
class TestCommandsWithDaemon(DaemonTestCase):
|
||||||
TESTNET = True
|
TESTNET = True
|
||||||
|
|||||||
Reference in New Issue
Block a user