1
0

Make lnwatcher not async

This fixes offline history not having the proper labels
This commit is contained in:
ThomasV
2025-02-07 09:52:03 +01:00
parent 42b072aca8
commit fbebe7de1a
5 changed files with 117 additions and 87 deletions

View File

@@ -5,7 +5,7 @@
from typing import TYPE_CHECKING
from enum import IntEnum, auto
from .util import log_exceptions, ignore_exceptions, TxMinedInfo, BelowDustLimit
from .util import log_exceptions, TxMinedInfo, BelowDustLimit
from .util import EventListener, event_listener
from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_FUTURE
from .transaction import Transaction, TxOutpoint
@@ -17,6 +17,8 @@ if TYPE_CHECKING:
from .lnsweep import SweepInfo
from .lnworker import LNWallet
from .lnchannel import AbstractChannel
from .simple_config import SimpleConfig
class TxMinedDepth(IntEnum):
""" IntEnum because we call min() in get_deepest_tx_mined_depth_for_txids """
@@ -30,30 +32,27 @@ class LNWatcher(Logger, EventListener):
LOGGING_SHORTCUT = 'W'
def __init__(self, adb: 'AddressSynchronizer', network: 'Network'):
def __init__(self, adb: 'AddressSynchronizer', config: 'SimpleConfig'):
Logger.__init__(self)
self.adb = adb
self.config = network.config
self.callbacks = {} # address -> lambda: coroutine
self.network = network
self.config = config
self.callbacks = {} # address -> lambda: coroutine
self.network = None
self.register_callbacks()
# status gets populated when we run
self.channel_status = {}
def start_network(self, network: 'Network'):
self.network = network
async def stop(self):
self.unregister_callbacks()
def get_channel_status(self, outpoint):
return self.channel_status.get(outpoint, 'unknown')
def add_channel(self, outpoint: str, address: str) -> None:
assert isinstance(outpoint, str)
assert isinstance(address, str)
cb = lambda: self.check_onchain_situation(address, outpoint)
self.add_callback(address, cb)
async def unwatch_channel(self, address, funding_outpoint):
def unwatch_channel(self, address, funding_outpoint):
self.logger.info(f'unwatching {funding_outpoint}')
self.remove_callback(address)
@@ -93,46 +92,7 @@ class LNWatcher(Logger, EventListener):
self.logger.info("synchronizer not set yet")
return
for address, callback in list(self.callbacks.items()):
await callback()
async def check_onchain_situation(self, address, funding_outpoint):
# early return if address has not been added yet
if not self.adb.is_mine(address):
return
# inspect_tx_candidate might have added new addresses, in which case we return early
if not self.adb.is_up_to_date():
return
funding_txid = funding_outpoint.split(':')[0]
funding_height = self.adb.get_tx_height(funding_txid)
closing_txid = self.get_spender(funding_outpoint)
closing_height = self.adb.get_tx_height(closing_txid)
if closing_txid:
closing_tx = self.adb.get_transaction(closing_txid)
if closing_tx:
keep_watching = await self.sweep_commitment_transaction(funding_outpoint, closing_tx)
else:
self.logger.info(f"channel {funding_outpoint} closed by {closing_txid}. still waiting for tx itself...")
keep_watching = True
else:
keep_watching = True
await self.update_channel_state(
funding_outpoint=funding_outpoint,
funding_txid=funding_txid,
funding_height=funding_height,
closing_txid=closing_txid,
closing_height=closing_height,
keep_watching=keep_watching)
if not keep_watching:
await self.unwatch_channel(address, funding_outpoint)
async def sweep_commitment_transaction(self, funding_outpoint: str, closing_tx: Transaction) -> bool:
raise NotImplementedError() # implemented by subclasses
async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str,
funding_height: TxMinedInfo, closing_txid: str,
closing_height: TxMinedInfo, keep_watching: bool) -> None:
raise NotImplementedError() # implemented by subclasses
callback()
def get_spender(self, outpoint) -> str:
"""
@@ -181,10 +141,45 @@ class LNWatcher(Logger, EventListener):
class LNWalletWatcher(LNWatcher):
def __init__(self, lnworker: 'LNWallet', network: 'Network'):
self.network = network
def __init__(self, lnworker: 'LNWallet'):
self.lnworker = lnworker
LNWatcher.__init__(self, lnworker.wallet.adb, network)
LNWatcher.__init__(self, lnworker.wallet.adb, lnworker.config)
def add_channel(self, chan: 'AbstractChannel') -> None:
outpoint = chan.funding_outpoint.to_str()
address = chan.get_funding_address()
callback = lambda: self.check_onchain_situation(address, outpoint)
callback() # run once, for side effects
if chan.need_to_subscribe():
self.add_callback(address, callback)
def check_onchain_situation(self, address, funding_outpoint):
# early return if address has not been added yet
if not self.adb.is_mine(address):
return
# inspect_tx_candidate might have added new addresses, in which case we return early
funding_txid = funding_outpoint.split(':')[0]
funding_height = self.adb.get_tx_height(funding_txid)
closing_txid = self.get_spender(funding_outpoint)
closing_height = self.adb.get_tx_height(closing_txid)
if closing_txid:
closing_tx = self.adb.get_transaction(closing_txid)
if closing_tx:
keep_watching = self.sweep_commitment_transaction(funding_outpoint, closing_tx)
else:
self.logger.info(f"channel {funding_outpoint} closed by {closing_txid}. still waiting for tx itself...")
keep_watching = True
else:
keep_watching = True
self.update_channel_state(
funding_outpoint=funding_outpoint,
funding_txid=funding_txid,
funding_height=funding_height,
closing_txid=closing_txid,
closing_height=closing_height,
keep_watching=keep_watching)
if not keep_watching:
self.unwatch_channel(address, funding_outpoint)
@event_listener
async def on_event_blockchain_updated(self, *args):
@@ -199,11 +194,9 @@ class LNWalletWatcher(LNWatcher):
def diagnostic_name(self):
return f"{self.lnworker.wallet.diagnostic_name()}-LNW"
@ignore_exceptions
@log_exceptions
async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str,
funding_height: TxMinedInfo, closing_txid: str,
closing_height: TxMinedInfo, keep_watching: bool) -> None:
def update_channel_state(self, *, funding_outpoint: str, funding_txid: str,
funding_height: TxMinedInfo, closing_txid: str,
closing_height: TxMinedInfo, keep_watching: bool) -> None:
chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan:
return
@@ -213,15 +206,18 @@ class LNWalletWatcher(LNWatcher):
closing_txid=closing_txid,
closing_height=closing_height,
keep_watching=keep_watching)
await self.lnworker.handle_onchain_state(chan)
self.lnworker.handle_onchain_state(chan)
@log_exceptions
async def sweep_commitment_transaction(self, funding_outpoint, closing_tx) -> bool:
def sweep_commitment_transaction(self, funding_outpoint, closing_tx) -> bool:
"""This function is called when a channel was closed. In this case
we need to check for redeemable outputs of the commitment transaction
or spenders down the line (HTLC-timeout/success transactions).
Returns whether we should continue to monitor."""
Returns whether we should continue to monitor.
Side-effécts:
- sets defaults labels
"""
chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan:
return False

