diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index 086000362..61bd3c08f 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING from enum import IntEnum, auto -from .util import log_exceptions, ignore_exceptions, TxMinedInfo, BelowDustLimit +from .util import log_exceptions, TxMinedInfo, BelowDustLimit from .util import EventListener, event_listener from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_FUTURE from .transaction import Transaction, TxOutpoint @@ -17,6 +17,8 @@ if TYPE_CHECKING: from .lnsweep import SweepInfo from .lnworker import LNWallet from .lnchannel import AbstractChannel + from .simple_config import SimpleConfig + class TxMinedDepth(IntEnum): """ IntEnum because we call min() in get_deepest_tx_mined_depth_for_txids """ @@ -30,30 +32,27 @@ class LNWatcher(Logger, EventListener): LOGGING_SHORTCUT = 'W' - def __init__(self, adb: 'AddressSynchronizer', network: 'Network'): + def __init__(self, adb: 'AddressSynchronizer', config: 'SimpleConfig'): Logger.__init__(self) self.adb = adb - self.config = network.config - self.callbacks = {} # address -> lambda: coroutine - self.network = network + self.config = config + self.callbacks = {} # address -> lambda: coroutine + self.network = None self.register_callbacks() # status gets populated when we run self.channel_status = {} + def start_network(self, network: 'Network'): + self.network = network + async def stop(self): self.unregister_callbacks() def get_channel_status(self, outpoint): return self.channel_status.get(outpoint, 'unknown') - def add_channel(self, outpoint: str, address: str) -> None: - assert isinstance(outpoint, str) - assert isinstance(address, str) - cb = lambda: self.check_onchain_situation(address, outpoint) - self.add_callback(address, cb) - - async def unwatch_channel(self, address, funding_outpoint): + def unwatch_channel(self, address, funding_outpoint): self.logger.info(f'unwatching {funding_outpoint}') self.remove_callback(address) @@ -93,46 +92,7 @@ class LNWatcher(Logger, EventListener): self.logger.info("synchronizer not set yet") return for address, callback in list(self.callbacks.items()): - await callback() - - async def check_onchain_situation(self, address, funding_outpoint): - # early return if address has not been added yet - if not self.adb.is_mine(address): - return - # inspect_tx_candidate might have added new addresses, in which case we return early - if not self.adb.is_up_to_date(): - return - funding_txid = funding_outpoint.split(':')[0] - funding_height = self.adb.get_tx_height(funding_txid) - closing_txid = self.get_spender(funding_outpoint) - closing_height = self.adb.get_tx_height(closing_txid) - if closing_txid: - closing_tx = self.adb.get_transaction(closing_txid) - if closing_tx: - keep_watching = await self.sweep_commitment_transaction(funding_outpoint, closing_tx) - else: - self.logger.info(f"channel {funding_outpoint} closed by {closing_txid}. still waiting for tx itself...") - keep_watching = True - else: - keep_watching = True - await self.update_channel_state( - funding_outpoint=funding_outpoint, - funding_txid=funding_txid, - funding_height=funding_height, - closing_txid=closing_txid, - closing_height=closing_height, - keep_watching=keep_watching) - if not keep_watching: - await self.unwatch_channel(address, funding_outpoint) - - async def sweep_commitment_transaction(self, funding_outpoint: str, closing_tx: Transaction) -> bool: - raise NotImplementedError() # implemented by subclasses - - async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str, - funding_height: TxMinedInfo, closing_txid: str, - closing_height: TxMinedInfo, keep_watching: bool) -> None: - raise NotImplementedError() # implemented by subclasses - + callback() def get_spender(self, outpoint) -> str: """ @@ -181,10 +141,45 @@ class LNWatcher(Logger, EventListener): class LNWalletWatcher(LNWatcher): - def __init__(self, lnworker: 'LNWallet', network: 'Network'): - self.network = network + def __init__(self, lnworker: 'LNWallet'): self.lnworker = lnworker - LNWatcher.__init__(self, lnworker.wallet.adb, network) + LNWatcher.__init__(self, lnworker.wallet.adb, lnworker.config) + + def add_channel(self, chan: 'AbstractChannel') -> None: + outpoint = chan.funding_outpoint.to_str() + address = chan.get_funding_address() + callback = lambda: self.check_onchain_situation(address, outpoint) + callback() # run once, for side effects + if chan.need_to_subscribe(): + self.add_callback(address, callback) + + def check_onchain_situation(self, address, funding_outpoint): + # early return if address has not been added yet + if not self.adb.is_mine(address): + return + # inspect_tx_candidate might have added new addresses, in which case we return early + funding_txid = funding_outpoint.split(':')[0] + funding_height = self.adb.get_tx_height(funding_txid) + closing_txid = self.get_spender(funding_outpoint) + closing_height = self.adb.get_tx_height(closing_txid) + if closing_txid: + closing_tx = self.adb.get_transaction(closing_txid) + if closing_tx: + keep_watching = self.sweep_commitment_transaction(funding_outpoint, closing_tx) + else: + self.logger.info(f"channel {funding_outpoint} closed by {closing_txid}. still waiting for tx itself...") + keep_watching = True + else: + keep_watching = True + self.update_channel_state( + funding_outpoint=funding_outpoint, + funding_txid=funding_txid, + funding_height=funding_height, + closing_txid=closing_txid, + closing_height=closing_height, + keep_watching=keep_watching) + if not keep_watching: + self.unwatch_channel(address, funding_outpoint) @event_listener async def on_event_blockchain_updated(self, *args): @@ -199,11 +194,9 @@ class LNWalletWatcher(LNWatcher): def diagnostic_name(self): return f"{self.lnworker.wallet.diagnostic_name()}-LNW" - @ignore_exceptions - @log_exceptions - async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str, - funding_height: TxMinedInfo, closing_txid: str, - closing_height: TxMinedInfo, keep_watching: bool) -> None: + def update_channel_state(self, *, funding_outpoint: str, funding_txid: str, + funding_height: TxMinedInfo, closing_txid: str, + closing_height: TxMinedInfo, keep_watching: bool) -> None: chan = self.lnworker.channel_by_txo(funding_outpoint) if not chan: return @@ -213,15 +206,18 @@ class LNWalletWatcher(LNWatcher): closing_txid=closing_txid, closing_height=closing_height, keep_watching=keep_watching) - await self.lnworker.handle_onchain_state(chan) + self.lnworker.handle_onchain_state(chan) - @log_exceptions - async def sweep_commitment_transaction(self, funding_outpoint, closing_tx) -> bool: + def sweep_commitment_transaction(self, funding_outpoint, closing_tx) -> bool: """This function is called when a channel was closed. In this case we need to check for redeemable outputs of the commitment transaction or spenders down the line (HTLC-timeout/success transactions). - Returns whether we should continue to monitor.""" + Returns whether we should continue to monitor. + + Side-effécts: + - sets defaults labels + """ chan = self.lnworker.channel_by_txo(funding_outpoint) if not chan: return False diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 525d531fa..8d72b1d9c 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -842,7 +842,7 @@ class LNWallet(LNWorker): if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP: features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch LNWorker.__init__(self, self.node_keypair, features, config=self.config) - self.lnwatcher = None + self.lnwatcher = LNWalletWatcher(self) self.lnrater: LNRater = None self.payment_info = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage @@ -890,6 +890,13 @@ class LNWallet(LNWorker): self.nostr_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NOSTR_KEY) self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) self.onion_message_manager = OnionMessageManager(self) + self.subscribe_to_channels() + + def subscribe_to_channels(self): + for chan in self.channels.values(): + self.lnwatcher.add_channel(chan) + for cb in self.channel_backups.values(): + self.lnwatcher.add_channel(cb) def has_deterministic_node_id(self) -> bool: return bool(self.db.get('lightning_xprv')) @@ -970,18 +977,11 @@ class LNWallet(LNWorker): def start_network(self, network: 'Network'): super().start_network(network) - self.lnwatcher = LNWalletWatcher(self, network) + self.lnwatcher.start_network(network) self.swap_manager.start_network(network) self.lnrater = LNRater(self, network) self.onion_message_manager.start_network(network=network) - for chan in self.channels.values(): - if chan.need_to_subscribe(): - self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) - for cb in self.channel_backups.values(): - if cb.need_to_subscribe(): - self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) - for coro in [ self.maybe_listen(), self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified @@ -1193,15 +1193,15 @@ class LNWallet(LNWorker): if chan.funding_outpoint.to_str() == txo: return chan - async def handle_onchain_state(self, chan: Channel): + def handle_onchain_state(self, chan: Channel): if type(chan) is ChannelBackup: util.trigger_callback('channel', self.wallet, chan) return if (chan.get_state() in (ChannelState.OPEN, ChannelState.SHUTDOWN) - and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height())): + and chan.should_be_closed_due_to_expiring_htlcs(self.wallet.adb.get_local_height())): self.logger.info(f"force-closing due to expiring htlcs") - await self.schedule_force_closing(chan.channel_id) + asyncio.ensure_future(self.schedule_force_closing(chan.channel_id)) elif chan.get_state() == ChannelState.FUNDED: peer = self._peers.get(chan.node_id) @@ -1220,7 +1220,7 @@ class LNWallet(LNWorker): height = self.lnwatcher.adb.get_tx_height(txid).height if height == TX_HEIGHT_LOCAL: self.logger.info('REBROADCASTING CLOSING TX') - await self.network.try_broadcasting(force_close_tx, 'force-close') + asyncio.ensure_future(self.network.try_broadcasting(force_close_tx, 'force-close')) def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]: for nodeid, peer in self.peers.items(): @@ -1363,7 +1363,7 @@ class LNWallet(LNWorker): def add_channel(self, chan: Channel): with self.lock: self._channels[chan.channel_id] = chan - self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) + self.lnwatcher.add_channel(chan) def add_new_channel(self, chan: Channel): self.add_channel(chan) @@ -1953,7 +1953,7 @@ class LNWallet(LNWorker): We first try to conduct the payment over a single channel. If that fails and mpp is supported by the receiver, we will split the payment.""" trampoline_features = LnFeatures.VAR_ONION_OPT - local_height = self.network.get_local_height() + local_height = self.wallet.adb.get_local_height() fee_related_error = None # type: Optional[FeeBudgetExceeded] if channels: my_active_channels = channels @@ -3069,7 +3069,7 @@ class LNWallet(LNWorker): self.wallet.set_reserved_addresses_for_chan(cb, reserved=True) self.wallet.save_db() util.trigger_callback('channels_updated', self.wallet) - self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) + self.lnwatcher.add_channel(cb) def has_conflicting_backup_with(self, remote_node_id: bytes): """ Returns whether we have an active channel with this node on another device, using same local node id. """ @@ -3186,7 +3186,7 @@ class LNWallet(LNWorker): with self.lock: self._channel_backups[bfh(channel_id)] = cb util.trigger_callback('channels_updated', self.wallet) - self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) + self.lnwatcher.add_channel(cb) def save_forwarding_failure( self, payment_key:str, *, diff --git a/electrum/plugins/watchtower/watchtower.py b/electrum/plugins/watchtower/watchtower.py index 1b59d328c..ce8fdabc5 100644 --- a/electrum/plugins/watchtower/watchtower.py +++ b/electrum/plugins/watchtower/watchtower.py @@ -75,8 +75,20 @@ class WatchTower(LNWatcher): await super().stop() await self.adb.stop() + def add_channel(self, outpoint: str, address: str) -> None: + callback = lambda: self.check_onchain_situation(address, outpoint) + self.add_callback(address, callback) + + @log_exceptions + async def trigger_callbacks(self): + if not self.adb.synchronizer: + self.logger.info("synchronizer not set yet") + return + for address, callback in list(self.callbacks.items()): + await callback() + def diagnostic_name(self): - return "local_tower" + return "watchtower" @log_exceptions async def start_watching(self): @@ -85,6 +97,27 @@ class WatchTower(LNWatcher): for outpoint, address in random_shuffled_copy(lst): self.add_channel(outpoint, address) + async def check_onchain_situation(self, address, funding_outpoint): + # early return if address has not been added yet + if not self.adb.is_mine(address): + return + # inspect_tx_candidate might have added new addresses, in which case we return early + funding_txid = funding_outpoint.split(':')[0] + funding_height = self.adb.get_tx_height(funding_txid) + closing_txid = self.get_spender(funding_outpoint) + closing_height = self.adb.get_tx_height(closing_txid) + if closing_txid: + closing_tx = self.adb.get_transaction(closing_txid) + if closing_tx: + keep_watching = await self.sweep_commitment_transaction(funding_outpoint, closing_tx) + else: + self.logger.info(f"channel {funding_outpoint} closed by {closing_txid}. still waiting for tx itself...") + keep_watching = True + else: + keep_watching = True + if not keep_watching: + await self.unwatch_channel(address, funding_outpoint) + def inspect_tx_candidate(self, outpoint, n: int) -> Dict[str, str]: """ returns a dict of spenders for a transaction of interest. @@ -188,8 +221,6 @@ class WatchTower(LNWatcher): await self.sweepstore.remove_sweep_tx(funding_outpoint) await self.sweepstore.remove_channel(funding_outpoint) - async def update_channel_state(self, *args, **kwargs): - pass diff --git a/electrum/submarine_swaps.py b/electrum/submarine_swaps.py index 3704a88ca..25f8653dd 100644 --- a/electrum/submarine_swaps.py +++ b/electrum/submarine_swaps.py @@ -179,6 +179,7 @@ class SwapManager(Logger): self.wallet = wallet self.config = wallet.config self.lnworker = lnworker + self.lnwatcher = self.lnworker.lnwatcher self.config = wallet.config self.taskgroup = OldTaskGroup() self.dummy_address = DummyAddress.SWAP @@ -207,7 +208,6 @@ class SwapManager(Logger): return self.logger.info('start_network: starting main loop') self.network = network - self.lnwatcher = self.lnworker.lnwatcher for k, swap in self.swaps.items(): if swap.is_redeemed: continue @@ -321,8 +321,7 @@ class SwapManager(Logger): if sha256(preimage) == swap.payment_hash: return preimage - @log_exceptions - async def _claim_swap(self, swap: SwapData) -> None: + def _claim_swap(self, swap: SwapData) -> None: assert self.network assert self.lnwatcher if not self.lnwatcher.adb.is_up_to_date(): diff --git a/tests/test_lnpeer.py b/tests/test_lnpeer.py index 940e72247..3b1630cbf 100644 --- a/tests/test_lnpeer.py +++ b/tests/test_lnpeer.py @@ -108,8 +108,12 @@ class MockBlockchain: class MockADB: + def __init__(self): + self._blockchain = MockBlockchain() def add_transaction(self, tx): pass + def get_local_height(self): + return self._blockchain.height() class MockWallet: receive_requests = {}