1
0

Make lnwatcher not async

This fixes offline history not having the proper labels
This commit is contained in:
ThomasV
2025-02-07 09:52:03 +01:00
parent 42b072aca8
commit fbebe7de1a
5 changed files with 117 additions and 87 deletions

View File

@@ -5,7 +5,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from enum import IntEnum, auto from enum import IntEnum, auto
from .util import log_exceptions, ignore_exceptions, TxMinedInfo, BelowDustLimit from .util import log_exceptions, TxMinedInfo, BelowDustLimit
from .util import EventListener, event_listener from .util import EventListener, event_listener
from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_FUTURE from .address_synchronizer import AddressSynchronizer, TX_HEIGHT_LOCAL, TX_HEIGHT_UNCONF_PARENT, TX_HEIGHT_UNCONFIRMED, TX_HEIGHT_FUTURE
from .transaction import Transaction, TxOutpoint from .transaction import Transaction, TxOutpoint
@@ -17,6 +17,8 @@ if TYPE_CHECKING:
from .lnsweep import SweepInfo from .lnsweep import SweepInfo
from .lnworker import LNWallet from .lnworker import LNWallet
from .lnchannel import AbstractChannel from .lnchannel import AbstractChannel
from .simple_config import SimpleConfig
class TxMinedDepth(IntEnum): class TxMinedDepth(IntEnum):
""" IntEnum because we call min() in get_deepest_tx_mined_depth_for_txids """ """ IntEnum because we call min() in get_deepest_tx_mined_depth_for_txids """
@@ -30,30 +32,27 @@ class LNWatcher(Logger, EventListener):
LOGGING_SHORTCUT = 'W' LOGGING_SHORTCUT = 'W'
def __init__(self, adb: 'AddressSynchronizer', network: 'Network'): def __init__(self, adb: 'AddressSynchronizer', config: 'SimpleConfig'):
Logger.__init__(self) Logger.__init__(self)
self.adb = adb self.adb = adb
self.config = network.config self.config = config
self.callbacks = {} # address -> lambda: coroutine self.callbacks = {} # address -> lambda: coroutine
self.network = network self.network = None
self.register_callbacks() self.register_callbacks()
# status gets populated when we run # status gets populated when we run
self.channel_status = {} self.channel_status = {}
def start_network(self, network: 'Network'):
self.network = network
async def stop(self): async def stop(self):
self.unregister_callbacks() self.unregister_callbacks()
def get_channel_status(self, outpoint): def get_channel_status(self, outpoint):
return self.channel_status.get(outpoint, 'unknown') return self.channel_status.get(outpoint, 'unknown')
def add_channel(self, outpoint: str, address: str) -> None: def unwatch_channel(self, address, funding_outpoint):
assert isinstance(outpoint, str)
assert isinstance(address, str)
cb = lambda: self.check_onchain_situation(address, outpoint)
self.add_callback(address, cb)
async def unwatch_channel(self, address, funding_outpoint):
self.logger.info(f'unwatching {funding_outpoint}') self.logger.info(f'unwatching {funding_outpoint}')
self.remove_callback(address) self.remove_callback(address)
@@ -93,46 +92,7 @@ class LNWatcher(Logger, EventListener):
self.logger.info("synchronizer not set yet") self.logger.info("synchronizer not set yet")
return return
for address, callback in list(self.callbacks.items()): for address, callback in list(self.callbacks.items()):
await callback() callback()
async def check_onchain_situation(self, address, funding_outpoint):
# early return if address has not been added yet
if not self.adb.is_mine(address):
return
# inspect_tx_candidate might have added new addresses, in which case we return early
if not self.adb.is_up_to_date():
return
funding_txid = funding_outpoint.split(':')[0]
funding_height = self.adb.get_tx_height(funding_txid)
closing_txid = self.get_spender(funding_outpoint)
closing_height = self.adb.get_tx_height(closing_txid)
if closing_txid:
closing_tx = self.adb.get_transaction(closing_txid)
if closing_tx:
keep_watching = await self.sweep_commitment_transaction(funding_outpoint, closing_tx)
else:
self.logger.info(f"channel {funding_outpoint} closed by {closing_txid}. still waiting for tx itself...")
keep_watching = True
else:
keep_watching = True
await self.update_channel_state(
funding_outpoint=funding_outpoint,
funding_txid=funding_txid,
funding_height=funding_height,
closing_txid=closing_txid,
closing_height=closing_height,
keep_watching=keep_watching)
if not keep_watching:
await self.unwatch_channel(address, funding_outpoint)
async def sweep_commitment_transaction(self, funding_outpoint: str, closing_tx: Transaction) -> bool:
raise NotImplementedError() # implemented by subclasses
async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str,
funding_height: TxMinedInfo, closing_txid: str,
closing_height: TxMinedInfo, keep_watching: bool) -> None:
raise NotImplementedError() # implemented by subclasses
def get_spender(self, outpoint) -> str: def get_spender(self, outpoint) -> str:
""" """
@@ -181,10 +141,45 @@ class LNWatcher(Logger, EventListener):
class LNWalletWatcher(LNWatcher): class LNWalletWatcher(LNWatcher):
def __init__(self, lnworker: 'LNWallet', network: 'Network'): def __init__(self, lnworker: 'LNWallet'):
self.network = network
self.lnworker = lnworker self.lnworker = lnworker
LNWatcher.__init__(self, lnworker.wallet.adb, network) LNWatcher.__init__(self, lnworker.wallet.adb, lnworker.config)
def add_channel(self, chan: 'AbstractChannel') -> None:
outpoint = chan.funding_outpoint.to_str()
address = chan.get_funding_address()
callback = lambda: self.check_onchain_situation(address, outpoint)
callback() # run once, for side effects
if chan.need_to_subscribe():
self.add_callback(address, callback)
def check_onchain_situation(self, address, funding_outpoint):
# early return if address has not been added yet
if not self.adb.is_mine(address):
return
# inspect_tx_candidate might have added new addresses, in which case we return early
funding_txid = funding_outpoint.split(':')[0]
funding_height = self.adb.get_tx_height(funding_txid)
closing_txid = self.get_spender(funding_outpoint)
closing_height = self.adb.get_tx_height(closing_txid)
if closing_txid:
closing_tx = self.adb.get_transaction(closing_txid)
if closing_tx:
keep_watching = self.sweep_commitment_transaction(funding_outpoint, closing_tx)
else:
self.logger.info(f"channel {funding_outpoint} closed by {closing_txid}. still waiting for tx itself...")
keep_watching = True
else:
keep_watching = True
self.update_channel_state(
funding_outpoint=funding_outpoint,
funding_txid=funding_txid,
funding_height=funding_height,
closing_txid=closing_txid,
closing_height=closing_height,
keep_watching=keep_watching)
if not keep_watching:
self.unwatch_channel(address, funding_outpoint)
@event_listener @event_listener
async def on_event_blockchain_updated(self, *args): async def on_event_blockchain_updated(self, *args):
@@ -199,11 +194,9 @@ class LNWalletWatcher(LNWatcher):
def diagnostic_name(self): def diagnostic_name(self):
return f"{self.lnworker.wallet.diagnostic_name()}-LNW" return f"{self.lnworker.wallet.diagnostic_name()}-LNW"
@ignore_exceptions def update_channel_state(self, *, funding_outpoint: str, funding_txid: str,
@log_exceptions funding_height: TxMinedInfo, closing_txid: str,
async def update_channel_state(self, *, funding_outpoint: str, funding_txid: str, closing_height: TxMinedInfo, keep_watching: bool) -> None:
funding_height: TxMinedInfo, closing_txid: str,
closing_height: TxMinedInfo, keep_watching: bool) -> None:
chan = self.lnworker.channel_by_txo(funding_outpoint) chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan: if not chan:
return return
@@ -213,15 +206,18 @@ class LNWalletWatcher(LNWatcher):
closing_txid=closing_txid, closing_txid=closing_txid,
closing_height=closing_height, closing_height=closing_height,
keep_watching=keep_watching) keep_watching=keep_watching)
await self.lnworker.handle_onchain_state(chan) self.lnworker.handle_onchain_state(chan)
@log_exceptions def sweep_commitment_transaction(self, funding_outpoint, closing_tx) -> bool:
async def sweep_commitment_transaction(self, funding_outpoint, closing_tx) -> bool:
"""This function is called when a channel was closed. In this case """This function is called when a channel was closed. In this case
we need to check for redeemable outputs of the commitment transaction we need to check for redeemable outputs of the commitment transaction
or spenders down the line (HTLC-timeout/success transactions). or spenders down the line (HTLC-timeout/success transactions).
Returns whether we should continue to monitor.""" Returns whether we should continue to monitor.
Side-effécts:
- sets defaults labels
"""
chan = self.lnworker.channel_by_txo(funding_outpoint) chan = self.lnworker.channel_by_txo(funding_outpoint)
if not chan: if not chan:
return False return False