View File

@@ -842,7 +842,7 @@ class LNWallet(LNWorker):
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch
LNWorker.__init__(self, self.node_keypair, features, config=self.config)
self.lnwatcher = None
self.lnwatcher = LNWalletWatcher(self)
self.lnrater: LNRater = None
self.payment_info = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
@@ -890,6 +890,13 @@ class LNWallet(LNWorker):
self.nostr_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NOSTR_KEY)
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
self.onion_message_manager = OnionMessageManager(self)
self.subscribe_to_channels()
def subscribe_to_channels(self):
for chan in self.channels.values():
self.lnwatcher.add_channel(chan)
for cb in self.channel_backups.values():
self.lnwatcher.add_channel(cb)
def has_deterministic_node_id(self) -> bool:
return bool(self.db.get('lightning_xprv'))
@@ -970,18 +977,11 @@ class LNWallet(LNWorker):
def start_network(self, network: 'Network'):
super().start_network(network)
self.lnwatcher = LNWalletWatcher(self, network)
self.lnwatcher.start_network(network)
self.swap_manager.start_network(network)
self.lnrater = LNRater(self, network)
self.onion_message_manager.start_network(network=network)
for chan in self.channels.values():
if chan.need_to_subscribe():
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
for cb in self.channel_backups.values():
if cb.need_to_subscribe():
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
for coro in [
self.maybe_listen(),
self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified
@@ -1193,15 +1193,15 @@ class LNWallet(LNWorker):
if chan.funding_outpoint.to_str() == txo:
return chan
async def handle_onchain_state(self, chan: Channel):
def handle_onchain_state(self, chan: Channel):
if type(chan) is ChannelBackup:
util.trigger_callback('channel', self.wallet, chan)
return
if (chan.get_state() in (ChannelState.OPEN, ChannelState.SHUTDOWN)
and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height())):
and chan.should_be_closed_due_to_expiring_htlcs(self.wallet.adb.get_local_height())):
self.logger.info(f"force-closing due to expiring htlcs")
await self.schedule_force_closing(chan.channel_id)
asyncio.ensure_future(self.schedule_force_closing(chan.channel_id))
elif chan.get_state() == ChannelState.FUNDED:
peer = self._peers.get(chan.node_id)
@@ -1220,7 +1220,7 @@ class LNWallet(LNWorker):
height = self.lnwatcher.adb.get_tx_height(txid).height
if height == TX_HEIGHT_LOCAL:
self.logger.info('REBROADCASTING CLOSING TX')
await self.network.try_broadcasting(force_close_tx, 'force-close')
asyncio.ensure_future(self.network.try_broadcasting(force_close_tx, 'force-close'))
def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
for nodeid, peer in self.peers.items():
@@ -1363,7 +1363,7 @@ class LNWallet(LNWorker):
def add_channel(self, chan: Channel):
with self.lock:
self._channels[chan.channel_id] = chan
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
self.lnwatcher.add_channel(chan)
def add_new_channel(self, chan: Channel):
self.add_channel(chan)
@@ -1953,7 +1953,7 @@ class LNWallet(LNWorker):
We first try to conduct the payment over a single channel. If that fails
and mpp is supported by the receiver, we will split the payment."""
trampoline_features = LnFeatures.VAR_ONION_OPT
local_height = self.network.get_local_height()
local_height = self.wallet.adb.get_local_height()
fee_related_error = None # type: Optional[FeeBudgetExceeded]
if channels:
my_active_channels = channels
@@ -3069,7 +3069,7 @@ class LNWallet(LNWorker):
self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
self.wallet.save_db()
util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
self.lnwatcher.add_channel(cb)
def has_conflicting_backup_with(self, remote_node_id: bytes):
""" Returns whether we have an active channel with this node on another device, using same local node id. """
@@ -3186,7 +3186,7 @@ class LNWallet(LNWorker):
with self.lock:
self._channel_backups[bfh(channel_id)] = cb
util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
self.lnwatcher.add_channel(cb)
def save_forwarding_failure(
self, payment_key:str, *,

View File

@@ -75,8 +75,20 @@ class WatchTower(LNWatcher):
await super().stop()
await self.adb.stop()
def add_channel(self, outpoint: str, address: str) -> None:
callback = lambda: self.check_onchain_situation(address, outpoint)
self.add_callback(address, callback)
@log_exceptions
async def trigger_callbacks(self):
if not self.adb.synchronizer:
self.logger.info("synchronizer not set yet")
return
for address, callback in list(self.callbacks.items()):
await callback()
def diagnostic_name(self):
return "local_tower"
return "watchtower"
@log_exceptions
async def start_watching(self):
@@ -85,6 +97,27 @@ class WatchTower(LNWatcher):
for outpoint, address in random_shuffled_copy(lst):
self.add_channel(outpoint, address)
async def check_onchain_situation(self, address, funding_outpoint):
# early return if address has not been added yet
if not self.adb.is_mine(address):
return
# inspect_tx_candidate might have added new addresses, in which case we return early
funding_txid = funding_outpoint.split(':')[0]
funding_height = self.adb.get_tx_height(funding_txid)
closing_txid = self.get_spender(funding_outpoint)
closing_height = self.adb.get_tx_height(closing_txid)
if closing_txid:
closing_tx = self.adb.get_transaction(closing_txid)
if closing_tx:
keep_watching = await self.sweep_commitment_transaction(funding_outpoint, closing_tx)
else:
self.logger.info(f"channel {funding_outpoint} closed by {closing_txid}. still waiting for tx itself...")
keep_watching = True
else:
keep_watching = True
if not keep_watching:
await self.unwatch_channel(address, funding_outpoint)
def inspect_tx_candidate(self, outpoint, n: int) -> Dict[str, str]:
"""
returns a dict of spenders for a transaction of interest.
@@ -188,8 +221,6 @@ class WatchTower(LNWatcher):
await self.sweepstore.remove_sweep_tx(funding_outpoint)
await self.sweepstore.remove_channel(funding_outpoint)
async def update_channel_state(self, *args, **kwargs):
pass

View File

@@ -179,6 +179,7 @@ class SwapManager(Logger):
self.wallet = wallet
self.config = wallet.config
self.lnworker = lnworker
self.lnwatcher = self.lnworker.lnwatcher
self.config = wallet.config
self.taskgroup = OldTaskGroup()
self.dummy_address = DummyAddress.SWAP
@@ -207,7 +208,6 @@ class SwapManager(Logger):
return
self.logger.info('start_network: starting main loop')
self.network = network
self.lnwatcher = self.lnworker.lnwatcher
for k, swap in self.swaps.items():
if swap.is_redeemed:
continue
@@ -321,8 +321,7 @@ class SwapManager(Logger):
if sha256(preimage) == swap.payment_hash:
return preimage
@log_exceptions
async def _claim_swap(self, swap: SwapData) -> None:
def _claim_swap(self, swap: SwapData) -> None:
assert self.network
assert self.lnwatcher
if not self.lnwatcher.adb.is_up_to_date():

View File

@@ -108,8 +108,12 @@ class MockBlockchain:
class MockADB:
def __init__(self):
self._blockchain = MockBlockchain()
def add_transaction(self, tx):
pass
def get_local_height(self):
return self._blockchain.height()
class MockWallet:
receive_requests = {}