daemon/wallet/network: make stop() methods async
This commit is contained in:
@@ -252,6 +252,11 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
default_server: ServerAddr
|
||||
_recent_servers: List[ServerAddr]
|
||||
|
||||
channel_blacklist: 'ChannelBlackList'
|
||||
channel_db: Optional['ChannelDB'] = None
|
||||
lngossip: Optional['LNGossip'] = None
|
||||
local_watchtower: Optional['WatchTower'] = None
|
||||
|
||||
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
|
||||
global _INSTANCE
|
||||
assert _INSTANCE is None, "Network is a singleton!"
|
||||
@@ -344,9 +349,6 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
|
||||
# lightning network
|
||||
self.channel_blacklist = ChannelBlackList()
|
||||
self.channel_db = None # type: Optional[ChannelDB]
|
||||
self.lngossip = None # type: Optional[LNGossip]
|
||||
self.local_watchtower = None # type: Optional[WatchTower]
|
||||
if self.config.get('run_local_watchtower', False):
|
||||
from . import lnwatcher
|
||||
self.local_watchtower = lnwatcher.WatchTower(self)
|
||||
@@ -373,11 +375,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
self.lngossip = lnworker.LNGossip()
|
||||
self.lngossip.start_network(self)
|
||||
|
||||
def stop_gossip(self):
|
||||
async def stop_gossip(self, *, full_shutdown: bool = False):
|
||||
if self.lngossip:
|
||||
self.lngossip.stop()
|
||||
await self.lngossip.stop()
|
||||
self.lngossip = None
|
||||
self.channel_db.stop()
|
||||
if full_shutdown:
|
||||
await self.channel_db.stopped_event.wait()
|
||||
self.channel_db = None
|
||||
|
||||
def run_from_another_thread(self, coro, *, timeout=None):
|
||||
@@ -623,7 +627,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
self.auto_connect = net_params.auto_connect
|
||||
if self.proxy != proxy or self.oneserver != net_params.oneserver:
|
||||
# Restart the network defaulting to the given server
|
||||
await self._stop()
|
||||
await self.stop(full_shutdown=False)
|
||||
self.default_server = server
|
||||
await self._start()
|
||||
elif self.default_server != server:
|
||||
@@ -1217,13 +1221,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
asyncio.run_coroutine_threadsafe(self._start(), self.asyncio_loop)
|
||||
|
||||
@log_exceptions
|
||||
async def _stop(self, full_shutdown=False):
|
||||
async def stop(self, *, full_shutdown: bool = True):
|
||||
self.logger.info("stopping network")
|
||||
try:
|
||||
# note: cancel_remaining ~cannot be cancelled, it suppresses CancelledError
|
||||
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
|
||||
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=1)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
|
||||
self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
|
||||
self.logger.info(f"exc during taskgroup cancellation: {repr(e)}")
|
||||
self.taskgroup = None
|
||||
self.interface = None
|
||||
self.interfaces = {}
|
||||
@@ -1231,13 +1235,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
self._closing_ifaces.clear()
|
||||
if not full_shutdown:
|
||||
util.trigger_callback('network_updated')
|
||||
|
||||
def stop(self):
|
||||
assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
|
||||
fut = asyncio.run_coroutine_threadsafe(self._stop(full_shutdown=True), self.asyncio_loop)
|
||||
try:
|
||||
fut.result(timeout=2)
|
||||
except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError): pass
|
||||
if full_shutdown:
|
||||
await self.stop_gossip(full_shutdown=full_shutdown)
|
||||
|
||||
async def _ensure_there_is_a_main_interface(self):
|
||||
if self.is_connected():
|
||||
|
||||
Reference in New Issue
Block a user