View File

@@ -842,7 +842,7 @@ class LNWallet(LNWorker):
if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP: if self.config.EXPERIMENTAL_LN_FORWARD_PAYMENTS and self.config.LIGHTNING_USE_GOSSIP:
features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch features |= LnFeatures.GOSSIP_QUERIES_OPT # signal we have gossip to fetch
LNWorker.__init__(self, self.node_keypair, features, config=self.config) LNWorker.__init__(self, self.node_keypair, features, config=self.config)
self.lnwatcher = None self.lnwatcher = LNWalletWatcher(self)
self.lnrater: LNRater = None self.lnrater: LNRater = None
self.payment_info = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid self.payment_info = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage self.preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
@@ -890,6 +890,13 @@ class LNWallet(LNWorker):
self.nostr_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NOSTR_KEY) self.nostr_keypair = generate_keypair(BIP32Node.from_xkey(xprv), LnKeyFamily.NOSTR_KEY)
self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self) self.swap_manager = SwapManager(wallet=self.wallet, lnworker=self)
self.onion_message_manager = OnionMessageManager(self) self.onion_message_manager = OnionMessageManager(self)
self.subscribe_to_channels()
def subscribe_to_channels(self):
for chan in self.channels.values():
self.lnwatcher.add_channel(chan)
for cb in self.channel_backups.values():
self.lnwatcher.add_channel(cb)
def has_deterministic_node_id(self) -> bool: def has_deterministic_node_id(self) -> bool:
return bool(self.db.get('lightning_xprv')) return bool(self.db.get('lightning_xprv'))
@@ -970,18 +977,11 @@ class LNWallet(LNWorker):
def start_network(self, network: 'Network'): def start_network(self, network: 'Network'):
super().start_network(network) super().start_network(network)
self.lnwatcher = LNWalletWatcher(self, network) self.lnwatcher.start_network(network)
self.swap_manager.start_network(network) self.swap_manager.start_network(network)
self.lnrater = LNRater(self, network) self.lnrater = LNRater(self, network)
self.onion_message_manager.start_network(network=network) self.onion_message_manager.start_network(network=network)
for chan in self.channels.values():
if chan.need_to_subscribe():
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address())
for cb in self.channel_backups.values():
if cb.need_to_subscribe():
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address())
for coro in [ for coro in [
self.maybe_listen(), self.maybe_listen(),
self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified self.lnwatcher.trigger_callbacks(), # shortcut (don't block) if funding tx locked and verified
@@ -1193,15 +1193,15 @@ class LNWallet(LNWorker):
if chan.funding_outpoint.to_str() == txo: if chan.funding_outpoint.to_str() == txo:
return chan return chan
async def handle_onchain_state(self, chan: Channel): def handle_onchain_state(self, chan: Channel):
if type(chan) is ChannelBackup: if type(chan) is ChannelBackup:
util.trigger_callback('channel', self.wallet, chan) util.trigger_callback('channel', self.wallet, chan)
return return
if (chan.get_state() in (ChannelState.OPEN, ChannelState.SHUTDOWN) if (chan.get_state() in (ChannelState.OPEN, ChannelState.SHUTDOWN)
and chan.should_be_closed_due_to_expiring_htlcs(self.network.get_local_height())): and chan.should_be_closed_due_to_expiring_htlcs(self.wallet.adb.get_local_height())):
self.logger.info(f"force-closing due to expiring htlcs") self.logger.info(f"force-closing due to expiring htlcs")
await self.schedule_force_closing(chan.channel_id) asyncio.ensure_future(self.schedule_force_closing(chan.channel_id))
elif chan.get_state() == ChannelState.FUNDED: elif chan.get_state() == ChannelState.FUNDED:
peer = self._peers.get(chan.node_id) peer = self._peers.get(chan.node_id)
@@ -1220,7 +1220,7 @@ class LNWallet(LNWorker):
height = self.lnwatcher.adb.get_tx_height(txid).height height = self.lnwatcher.adb.get_tx_height(txid).height
if height == TX_HEIGHT_LOCAL: if height == TX_HEIGHT_LOCAL:
self.logger.info('REBROADCASTING CLOSING TX') self.logger.info('REBROADCASTING CLOSING TX')
await self.network.try_broadcasting(force_close_tx, 'force-close') asyncio.ensure_future(self.network.try_broadcasting(force_close_tx, 'force-close'))
def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]: def get_peer_by_static_jit_scid_alias(self, scid_alias: bytes) -> Optional[Peer]:
for nodeid, peer in self.peers.items(): for nodeid, peer in self.peers.items():
@@ -1363,7 +1363,7 @@ class LNWallet(LNWorker):
def add_channel(self, chan: Channel): def add_channel(self, chan: Channel):
with self.lock: with self.lock:
self._channels[chan.channel_id] = chan self._channels[chan.channel_id] = chan
self.lnwatcher.add_channel(chan.funding_outpoint.to_str(), chan.get_funding_address()) self.lnwatcher.add_channel(chan)
def add_new_channel(self, chan: Channel): def add_new_channel(self, chan: Channel):
self.add_channel(chan) self.add_channel(chan)
@@ -1953,7 +1953,7 @@ class LNWallet(LNWorker):
We first try to conduct the payment over a single channel. If that fails We first try to conduct the payment over a single channel. If that fails
and mpp is supported by the receiver, we will split the payment.""" and mpp is supported by the receiver, we will split the payment."""
trampoline_features = LnFeatures.VAR_ONION_OPT trampoline_features = LnFeatures.VAR_ONION_OPT
local_height = self.network.get_local_height() local_height = self.wallet.adb.get_local_height()
fee_related_error = None # type: Optional[FeeBudgetExceeded] fee_related_error = None # type: Optional[FeeBudgetExceeded]
if channels: if channels:
my_active_channels = channels my_active_channels = channels
@@ -3069,7 +3069,7 @@ class LNWallet(LNWorker):
self.wallet.set_reserved_addresses_for_chan(cb, reserved=True) self.wallet.set_reserved_addresses_for_chan(cb, reserved=True)
self.wallet.save_db() self.wallet.save_db()
util.trigger_callback('channels_updated', self.wallet) util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) self.lnwatcher.add_channel(cb)
def has_conflicting_backup_with(self, remote_node_id: bytes): def has_conflicting_backup_with(self, remote_node_id: bytes):
""" Returns whether we have an active channel with this node on another device, using same local node id. """ """ Returns whether we have an active channel with this node on another device, using same local node id. """
@@ -3186,7 +3186,7 @@ class LNWallet(LNWorker):
with self.lock: with self.lock:
self._channel_backups[bfh(channel_id)] = cb self._channel_backups[bfh(channel_id)] = cb
util.trigger_callback('channels_updated', self.wallet) util.trigger_callback('channels_updated', self.wallet)
self.lnwatcher.add_channel(cb.funding_outpoint.to_str(), cb.get_funding_address()) self.lnwatcher.add_channel(cb)
def save_forwarding_failure( def save_forwarding_failure(
self, payment_key:str, *, self, payment_key:str, *,

View File

@@ -75,8 +75,20 @@ class WatchTower(LNWatcher):
await super().stop() await super().stop()
await self.adb.stop() await self.adb.stop()
def add_channel(self, outpoint: str, address: str) -> None:
callback = lambda: self.check_onchain_situation(address, outpoint)
self.add_callback(address, callback)
@log_exceptions
async def trigger_callbacks(self):
if not self.adb.synchronizer:
self.logger.info("synchronizer not set yet")
return
for address, callback in list(self.callbacks.items()):
await callback()
def diagnostic_name(self): def diagnostic_name(self):
return "local_tower" return "watchtower"
@log_exceptions @log_exceptions
async def start_watching(self): async def start_watching(self):
@@ -85,6 +97,27 @@ class WatchTower(LNWatcher):
for outpoint, address in random_shuffled_copy(lst): for outpoint, address in random_shuffled_copy(lst):
self.add_channel(outpoint, address) self.add_channel(outpoint, address)
async def check_onchain_situation(self, address, funding_outpoint):
# early return if address has not been added yet
if not self.adb.is_mine(address):
return
# inspect_tx_candidate might have added new addresses, in which case we return early
funding_txid = funding_outpoint.split(':')[0]
funding_height = self.adb.get_tx_height(funding_txid)
closing_txid = self.get_spender(funding_outpoint)
closing_height = self.adb.get_tx_height(closing_txid)
if closing_txid:
closing_tx = self.adb.get_transaction(closing_txid)
if closing_tx:
keep_watching = await self.sweep_commitment_transaction(funding_outpoint, closing_tx)
else:
self.logger.info(f"channel {funding_outpoint} closed by {closing_txid}. still waiting for tx itself...")
keep_watching = True
else:
keep_watching = True
if not keep_watching:
await self.unwatch_channel(address, funding_outpoint)
def inspect_tx_candidate(self, outpoint, n: int) -> Dict[str, str]: def inspect_tx_candidate(self, outpoint, n: int) -> Dict[str, str]:
""" """
returns a dict of spenders for a transaction of interest. returns a dict of spenders for a transaction of interest.
@@ -188,8 +221,6 @@ class WatchTower(LNWatcher):
await self.sweepstore.remove_sweep_tx(funding_outpoint) await self.sweepstore.remove_sweep_tx(funding_outpoint)
await self.sweepstore.remove_channel(funding_outpoint) await self.sweepstore.remove_channel(funding_outpoint)
async def update_channel_state(self, *args, **kwargs):
pass

View File

@@ -179,6 +179,7 @@ class SwapManager(Logger):
self.wallet = wallet self.wallet = wallet
self.config = wallet.config self.config = wallet.config
self.lnworker = lnworker self.lnworker = lnworker
self.lnwatcher = self.lnworker.lnwatcher
self.config = wallet.config self.config = wallet.config
self.taskgroup = OldTaskGroup() self.taskgroup = OldTaskGroup()
self.dummy_address = DummyAddress.SWAP self.dummy_address = DummyAddress.SWAP
@@ -207,7 +208,6 @@ class SwapManager(Logger):
return return
self.logger.info('start_network: starting main loop') self.logger.info('start_network: starting main loop')
self.network = network self.network = network
self.lnwatcher = self.lnworker.lnwatcher
for k, swap in self.swaps.items(): for k, swap in self.swaps.items():
if swap.is_redeemed: if swap.is_redeemed:
continue continue
@@ -321,8 +321,7 @@ class SwapManager(Logger):
if sha256(preimage) == swap.payment_hash: if sha256(preimage) == swap.payment_hash:
return preimage return preimage
@log_exceptions def _claim_swap(self, swap: SwapData) -> None:
async def _claim_swap(self, swap: SwapData) -> None:
assert self.network assert self.network
assert self.lnwatcher assert self.lnwatcher
if not self.lnwatcher.adb.is_up_to_date(): if not self.lnwatcher.adb.is_up_to_date():

View File

@@ -108,8 +108,12 @@ class MockBlockchain:
class MockADB: class MockADB:
def __init__(self):
self._blockchain = MockBlockchain()
def add_transaction(self, tx): def add_transaction(self, tx):
pass pass
def get_local_height(self):
return self._blockchain.height()
class MockWallet: class MockWallet:
receive_requests = {} receive_requests = {}