diff --git a/electrum/address_synchronizer.py b/electrum/address_synchronizer.py index 1d559a762..e1212a527 100644 --- a/electrum/address_synchronizer.py +++ b/electrum/address_synchronizer.py @@ -54,6 +54,16 @@ TX_TIMESTAMP_INF = 999_999_999_999 TX_HEIGHT_INF = 10 ** 9 +from enum import IntEnum, auto + +class TxMinedDepth(IntEnum): + """ IntEnum because we call min() in get_deepest_tx_mined_depth_for_txids """ + DEEP = auto() + SHALLOW = auto() + MEMPOOL = auto() + FREE = auto() + + class HistoryItem(NamedTuple): txid: str tx_mined_status: TxMinedInfo @@ -990,3 +1000,46 @@ class AddressSynchronizer(Logger, EventListener): tx_age = self.get_local_height() - tx_height + 1 max_conf = max(max_conf, tx_age) return max_conf >= req_conf + + def get_spender(self, outpoint: str) -> str: + """ + returns txid spending outpoint. + subscribes to addresses as a side effect. + """ + prev_txid, index = outpoint.split(':') + spender_txid = self.db.get_spent_outpoint(prev_txid, int(index)) + # discard local spenders + tx_mined_status = self.get_tx_height(spender_txid) + if tx_mined_status.height in [TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE]: + spender_txid = None + if not spender_txid: + return + spender_tx = self.get_transaction(spender_txid) + for i, o in enumerate(spender_tx.outputs()): + if o.address is None: + continue + if not self.is_mine(o.address): + self.add_address(o.address) + return spender_txid + + def get_tx_mined_depth(self, txid: str): + if not txid: + return TxMinedDepth.FREE + tx_mined_depth = self.get_tx_height(txid) + height, conf = tx_mined_depth.height, tx_mined_depth.conf + if conf > 20: + return TxMinedDepth.DEEP + elif conf > 0: + return TxMinedDepth.SHALLOW + elif height in (TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_UNCONF_PARENT): + return TxMinedDepth.MEMPOOL + elif height in (TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE): + return TxMinedDepth.FREE + elif height > 0 and conf == 0: + # unverified but claimed to be mined + return TxMinedDepth.MEMPOOL + else: + raise NotImplementedError() + + def is_deeply_mined(self, txid): + return self.get_tx_mined_depth(txid) == TxMinedDepth.DEEP diff --git a/electrum/lnwatcher.py b/electrum/lnwatcher.py index 61bd3c08f..d9e9b1734 100644 --- a/electrum/lnwatcher.py +++ b/electrum/lnwatcher.py @@ -3,11 +3,9 @@ # file LICENCE or http://www.opensource.org/licenses/mit-license.php from typing import TYPE_CHECKING -from enum import IntEnum, auto -from .util import log_exceptions, TxMinedInfo, BelowDustLimit +from .util import TxMinedInfo, BelowDustLimit from .util import EventListener, event_listener -from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_FUTURE from .transaction import Transaction, TxOutpoint from .logging import Logger @@ -17,27 +15,18 @@ if TYPE_CHECKING: from .lnsweep import SweepInfo from .lnworker import LNWallet from .lnchannel import AbstractChannel - from .simple_config import SimpleConfig - - -class TxMinedDepth(IntEnum): - """ IntEnum because we call min() in get_deepest_tx_mined_depth_for_txids """ - DEEP = auto() - SHALLOW = auto() - MEMPOOL = auto() - FREE = auto() class LNWatcher(Logger, EventListener): LOGGING_SHORTCUT = 'W' - def __init__(self, adb: 'AddressSynchronizer', config: 'SimpleConfig'): - + def __init__(self, lnworker: 'LNWallet'): + self.lnworker = lnworker Logger.__init__(self) - self.adb = adb - self.config = config - self.callbacks = {} # address -> lambda: coroutine + self.adb = lnworker.wallet.adb + self.config = lnworker.config + self.callbacks = {} # address -> lambda function self.network = None self.register_callbacks() # status gets populated when we run @@ -46,16 +35,12 @@ class LNWatcher(Logger, EventListener): def start_network(self, network: 'Network'): self.network = network - async def stop(self): + def stop(self): self.unregister_callbacks() def get_channel_status(self, outpoint): return self.channel_status.get(outpoint, 'unknown') - def unwatch_channel(self, address, funding_outpoint): - self.logger.info(f'unwatching {funding_outpoint}') - self.remove_callback(address) - def remove_callback(self, address): self.callbacks.pop(address, None) @@ -63,87 +48,40 @@ class LNWatcher(Logger, EventListener): self.adb.add_address(address) self.callbacks[address] = callback - @event_listener - async def on_event_blockchain_updated(self, *args): - await self.trigger_callbacks() - - @event_listener - async def on_event_wallet_updated(self, wallet): - # called if we add local tx - if wallet.adb != self.adb: - return - await self.trigger_callbacks() - - @event_listener - async def on_event_adb_added_verified_tx(self, adb, tx_hash): - if adb != self.adb: - return - await self.trigger_callbacks() - - @event_listener - async def on_event_adb_set_up_to_date(self, adb): - if adb != self.adb: - return - await self.trigger_callbacks() - - @log_exceptions - async def trigger_callbacks(self): + def trigger_callbacks(self): if not self.adb.synchronizer: self.logger.info("synchronizer not set yet") return for address, callback in list(self.callbacks.items()): callback() - def get_spender(self, outpoint) -> str: - """ - returns txid spending outpoint. - subscribes to addresses as a side effect. - """ - prev_txid, index = outpoint.split(':') - spender_txid = self.adb.db.get_spent_outpoint(prev_txid, int(index)) - # discard local spenders - tx_mined_status = self.adb.get_tx_height(spender_txid) - if tx_mined_status.height in [TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE]: - spender_txid = None - if not spender_txid: + @event_listener + async def on_event_blockchain_updated(self, *args): + # we invalidate the cache on each new block because + # some processes affect the list of sweep transactions + # (hold invoice preimage revealed, MPP completed, etc) + for chan in self.lnworker.channels.values(): + chan._sweep_info.clear() + self.trigger_callbacks() + + @event_listener + def on_event_wallet_updated(self, wallet): + # called if we add local tx + if wallet.adb != self.adb: return - spender_tx = self.adb.get_transaction(spender_txid) - for i, o in enumerate(spender_tx.outputs()): - if o.address is None: - continue - if not self.adb.is_mine(o.address): - self.adb.add_address(o.address) - return spender_txid + self.trigger_callbacks() - def get_tx_mined_depth(self, txid: str): - if not txid: - return TxMinedDepth.FREE - tx_mined_depth = self.adb.get_tx_height(txid) - height, conf = tx_mined_depth.height, tx_mined_depth.conf - if conf > 20: - return TxMinedDepth.DEEP - elif conf > 0: - return TxMinedDepth.SHALLOW - elif height in (TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_UNCONF_PARENT): - return TxMinedDepth.MEMPOOL - elif height in (TX_HEIGHT_LOCAL, TX_HEIGHT_FUTURE): - return TxMinedDepth.FREE - elif height > 0 and conf == 0: - # unverified but claimed to be mined - return TxMinedDepth.MEMPOOL - else: - raise NotImplementedError() + @event_listener + def on_event_adb_added_verified_tx(self, adb, tx_hash): + if adb != self.adb: + return + self.trigger_callbacks() - def is_deeply_mined(self, txid): - return self.get_tx_mined_depth(txid) == TxMinedDepth.DEEP - - - -class LNWalletWatcher(LNWatcher): - - def __init__(self, lnworker: 'LNWallet'): - self.lnworker = lnworker - LNWatcher.__init__(self, lnworker.wallet.adb, lnworker.config) + @event_listener + def on_event_adb_set_up_to_date(self, adb): + if adb != self.adb: + return + self.trigger_callbacks() def add_channel(self, chan: 'AbstractChannel') -> None: outpoint = chan.funding_outpoint.to_str() @@ -153,6 +91,10 @@ class LNWalletWatcher(LNWatcher): if chan.need_to_subscribe(): self.add_callback(address, callback) + def unwatch_channel(self, address, funding_outpoint): + self.logger.info(f'unwatching {funding_outpoint}') + self.remove_callback(address) + def check_onchain_situation(self, address, funding_outpoint): # early return if address has not been added yet if not self.adb.is_mine(address): @@ -160,7 +102,7 @@ class LNWalletWatcher(LNWatcher): # inspect_tx_candidate might have added new addresses, in which case we return early funding_txid = funding_outpoint.split(':')[0] funding_height = self.adb.get_tx_height(funding_txid) - closing_txid = self.get_spender(funding_outpoint) + closing_txid = self.adb.get_spender(funding_outpoint) closing_height = self.adb.get_tx_height(closing_txid) if closing_txid: closing_tx = self.adb.get_transaction(closing_txid) @@ -181,16 +123,6 @@ class LNWalletWatcher(LNWatcher): if not keep_watching: self.unwatch_channel(address, funding_outpoint) - @event_listener - async def on_event_blockchain_updated(self, *args): - # overload parent method with cache invalidation - # we invalidate the cache on each new block because - # some processes affect the list of sweep transactions - # (hold invoice preimage revealed, MPP completed, etc) - for chan in self.lnworker.channels.values(): - chan._sweep_info.clear() - await self.trigger_callbacks() - def diagnostic_name(self): return f"{self.lnworker.wallet.diagnostic_name()}-LNW" @@ -223,8 +155,7 @@ class LNWalletWatcher(LNWatcher): return False # detect who closed and get information about how to claim outputs sweep_info_dict = chan.sweep_ctx(closing_tx) - #self.logger.info(f"do_breach_remedy: {[x.name for x in sweep_info_dict.values()]}") - keep_watching = False if sweep_info_dict else not self.is_deeply_mined(closing_tx.txid()) + keep_watching = False if sweep_info_dict else not self.adb.is_deeply_mined(closing_tx.txid()) # create and broadcast transactions for prevout, sweep_info in sweep_info_dict.items(): prev_txid, prev_index = prevout.split(':') @@ -234,19 +165,19 @@ class LNWalletWatcher(LNWatcher): # do not keep watching if prevout does not exist self.logger.info(f'prevout does not exist for {name}: {prevout}') continue - spender_txid = self.get_spender(prevout) + spender_txid = self.adb.get_spender(prevout) spender_tx = self.adb.get_transaction(spender_txid) if spender_txid else None if spender_tx: # the spender might be the remote, revoked or not htlc_sweepinfo = chan.maybe_sweep_htlcs(closing_tx, spender_tx) for prevout2, htlc_sweep_info in htlc_sweepinfo.items(): - htlc_tx_spender = self.get_spender(prevout2) + htlc_tx_spender = self.adb.get_spender(prevout2) self.lnworker.wallet.set_default_label(prevout2, htlc_sweep_info.name) if htlc_tx_spender: - keep_watching |= not self.is_deeply_mined(htlc_tx_spender) + keep_watching |= not self.adb.is_deeply_mined(htlc_tx_spender) else: keep_watching |= self.maybe_redeem(htlc_sweep_info) - keep_watching |= not self.is_deeply_mined(spender_txid) + keep_watching |= not self.adb.is_deeply_mined(spender_txid) self.maybe_extract_preimage(chan, spender_tx, prevout) else: keep_watching |= self.maybe_redeem(sweep_info) @@ -266,5 +197,5 @@ class LNWalletWatcher(LNWatcher): spender_txin = spender_tx.inputs()[txin_idx] chan.extract_preimage_from_htlc_txin( spender_txin, - is_deeply_mined=self.is_deeply_mined(spender_tx.txid()), + is_deeply_mined=self.adb.is_deeply_mined(spender_tx.txid()), ) diff --git a/electrum/lnworker.py b/electrum/lnworker.py index 8d72b1d9c..24152a9b4 100644 --- a/electrum/lnworker.py +++ b/electrum/lnworker.py @@ -69,7 +69,7 @@ from .lnmsg import decode_msg from .lnrouter import ( RouteEdge, LNPaymentRoute, LNPaymentPath, is_route_within_budget, NoChannelPolicy, LNPathInconsistent ) -from .lnwatcher import LNWalletWatcher +from .lnwatcher import LNWatcher from .submarine_swaps import SwapManager from .mpp_split import suggest_splits, SplitConfigRating from .trampoline import ( @@ -817,7 +817,7 @@ class PaySession(Logger): class LNWallet(LNWorker): - lnwatcher: Optional['LNWalletWatcher'] + lnwatcher: Optional['LNWatcher'] MPP_EXPIRY = 120 TIMEOUT_SHUTDOWN_FAIL_PENDING_HTLCS = 3 # seconds PAYMENT_TIMEOUT = 120 @@ -842,7 +842,7 @@ class LNWallet(LNWorker): if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP: features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch LNWorker.__init__(self, self.node_keypair, features, config=self.config) - self.lnwatcher = LNWalletWatcher(self) + self.lnwatcher = LNWatcher(self) self.lnrater: LNRater = None self.payment_info = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage @@ -999,7 +999,7 @@ class LNWallet(LNWorker): await self.wait_for_received_pending_htlcs_to_get_removed() await LNWorker.stop(self) if self.lnwatcher: - await self.lnwatcher.stop() + self.lnwatcher.stop() self.lnwatcher = None if self.swap_manager and self.swap_manager.network: # may not be present in tests await self.swap_manager.stop() diff --git a/electrum/plugins/watchtower/watchtower.py b/electrum/plugins/watchtower/watchtower.py index ce8fdabc5..b992a66a4 100644 --- a/electrum/plugins/watchtower/watchtower.py +++ b/electrum/plugins/watchtower/watchtower.py @@ -24,19 +24,21 @@ # SOFTWARE. -import asyncio, os +import asyncio +import os from typing import TYPE_CHECKING -from typing import NamedTuple, Dict +from typing import Dict from electrum.util import log_exceptions, random_shuffled_copy -from electrum.plugin import BasePlugin, hook +from electrum.plugin import BasePlugin from electrum.sql_db import SqlDB, sql -from electrum.lnwatcher import LNWatcher from electrum.transaction import Transaction, match_script_against_template from electrum.network import Network from electrum.address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL from electrum.wallet_db import WalletDB from electrum.lnutil import WITNESS_TEMPLATE_RECEIVED_HTLC, WITNESS_TEMPLATE_OFFERED_HTLC +from electrum.logging import Logger +from electrum.util import EventListener, event_listener from .server import WatchTowerServer @@ -60,24 +62,51 @@ class WatchtowerPlugin(BasePlugin): asyncio.run_coroutine_threadsafe(self.network.taskgroup.spawn(self.server.run), self.network.asyncio_loop) -class WatchTower(LNWatcher): +class WatchTower(Logger, EventListener): LOGGING_SHORTCUT = 'W' def __init__(self, network: 'Network'): - adb = AddressSynchronizer(WalletDB('', storage=None, upgrade=True), network.config, name=self.diagnostic_name()) - adb.start_network(network) - LNWatcher.__init__(self, adb, network) + Logger.__init__(self) + self.adb = AddressSynchronizer(WalletDB('', storage=None, upgrade=True), network.config, name=self.diagnostic_name()) + self.adb.start_network(network) + self.config = network.config + self.callbacks = {} # address -> lambda function + self.register_callbacks() + # status gets populated when we run + self.channel_status = {} self.network = network self.sweepstore = SweepStore(os.path.join(self.network.config.path, "watchtower_db"), network) - async def stop(self): - await super().stop() - await self.adb.stop() + def remove_callback(self, address): + self.callbacks.pop(address, None) - def add_channel(self, outpoint: str, address: str) -> None: - callback = lambda: self.check_onchain_situation(address, outpoint) - self.add_callback(address, callback) + def add_callback(self, address, callback): + self.adb.add_address(address) + self.callbacks[address] = callback + + @event_listener + async def on_event_blockchain_updated(self, *args): + await self.trigger_callbacks() + + @event_listener + async def on_event_wallet_updated(self, wallet): + # called if we add local tx + if wallet.adb != self.adb: + return + await self.trigger_callbacks() + + @event_listener + async def on_event_adb_added_verified_tx(self, adb, tx_hash): + if adb != self.adb: + return + await self.trigger_callbacks() + + @event_listener + async def on_event_adb_set_up_to_date(self, adb): + if adb != self.adb: + return + await self.trigger_callbacks() @log_exceptions async def trigger_callbacks(self): @@ -87,6 +116,14 @@ class WatchTower(LNWatcher): for address, callback in list(self.callbacks.items()): await callback() + async def stop(self): + self.unregister_callbacks() + await self.adb.stop() + + def add_channel(self, outpoint: str, address: str) -> None: + callback = lambda: self.check_onchain_situation(address, outpoint) + self.add_callback(address, callback) + def diagnostic_name(self): return "watchtower" @@ -102,10 +139,7 @@ class WatchTower(LNWatcher): if not self.adb.is_mine(address): return # inspect_tx_candidate might have added new addresses, in which case we return early - funding_txid = funding_outpoint.split(':')[0] - funding_height = self.adb.get_tx_height(funding_txid) - closing_txid = self.get_spender(funding_outpoint) - closing_height = self.adb.get_tx_height(closing_txid) + closing_txid = self.adb.get_spender(funding_outpoint) if closing_txid: closing_tx = self.adb.get_transaction(closing_txid) if closing_tx: @@ -132,7 +166,7 @@ class WatchTower(LNWatcher): if n == 0: if spender_txid is None: self.channel_status[outpoint] = 'open' - elif not self.is_deeply_mined(spender_txid): + elif not self.adb.is_deeply_mined(spender_txid): self.channel_status[outpoint] = 'closed (%d)' % self.adb.get_tx_height(spender_txid).conf else: self.channel_status[outpoint] = 'closed (deep)' @@ -166,7 +200,7 @@ class WatchTower(LNWatcher): if not self.adb.is_mine(o.address): self.adb.add_address(o.address) elif n < 2: - r = self.inspect_tx_candidate(spender_txid+':%d'%i, n+1) + r = self.inspect_tx_candidate(spender_txid + ':%d' % i, n + 1) result.update(r) return result @@ -175,7 +209,7 @@ class WatchTower(LNWatcher): keep_watching = False for prevout, spender in spenders.items(): if spender is not None: - keep_watching |= not self.is_deeply_mined(spender) + keep_watching |= not self.adb.is_deeply_mined(spender) continue sweep_txns = await self.sweepstore.get_sweep_tx(funding_outpoint, prevout) for tx in sweep_txns: @@ -217,13 +251,10 @@ class WatchTower(LNWatcher): return self.network.run_from_another_thread(f()) async def unwatch_channel(self, address, funding_outpoint): - await super().unwatch_channel(address, funding_outpoint) await self.sweepstore.remove_sweep_tx(funding_outpoint) await self.sweepstore.remove_channel(funding_outpoint) - - create_sweep_txs=""" CREATE TABLE IF NOT EXISTS sweep_txs ( funding_outpoint VARCHAR(34) NOT NULL, @@ -319,5 +350,3 @@ class SweepStore(SqlDB): c = self.conn.cursor() c.execute("SELECT outpoint, address FROM channel_info") return [(r[0], r[1]) for r in c.fetchall()] - -