daemon/wallet/network: make stop() methods async
This commit is contained in:
@@ -28,6 +28,8 @@ import itertools
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple, NamedTuple, Sequence, List
|
||||
|
||||
from aiorpcx import TaskGroup
|
||||
|
||||
from . import bitcoin, util
|
||||
from .bitcoin import COINBASE_MATURITY
|
||||
from .util import profiler, bfh, TxMinedInfo, UnrelatedTransactionException
|
||||
@@ -197,16 +199,19 @@ class AddressSynchronizer(Logger):
|
||||
def on_blockchain_updated(self, event, *args):
|
||||
self._get_addr_balance_cache = {} # invalidate cache
|
||||
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
if self.network:
|
||||
if self.synchronizer:
|
||||
asyncio.run_coroutine_threadsafe(self.synchronizer.stop(), self.network.asyncio_loop)
|
||||
try:
|
||||
async with TaskGroup() as group:
|
||||
if self.synchronizer:
|
||||
await group.spawn(self.synchronizer.stop())
|
||||
if self.verifier:
|
||||
await group.spawn(self.verifier.stop())
|
||||
finally: # even if we get cancelled
|
||||
self.synchronizer = None
|
||||
if self.verifier:
|
||||
asyncio.run_coroutine_threadsafe(self.verifier.stop(), self.network.asyncio_loop)
|
||||
self.verifier = None
|
||||
util.unregister_callback(self.on_blockchain_updated)
|
||||
self.db.put('stored_height', self.get_local_height())
|
||||
util.unregister_callback(self.on_blockchain_updated)
|
||||
self.db.put('stored_height', self.get_local_height())
|
||||
|
||||
def add_address(self, address):
|
||||
if not self.db.get_addr_history(address):
|
||||
|
||||
@@ -29,7 +29,7 @@ import time
|
||||
import traceback
|
||||
import sys
|
||||
import threading
|
||||
from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping
|
||||
from typing import Dict, Optional, Tuple, Iterable, Callable, Union, Sequence, Mapping, TYPE_CHECKING
|
||||
from base64 import b64decode, b64encode
|
||||
from collections import defaultdict
|
||||
import concurrent
|
||||
@@ -38,7 +38,7 @@ import json
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web, client_exceptions
|
||||
from aiorpcx import TaskGroup
|
||||
from aiorpcx import TaskGroup, timeout_after, TaskTimeout
|
||||
|
||||
from . import util
|
||||
from .network import Network
|
||||
@@ -53,6 +53,9 @@ from .simple_config import SimpleConfig
|
||||
from .exchange_rate import FxThread
|
||||
from .logging import get_logger, Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from electrum import gui
|
||||
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
@@ -407,6 +410,7 @@ class PayServer(Logger):
|
||||
class Daemon(Logger):
|
||||
|
||||
network: Optional[Network]
|
||||
gui_object: Optional[Union['gui.qt.ElectrumGui', 'gui.kivy.ElectrumGui']]
|
||||
|
||||
@profiler
|
||||
def __init__(self, config: SimpleConfig, fd=None, *, listen_jsonrpc=True):
|
||||
@@ -523,7 +527,8 @@ class Daemon(Logger):
|
||||
wallet = self._wallets.pop(path, None)
|
||||
if not wallet:
|
||||
return False
|
||||
wallet.stop()
|
||||
fut = asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop)
|
||||
fut.result()
|
||||
return True
|
||||
|
||||
def run_daemon(self):
|
||||
@@ -544,20 +549,28 @@ class Daemon(Logger):
|
||||
self.running = False
|
||||
|
||||
def on_stop(self):
|
||||
self.logger.info("on_stop() entered. initiating shutdown")
|
||||
if self.gui_object:
|
||||
self.gui_object.stop()
|
||||
# stop network/wallets
|
||||
for k, wallet in self._wallets.items():
|
||||
wallet.stop()
|
||||
if self.network:
|
||||
self.logger.info("shutting down network")
|
||||
self.network.stop()
|
||||
self.logger.info("stopping taskgroup")
|
||||
fut = asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.asyncio_loop)
|
||||
try:
|
||||
fut.result(timeout=2)
|
||||
except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
@log_exceptions
|
||||
async def stop_async():
|
||||
self.logger.info("stopping all wallets")
|
||||
async with TaskGroup() as group:
|
||||
for k, wallet in self._wallets.items():
|
||||
await group.spawn(wallet.stop())
|
||||
self.logger.info("stopping network and taskgroup")
|
||||
try:
|
||||
async with timeout_after(2):
|
||||
async with TaskGroup() as group:
|
||||
if self.network:
|
||||
await group.spawn(self.network.stop(full_shutdown=True))
|
||||
await group.spawn(self.taskgroup.cancel_remaining())
|
||||
except TaskTimeout:
|
||||
pass
|
||||
|
||||
fut = asyncio.run_coroutine_threadsafe(stop_async(), self.asyncio_loop)
|
||||
fut.result()
|
||||
self.logger.info("removing lockfile")
|
||||
remove_lockfile(get_lockfile(self.config))
|
||||
self.logger.info("stopped")
|
||||
|
||||
@@ -3,3 +3,9 @@
|
||||
# The Wallet object is instantiated by the GUI
|
||||
|
||||
# Notifications about network events are sent to the GUI by using network.register_callback()
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import qt
|
||||
from . import kivy
|
||||
|
||||
@@ -190,7 +190,8 @@ class ElectrumWindow(App, Logger):
|
||||
if self.use_gossip:
|
||||
self.network.start_gossip()
|
||||
else:
|
||||
self.network.stop_gossip()
|
||||
self.network.run_from_another_thread(
|
||||
self.network.stop_gossip())
|
||||
|
||||
android_backups = BooleanProperty(False)
|
||||
def on_android_backups(self, instance, x):
|
||||
|
||||
@@ -141,7 +141,8 @@ channels graph and compute payment path locally, instead of using trampoline pay
|
||||
if use_gossip:
|
||||
self.window.network.start_gossip()
|
||||
else:
|
||||
self.window.network.stop_gossip()
|
||||
self.window.network.run_from_another_thread(
|
||||
self.window.network.stop_gossip())
|
||||
util.trigger_callback('ln_gossip_sync_progress')
|
||||
# FIXME: update all wallet windows
|
||||
util.trigger_callback('channels_updated', self.wallet)
|
||||
|
||||
@@ -695,7 +695,7 @@ class Interface(Logger):
|
||||
# We give up after a while and just abort the connection.
|
||||
# Note: specifically if the server is running Fulcrum, waiting seems hopeless,
|
||||
# the connection must be aborted (see https://github.com/cculianu/Fulcrum/issues/76)
|
||||
force_after = 2 # seconds
|
||||
force_after = 1 # seconds
|
||||
if self.session:
|
||||
await self.session.close(force_after=force_after)
|
||||
# monitor_connection will cancel tasks
|
||||
|
||||
@@ -147,8 +147,8 @@ class LNWatcher(AddressSynchronizer):
|
||||
# status gets populated when we run
|
||||
self.channel_status = {}
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
util.unregister_callback(self.on_network_update)
|
||||
|
||||
def get_channel_status(self, outpoint):
|
||||
|
||||
@@ -311,11 +311,11 @@ class LNWorker(Logger, NetworkRetryManager[LNPeerAddr]):
|
||||
self._add_peers_from_config()
|
||||
asyncio.run_coroutine_threadsafe(self.main_loop(), self.network.asyncio_loop)
|
||||
|
||||
def stop(self):
|
||||
async def stop(self):
|
||||
if self.listen_server:
|
||||
self.network.asyncio_loop.call_soon_threadsafe(self.listen_server.close)
|
||||
asyncio.run_coroutine_threadsafe(self.taskgroup.cancel_remaining(), self.network.asyncio_loop)
|
||||
self.listen_server.close()
|
||||
util.unregister_callback(self.on_proxy_changed)
|
||||
await self.taskgroup.cancel_remaining()
|
||||
|
||||
def _add_peers_from_config(self):
|
||||
peer_list = self.config.get('lightning_peers', [])
|
||||
@@ -704,9 +704,9 @@ class LNWallet(LNWorker):
|
||||
tg_coro = self.taskgroup.spawn(coro)
|
||||
asyncio.run_coroutine_threadsafe(tg_coro, self.network.asyncio_loop)
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
self.lnwatcher.stop()
|
||||
async def stop(self):
|
||||
await super().stop()
|
||||
await self.lnwatcher.stop()
|
||||
self.lnwatcher = None
|
||||
|
||||
def peer_closed(self, peer):
|
||||
|
||||
@@ -252,6 +252,11 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
default_server: ServerAddr
|
||||
_recent_servers: List[ServerAddr]
|
||||
|
||||
channel_blacklist: 'ChannelBlackList'
|
||||
channel_db: Optional['ChannelDB'] = None
|
||||
lngossip: Optional['LNGossip'] = None
|
||||
local_watchtower: Optional['WatchTower'] = None
|
||||
|
||||
def __init__(self, config: SimpleConfig, *, daemon: 'Daemon' = None):
|
||||
global _INSTANCE
|
||||
assert _INSTANCE is None, "Network is a singleton!"
|
||||
@@ -344,9 +349,6 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
|
||||
# lightning network
|
||||
self.channel_blacklist = ChannelBlackList()
|
||||
self.channel_db = None # type: Optional[ChannelDB]
|
||||
self.lngossip = None # type: Optional[LNGossip]
|
||||
self.local_watchtower = None # type: Optional[WatchTower]
|
||||
if self.config.get('run_local_watchtower', False):
|
||||
from . import lnwatcher
|
||||
self.local_watchtower = lnwatcher.WatchTower(self)
|
||||
@@ -373,11 +375,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
self.lngossip = lnworker.LNGossip()
|
||||
self.lngossip.start_network(self)
|
||||
|
||||
def stop_gossip(self):
|
||||
async def stop_gossip(self, *, full_shutdown: bool = False):
|
||||
if self.lngossip:
|
||||
self.lngossip.stop()
|
||||
await self.lngossip.stop()
|
||||
self.lngossip = None
|
||||
self.channel_db.stop()
|
||||
if full_shutdown:
|
||||
await self.channel_db.stopped_event.wait()
|
||||
self.channel_db = None
|
||||
|
||||
def run_from_another_thread(self, coro, *, timeout=None):
|
||||
@@ -623,7 +627,7 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
self.auto_connect = net_params.auto_connect
|
||||
if self.proxy != proxy or self.oneserver != net_params.oneserver:
|
||||
# Restart the network defaulting to the given server
|
||||
await self._stop()
|
||||
await self.stop(full_shutdown=False)
|
||||
self.default_server = server
|
||||
await self._start()
|
||||
elif self.default_server != server:
|
||||
@@ -1217,13 +1221,13 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
asyncio.run_coroutine_threadsafe(self._start(), self.asyncio_loop)
|
||||
|
||||
@log_exceptions
|
||||
async def _stop(self, full_shutdown=False):
|
||||
async def stop(self, *, full_shutdown: bool = True):
|
||||
self.logger.info("stopping network")
|
||||
try:
|
||||
# note: cancel_remaining ~cannot be cancelled, it suppresses CancelledError
|
||||
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=2)
|
||||
await asyncio.wait_for(self.taskgroup.cancel_remaining(), timeout=1)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError) as e:
|
||||
self.logger.info(f"exc during main_taskgroup cancellation: {repr(e)}")
|
||||
self.logger.info(f"exc during taskgroup cancellation: {repr(e)}")
|
||||
self.taskgroup = None
|
||||
self.interface = None
|
||||
self.interfaces = {}
|
||||
@@ -1231,13 +1235,8 @@ class Network(Logger, NetworkRetryManager[ServerAddr]):
|
||||
self._closing_ifaces.clear()
|
||||
if not full_shutdown:
|
||||
util.trigger_callback('network_updated')
|
||||
|
||||
def stop(self):
|
||||
assert self._loop_thread != threading.current_thread(), 'must not be called from network thread'
|
||||
fut = asyncio.run_coroutine_threadsafe(self._stop(full_shutdown=True), self.asyncio_loop)
|
||||
try:
|
||||
fut.result(timeout=2)
|
||||
except (concurrent.futures.TimeoutError, concurrent.futures.CancelledError): pass
|
||||
if full_shutdown:
|
||||
await self.stop_gossip(full_shutdown=full_shutdown)
|
||||
|
||||
async def _ensure_there_is_a_main_interface(self):
|
||||
if self.is_connected():
|
||||
|
||||
@@ -25,6 +25,7 @@ class SqlDB(Logger):
|
||||
Logger.__init__(self)
|
||||
self.asyncio_loop = asyncio_loop
|
||||
self.stopping = False
|
||||
self.stopped_event = asyncio.Event()
|
||||
self.path = path
|
||||
test_read_write_permissions(path)
|
||||
self.commit_interval = commit_interval
|
||||
@@ -65,6 +66,8 @@ class SqlDB(Logger):
|
||||
# write
|
||||
self.conn.commit()
|
||||
self.conn.close()
|
||||
|
||||
self.asyncio_loop.call_soon_threadsafe(self.stopped_event.set)
|
||||
self.logger.info("SQL thread terminated")
|
||||
|
||||
def create_database(self):
|
||||
|
||||
@@ -3,10 +3,12 @@ import tempfile
|
||||
import os
|
||||
import json
|
||||
from typing import Optional
|
||||
import asyncio
|
||||
|
||||
from electrum.wallet_db import WalletDB
|
||||
from electrum.wallet import Wallet
|
||||
from electrum import constants
|
||||
from electrum import util
|
||||
|
||||
from .test_wallet import WalletTestCase
|
||||
|
||||
@@ -15,6 +17,15 @@ from .test_wallet import WalletTestCase
|
||||
# TODO hw wallet with client version 2.6.x (single-, and multiacc)
|
||||
class TestStorageUpgrade(WalletTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.asyncio_loop, self._stop_loop, self._loop_thread = util.create_and_start_event_loop()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
|
||||
self._loop_thread.join(timeout=1)
|
||||
|
||||
def testnet_wallet(func):
|
||||
# note: it's ok to modify global network constants in subclasses of SequentialTestCase
|
||||
def wrapper(self, *args, **kwargs):
|
||||
@@ -281,7 +292,7 @@ class TestStorageUpgrade(WalletTestCase):
|
||||
# to simulate ks.opportunistically_fill_in_missing_info_from_device():
|
||||
ks._root_fingerprint = "deadbeef"
|
||||
ks.is_requesting_to_be_rewritten_to_wallet_file = True
|
||||
wallet.stop()
|
||||
asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result()
|
||||
|
||||
def test_upgrade_from_client_2_9_3_importedkeys_keystore_changes(self):
|
||||
# see #6401
|
||||
@@ -292,7 +303,7 @@ class TestStorageUpgrade(WalletTestCase):
|
||||
["p2wpkh:L1cgMEnShp73r9iCukoPE3MogLeueNYRD9JVsfT1zVHyPBR3KqBY"],
|
||||
password=None
|
||||
)
|
||||
wallet.stop()
|
||||
asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result()
|
||||
|
||||
@testnet_wallet
|
||||
def test_upgrade_from_client_3_3_8_xpub_with_realistic_history(self):
|
||||
|
||||
@@ -5,8 +5,9 @@ import os
|
||||
import json
|
||||
from decimal import Decimal
|
||||
import time
|
||||
|
||||
from io import StringIO
|
||||
import asyncio
|
||||
|
||||
from electrum.storage import WalletStorage
|
||||
from electrum.wallet_db import FINAL_SEED_VERSION
|
||||
from electrum.wallet import (Abstract_Wallet, Standard_Wallet, create_new_wallet,
|
||||
@@ -16,6 +17,7 @@ from electrum.util import TxMinedInfo, InvalidPassword
|
||||
from electrum.bitcoin import COIN
|
||||
from electrum.wallet_db import WalletDB
|
||||
from electrum.simple_config import SimpleConfig
|
||||
from electrum import util
|
||||
|
||||
from . import ElectrumTestCase
|
||||
|
||||
@@ -237,6 +239,15 @@ class TestCreateRestoreWallet(WalletTestCase):
|
||||
|
||||
class TestWalletPassword(WalletTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.asyncio_loop, self._stop_loop, self._loop_thread = util.create_and_start_event_loop()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.asyncio_loop.call_soon_threadsafe(self._stop_loop.set_result, 1)
|
||||
self._loop_thread.join(timeout=1)
|
||||
|
||||
def test_update_password_of_imported_wallet(self):
|
||||
wallet_str = '{"addr_history":{"1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr":[],"15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA":[],"1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6":[]},"addresses":{"change":[],"receiving":["1364Js2VG66BwRdkaoxAaFtdPb1eQgn8Dr","1Exet2BhHsFxKTwhnfdsBMkPYLGvobxuW6","15CyDgLffJsJgQrhcyooFH4gnVDG82pUrA"]},"keystore":{"keypairs":{"0344b1588589958b0bcab03435061539e9bcf54677c104904044e4f8901f4ebdf5":"L2sED74axVXC4H8szBJ4rQJrkfem7UMc6usLCPUoEWxDCFGUaGUM","0389508c13999d08ffae0f434a085f4185922d64765c0bff2f66e36ad7f745cc5f":"L3Gi6EQLvYw8gEEUckmqawkevfj9s8hxoQDFveQJGZHTfyWnbk1U","04575f52b82f159fa649d2a4c353eb7435f30206f0a6cb9674fbd659f45082c37d559ffd19bea9c0d3b7dcc07a7b79f4cffb76026d5d4dff35341efe99056e22d2":"5JyVyXU1LiRXATvRTQvR9Kp8Rx1X84j2x49iGkjSsXipydtByUq"},"type":"imported"},"pruned_txo":{},"seed_version":13,"stored_height":-1,"transactions":{},"tx_fees":{},"txi":{},"txo":{},"use_encryption":false,"verified_tx3":{},"wallet_type":"standard","winpos-qt":[100,100,840,405]}'
|
||||
db = WalletDB(wallet_str, manual_upgrades=False)
|
||||
@@ -273,7 +284,7 @@ class TestWalletPassword(WalletTestCase):
|
||||
db = WalletDB(wallet_str, manual_upgrades=False)
|
||||
storage = WalletStorage(self.wallet_path)
|
||||
wallet = Wallet(db, storage, config=self.config)
|
||||
wallet.stop()
|
||||
asyncio.run_coroutine_threadsafe(wallet.stop(), self.asyncio_loop).result()
|
||||
|
||||
storage = WalletStorage(self.wallet_path)
|
||||
# if storage.is_encrypted():
|
||||
|
||||
@@ -1205,11 +1205,9 @@ class NetworkJobOnDefaultServer(Logger, ABC):
|
||||
if taskgroup != self.taskgroup:
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
async def stop(self):
|
||||
unregister_callback(self._restart)
|
||||
await self._stop()
|
||||
|
||||
async def _stop(self):
|
||||
async def stop(self, *, full_shutdown: bool = True):
|
||||
if full_shutdown:
|
||||
unregister_callback(self._restart)
|
||||
await self.taskgroup.cancel_remaining()
|
||||
|
||||
@log_exceptions
|
||||
@@ -1219,7 +1217,7 @@ class NetworkJobOnDefaultServer(Logger, ABC):
|
||||
return # we should get called again soon
|
||||
|
||||
async with self._restart_lock:
|
||||
await self._stop()
|
||||
await self.stop(full_shutdown=False)
|
||||
self._reset()
|
||||
await self._start(interface)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ import itertools
|
||||
import threading
|
||||
import enum
|
||||
|
||||
from aiorpcx import TaskGroup
|
||||
from aiorpcx import TaskGroup, timeout_after, TaskTimeout
|
||||
|
||||
from .i18n import _
|
||||
from .bip32 import BIP32Node, convert_bip32_intpath_to_strpath, convert_bip32_path_to_list_of_uint32
|
||||
@@ -353,15 +353,21 @@ class Abstract_Wallet(AddressSynchronizer, ABC):
|
||||
ln_xprv = node.to_xprv()
|
||||
self.db.put('lightning_privkey2', ln_xprv)
|
||||
|
||||
def stop(self):
|
||||
super().stop()
|
||||
if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]):
|
||||
self.save_keystore()
|
||||
if self.network:
|
||||
if self.lnworker:
|
||||
self.lnworker.stop()
|
||||
self.lnworker = None
|
||||
self.save_db()
|
||||
async def stop(self):
|
||||
"""Stop all networking and save DB to disk."""
|
||||
try:
|
||||
async with timeout_after(5):
|
||||
await super().stop()
|
||||
if self.network:
|
||||
if self.lnworker:
|
||||
await self.lnworker.stop()
|
||||
self.lnworker = None
|
||||
except TaskTimeout:
|
||||
pass
|
||||
finally: # even if we get cancelled
|
||||
if any([ks.is_requesting_to_be_rewritten_to_wallet_file for ks in self.get_keystores()]):
|
||||
self.save_keystore()
|
||||
self.save_db()
|
||||
|
||||
def set_up_to_date(self, b):
|
||||
super().set_up_to_date(b)
|
||||
|
||||
@@ -345,7 +345,6 @@ def main():
|
||||
print_stderr('unknown command:', uri)
|
||||
sys.exit(1)
|
||||
|
||||
# singleton
|
||||
config = SimpleConfig(config_options)
|
||||
|
||||
if config.get('testnet'):
|
||||
|
||||
Reference in New Issue
Block a user